[multi-tenant-lora] Expose per-LoRA pause state over HTTP#1666
[multi-tenant-lora] Expose per-LoRA pause state over HTTP#1666erictang000 wants to merge 1 commit into
Conversation
Adds three control-plane endpoints on the vLLM server actor and a submission-gate middleware so out-of-process callers (e.g. the Tinker SkyRLTrainInferenceForwardingClient) can observe the same pause window as in-process RemoteInferenceClient.sample_with_retry callers. - POST /skyrl/v1/abort_lora_requests: now also clears a per-LoRA asyncio.Event in app.state.paused_loras before the engine.abort() fan-out. - POST /skyrl/v1/resume_lora_requests: sets the event. - POST /skyrl/v1/wait_lora_unpaused: long-polls until the event is set or the caller-supplied timeout elapses. - Pure-ASGI middleware blocks fresh /v1/completions and /v1/chat/completions for a paused LoRA until resume, closing the race between abort_lora_requests and load_lora_adapter where a new request could observe torn adapter weights. RemoteInferenceClient.resume_generation now also POSTs the new /skyrl/v1/resume_lora_requests so the server-side flag matches the in-process gate. Existing test_pause_lora.py callers don't observe a return-value contract change. All endpoints carry the same TRANSIENT marker as the existing abort endpoint: delete the stack when vLLM ships native per-LoRA pause. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements a per-LoRA pause and resume mechanism to prevent request processing during adapter weight updates. It introduces an ASGI middleware to gate completion requests and adds new endpoints for managing and waiting on LoRA pause states. The review feedback suggests improving the middleware's resilience by handling client disconnections during request buffering and adding validation for the timeout parameter in the long-poll endpoint to avoid potential runtime errors.
| while more_body: | ||
| msg = await receive() | ||
| chunks.append(msg.get("body", b"")) | ||
| more_body = msg.get("more_body", False) | ||
| body = b"".join(chunks) |
There was a problem hiding this comment.
The middleware buffers the entire request body in memory to inspect the model field. While this is necessary for peeking at the JSON, it does not handle http.disconnect messages during the buffering loop. If a client disconnects while the body is being uploaded, receive() will return a disconnect message, which this loop will treat as a body chunk with more_body=False, potentially leading to a JSONDecodeError or processing a partial request.
| while more_body: | |
| msg = await receive() | |
| chunks.append(msg.get("body", b"")) | |
| more_body = msg.get("more_body", False) | |
| body = b"".join(chunks) | |
| while more_body: | |
| msg = await receive() | |
| if msg.get("type") == "http.disconnect": | |
| await self.inner_app(scope, lambda: msg, send) | |
| return | |
| chunks.append(msg.get("body", b"")) | |
| more_body = msg.get("more_body", False) |
| timeout_s = body.get("timeout_s", 60.0) | ||
| ev = _get_pause_event(lora_name) | ||
| try: | ||
| await asyncio.wait_for(ev.wait(), timeout=float(timeout_s)) |
There was a problem hiding this comment.
The timeout_s parameter is converted to a float without validation. If a client provides a non-numeric value, this will raise a ValueError and return a 500 error. It's safer to validate the input or provide a fallback.
| timeout_s = body.get("timeout_s", 60.0) | |
| ev = _get_pause_event(lora_name) | |
| try: | |
| await asyncio.wait_for(ev.wait(), timeout=float(timeout_s)) | |
| try: | |
| timeout_s = float(body.get("timeout_s", 60.0)) | |
| except (ValueError, TypeError): | |
| timeout_s = 60.0 | |
| ev = _get_pause_event(lora_name) | |
| try: | |
| await asyncio.wait_for(ev.wait(), timeout=timeout_s) |
Summary
Exposes the per-LoRA pause state from
RemoteInferenceClient.sample_with_retryover HTTP so out-of-process callers (the TinkerSkyRLTrainInferenceForwardingClientfrom #1638) can participate in the same pause/resume cycle as in-process callers.This is the first of two PRs that finish the fully-async multi-tenant Tinker path. Follow-up PR will use these endpoints to add
sample_with_retry-equivalent abort recovery to the Tinker forwarding client and add Tinker-side tests that mirrortest_pause_lora.py.Background
Today
RemoteInferenceClient.sample_with_retry(added in #1657) uses an in-process_lora_pause_events: dict[str, asyncio.Event]to gate retries during a per-LoRA pause window. The Tinker API process runs in a separate Python process and cannot observe that dict, so an in-flight Tinker asample for a paused LoRA today:finish_reason="abort"in the vLLM response but maps it tostop_reason="length"(silent corruption — fix lands in the follow-up PR), andload_lora_adapterand can observe torn adapter weights.Changes
skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py:app.state.paused_loras: dict[str, asyncio.Event]initialised in_add_custom_endpoints. Mirrors the in-process gate.POST /skyrl/v1/abort_lora_requests([lora][tinker] Add pause and resume for multi-tenant lora #1657) now also clears the per-LoRA event before theengine.abort()fan-out.POST /skyrl/v1/resume_lora_requests {lora_name}: sets the event.POST /skyrl/v1/wait_lora_unpaused {lora_name, timeout_s?}: long-polls; returns{paused: false}once the event is set or{paused: true}after timeout so the caller can loop./v1/completionsand/v1/chat/completionsfor a currently-paused LoRA block until resume. Closes the race in (2) above. Pure ASGI (not@app.middleware("http")) becauseBaseHTTPMiddlewarebuffers streaming responses and would regress vLLM's SSE completions path.skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py:resume_generation(lora_name=X)now also POSTs/skyrl/v1/resume_lora_requeststo allserver_urlsso the server-side flag matches the in-process gate. The local set still happens first, so in-process retries don't wait on the HTTP round-trip; brief inconsistency between the two gates just causes one extra retry-loop iteration.All new endpoints carry the same TRANSIENT marker as the existing abort endpoint: delete the stack when vLLM ships native per-LoRA pause.
Test plan
ruff+blackcleantests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_pause_lora.py(existing pause/resume suite) passes — every test exercisespause_generationandresume_generation, so my changes are validated indirectly. New endpoints (/skyrl/v1/wait_lora_unpaused, submission gate) are exercised by the follow-up PR's Tinker-side test.🤖 Generated with Claude Code