Skip to content

[multi-tenant-lora] Expose per-LoRA pause state over HTTP#1666

Open
erictang000 wants to merge 1 commit into
NovaSky-AI:mainfrom
erictang000:multi-tenant-async-pause-server
Open

[multi-tenant-lora] Expose per-LoRA pause state over HTTP#1666
erictang000 wants to merge 1 commit into
NovaSky-AI:mainfrom
erictang000:multi-tenant-async-pause-server

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

Summary

Exposes the per-LoRA pause state from RemoteInferenceClient.sample_with_retry over HTTP so out-of-process callers (the Tinker SkyRLTrainInferenceForwardingClient from #1638) can participate in the same pause/resume cycle as in-process callers.

This is the first of two PRs that finish the fully-async multi-tenant Tinker path. Follow-up PR will use these endpoints to add sample_with_retry-equivalent abort recovery to the Tinker forwarding client and add Tinker-side tests that mirror test_pause_lora.py.

Background

Today RemoteInferenceClient.sample_with_retry (added in #1657) uses an in-process _lora_pause_events: dict[str, asyncio.Event] to gate retries during a per-LoRA pause window. The Tinker API process runs in a separate Python process and cannot observe that dict, so an in-flight Tinker asample for a paused LoRA today:

  1. Sees finish_reason="abort" in the vLLM response but maps it to stop_reason="length" (silent corruption — fix lands in the follow-up PR), and
  2. A fresh Tinker asample for the same LoRA submitted during the pause window races with load_lora_adapter and can observe torn adapter weights.

Changes

skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py:

  • app.state.paused_loras: dict[str, asyncio.Event] initialised in _add_custom_endpoints. Mirrors the in-process gate.
  • Existing POST /skyrl/v1/abort_lora_requests ([lora][tinker] Add pause and resume for multi-tenant lora  #1657) now also clears the per-LoRA event before the engine.abort() fan-out.
  • New POST /skyrl/v1/resume_lora_requests {lora_name}: sets the event.
  • New POST /skyrl/v1/wait_lora_unpaused {lora_name, timeout_s?}: long-polls; returns {paused: false} once the event is set or {paused: true} after timeout so the caller can loop.
  • New pure-ASGI submission-gate middleware: pending /v1/completions and /v1/chat/completions for a currently-paused LoRA block until resume. Closes the race in (2) above. Pure ASGI (not @app.middleware("http")) because BaseHTTPMiddleware buffers streaming responses and would regress vLLM's SSE completions path.

skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py:

  • resume_generation(lora_name=X) now also POSTs /skyrl/v1/resume_lora_requests to all server_urls so the server-side flag matches the in-process gate. The local set still happens first, so in-process retries don't wait on the HTTP round-trip; brief inconsistency between the two gates just causes one extra retry-loop iteration.

All new endpoints carry the same TRANSIENT marker as the existing abort endpoint: delete the stack when vLLM ships native per-LoRA pause.

Test plan

  • ruff + black clean
  • tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_pause_lora.py (existing pause/resume suite) passes — every test exercises pause_generation and resume_generation, so my changes are validated indirectly. New endpoints (/skyrl/v1/wait_lora_unpaused, submission gate) are exercised by the follow-up PR's Tinker-side test.

🤖 Generated with Claude Code

Adds three control-plane endpoints on the vLLM server actor and a
submission-gate middleware so out-of-process callers (e.g. the Tinker
SkyRLTrainInferenceForwardingClient) can observe the same pause window
as in-process RemoteInferenceClient.sample_with_retry callers.

- POST /skyrl/v1/abort_lora_requests: now also clears a per-LoRA
  asyncio.Event in app.state.paused_loras before the engine.abort()
  fan-out.
- POST /skyrl/v1/resume_lora_requests: sets the event.
- POST /skyrl/v1/wait_lora_unpaused: long-polls until the event is set
  or the caller-supplied timeout elapses.
- Pure-ASGI middleware blocks fresh /v1/completions and
  /v1/chat/completions for a paused LoRA until resume, closing the race
  between abort_lora_requests and load_lora_adapter where a new request
  could observe torn adapter weights.

RemoteInferenceClient.resume_generation now also POSTs the new
/skyrl/v1/resume_lora_requests so the server-side flag matches the
in-process gate. Existing test_pause_lora.py callers don't observe a
return-value contract change.

All endpoints carry the same TRANSIENT marker as the existing abort
endpoint: delete the stack when vLLM ships native per-LoRA pause.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a per-LoRA pause and resume mechanism to prevent request processing during adapter weight updates. It introduces an ASGI middleware to gate completion requests and adds new endpoints for managing and waiting on LoRA pause states. The review feedback suggests improving the middleware's resilience by handling client disconnections during request buffering and adding validation for the timeout parameter in the long-poll endpoint to avoid potential runtime errors.

Comment on lines +413 to +417
while more_body:
msg = await receive()
chunks.append(msg.get("body", b""))
more_body = msg.get("more_body", False)
body = b"".join(chunks)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The middleware buffers the entire request body in memory to inspect the model field. While this is necessary for peeking at the JSON, it does not handle http.disconnect messages during the buffering loop. If a client disconnects while the body is being uploaded, receive() will return a disconnect message, which this loop will treat as a body chunk with more_body=False, potentially leading to a JSONDecodeError or processing a partial request.

Suggested change
while more_body:
msg = await receive()
chunks.append(msg.get("body", b""))
more_body = msg.get("more_body", False)
body = b"".join(chunks)
while more_body:
msg = await receive()
if msg.get("type") == "http.disconnect":
await self.inner_app(scope, lambda: msg, send)
return
chunks.append(msg.get("body", b""))
more_body = msg.get("more_body", False)

Comment on lines +548 to +551
timeout_s = body.get("timeout_s", 60.0)
ev = _get_pause_event(lora_name)
try:
await asyncio.wait_for(ev.wait(), timeout=float(timeout_s))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The timeout_s parameter is converted to a float without validation. If a client provides a non-numeric value, this will raise a ValueError and return a 500 error. It's safer to validate the input or provide a fallback.

Suggested change
timeout_s = body.get("timeout_s", 60.0)
ev = _get_pause_event(lora_name)
try:
await asyncio.wait_for(ev.wait(), timeout=float(timeout_s))
try:
timeout_s = float(body.get("timeout_s", 60.0))
except (ValueError, TypeError):
timeout_s = 60.0
ev = _get_pause_event(lora_name)
try:
await asyncio.wait_for(ev.wait(), timeout=timeout_s)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant