Fully Async Trainer#394
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a new experimental “fully async” PPO training pipeline that decouples rollout generation from training via a Ray-based message queue, plus parameter sync/validation plumbing and a DeepResearch example stack.
Changes:
- Introduce
rllm.experimental.fully_asynccomponents: rollouter, trainer, message queue, param sync, HTTP client, metrics + batching utilities. - Add docs/config/scripts for installing dependencies and applying required
verlpatches. - Add a runnable DeepResearch example (tooling, refine service client, RAG server, and launch scripts).
Reviewed changes
Copilot reviewed 27 out of 28 changed files in this pull request and generated 25 comments.
Show a summary per file
| File | Description |
|---|---|
| rllm/experimental/fully_async/README.md | Documents architecture/installation/patch application for fully-async mode. |
| rllm/experimental/fully_async/client.py | Async HTTP rollout client with abort/continue support and chat-completions wrapper. |
| rllm/experimental/fully_async/config/init.py | Marks config package for Hydra discovery. |
| rllm/experimental/fully_async/config/fully_async_ppo_trainer.yaml | Example Hydra config for fully-async PPO trainer + rollout settings. |
| rllm/experimental/fully_async/fully_async_trainer.py | Trainer consuming samples from MQ, running PPO updates, triggering param sync, logging/ckpt. |
| rllm/experimental/fully_async/inference_manager.py | Manages SGLang server workers/router + cache clearing for async rollouts. |
| rllm/experimental/fully_async/install_vllm_sglang_mcore_updated_sglang.sh | Convenience install script for inference/training dependencies. |
| rllm/experimental/fully_async/message_queue.py | Ray-actor queue between rollouter and trainer (+ client wrapper). |
| rllm/experimental/fully_async/message_utils.py | Converts model token outputs into OpenAI message/tool-call structures. |
| rllm/experimental/fully_async/metric_utils.py | Step-wise metrics aggregation + validation metrics container. |
| rllm/experimental/fully_async/param_sync.py | Unified parameter sync actor coordinating pause/clear-cache/sync/resume/validation. |
| rllm/experimental/fully_async/protocol.py | Dataclasses for streamed outputs, sequences, trajectories, and trajectory groups. |
| rllm/experimental/fully_async/rollout_executor.py | Async rollouter that generates trajectories concurrently and drains to MQ; runs validation. |
| rllm/experimental/fully_async/runner.py | Entry wiring: starts inference manager, rollouter, trainer, MQ, and synchronizer. |
| rllm/experimental/fully_async/utils.py | Batch assembly into DataProto, rejection sampling, checkpoint helpers, metric reduction, HTTP helpers. |
| rllm/experimental/fully_async/verl_dp_actor.patch | Patch file for upstream verl actor behavior required by fully-async training. |
| rllm/experimental/fully_async/verl_patch.md | Describes the required upstream verl patch intent and how to apply. |
| examples/fully_async/deepresearch/config/8b_stale05_rs.sh | Example launch configuration for DeepResearch training. |
| examples/fully_async/deepresearch/data/prepare_browsecomp_plus.py | Dataset prep/decrypt script and DatasetRegistry registration. |
| examples/fully_async/deepresearch/data/prepare_cut_the_bill.py | DatasetRegistry registration helper for a custom dataset. |
| examples/fully_async/deepresearch/rag/launch_rag.sh | Launch helper for the RAG server with batching/sharding knobs. |
| examples/fully_async/deepresearch/rag/rag_server.py | FastAPI retrieval server with GPU sharding + request auto-batching. |
| examples/fully_async/deepresearch/refine_agent.py | Refine-service client with multi-endpoint load balancing + stats. |
| examples/fully_async/deepresearch/scripts/launch_refine.sh | Launch helper for multi-GPU vLLM refine servers. |
| examples/fully_async/deepresearch/search_agent.py | Search agent performing tool calls, refinement, and reward computation. |
| examples/fully_async/deepresearch/tool.py | Async local retrieval tool with client-side failover/load balancing. |
| examples/fully_async/deepresearch/train.py | Hydra entry point wiring DeepResearch rollout functions into AsyncAgentTrainer. |
| examples/fully_async/deepresearch/util.py | Helpers to normalize messages into simple dict format. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| try: | ||
| iteration = 0 | ||
| while self.global_steps < self.total_rollout_steps: |
There was a problem hiding this comment.
Off-by-one in rollout termination: with global_steps initialized to 1, while self.global_steps < self.total_rollout_steps will run only total_rollout_steps - 1 iterations (since global_steps is incremented once per datum). If total_rollout_steps is meant to be an exact count, adjust the initialization/condition (e.g., start at 0 or use <=).
| while self.global_steps < self.total_rollout_steps: | |
| while self.global_steps <= self.total_rollout_steps: |
| results = await asyncio.gather(*[run_one(batch[0]) for batch in self.val_dataloader]) | ||
|
|
There was a problem hiding this comment.
asyncio.gather(*[run_one(...) for batch in self.val_dataloader]) creates one coroutine per validation sample up front. Even with the semaphore, this can be a large memory/overhead spike for big validation sets. Prefer a bounded-concurrency/streaming pattern (e.g., asyncio.as_completed over a limited task set).
| results = await asyncio.gather(*[run_one(batch[0]) for batch in self.val_dataloader]) | |
| # Stream tasks with bounded concurrency instead of creating all coroutines up front. | |
| results = [] | |
| concurrency_limit = 1024 | |
| pending = set() | |
| for batch in self.val_dataloader: | |
| task = asyncio.create_task(run_one(batch[0])) | |
| pending.add(task) | |
| if len(pending) >= concurrency_limit: | |
| done, pending = await asyncio.wait( | |
| pending, return_when=asyncio.FIRST_COMPLETED | |
| ) | |
| for d in done: | |
| results.append(d.result()) | |
| if pending: | |
| done, _ = await asyncio.wait(pending) | |
| for d in done: | |
| results.append(d.result()) |
| ## dp_actor_functional_changes.patch | ||
|
|
||
| **File:** `verl/workers/actor/dp_actor.py` | ||
|
|
||
| **Purpose:** Modifications to `DataParallelPPOActor.update_actor()` for token-mean loss scaling and single mini-batch enforcement. |
There was a problem hiding this comment.
This doc section header and apply instructions refer to dp_actor_functional_changes.patch, but the repository includes verl_dp_actor.patch (and the README instructs applying verl_dp_actor.patch). Align the filename references here to avoid users applying the wrong/missing patch.
| from rllm.experimental.fully_async.rollout_executor import RolloutExecutor | ||
| from rllm.experimental.fully_async.utils import calculate_max_concurrency | ||
|
|
||
|
|
There was a problem hiding this comment.
Unused imports: ResourcePoolManager and need_reference_policy are imported but never referenced in this module. Removing them avoids lint failures and keeps dependencies clearer.
| group = TrajectoryGroup(trajectories=[res for res in self.result_dict[idx] if res is not None]) | ||
| serialized = ray.cloudpickle.dumps(group) | ||
| await self.trajectory_group_queue.put(serialized) | ||
| del self.result_dict[idx] | ||
| self.active_sample -= 1 | ||
| self.enqueued_sample += 1 |
There was a problem hiding this comment.
If all n rollouts for an index fail (all results are None), this still enqueues a TrajectoryGroup with trajectories=[]. Downstream batch assembly/training is likely to fail or produce empty batches. Consider dropping empty groups (and tracking them as dropped) rather than enqueueing them.
| group = TrajectoryGroup(trajectories=[res for res in self.result_dict[idx] if res is not None]) | |
| serialized = ray.cloudpickle.dumps(group) | |
| await self.trajectory_group_queue.put(serialized) | |
| del self.result_dict[idx] | |
| self.active_sample -= 1 | |
| self.enqueued_sample += 1 | |
| trajectories = [res for res in self.result_dict[idx] if res is not None] | |
| if trajectories: | |
| group = TrajectoryGroup(trajectories=trajectories) | |
| serialized = ray.cloudpickle.dumps(group) | |
| await self.trajectory_group_queue.put(serialized) | |
| self.enqueued_sample += 1 | |
| else: | |
| # All rollouts for this index failed; drop this sample instead of enqueuing an empty group. | |
| self.dropped_samples += 1 | |
| del self.result_dict[idx] | |
| self.active_sample -= 1 |
| # Wait for either timeout or batch_event (triggered when queue is full) | ||
| try: | ||
| await asyncio.wait_for(self.batch_event.wait(), timeout=self.batch_timeout) | ||
| except asyncio.TimeoutError: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except asyncio.CancelledError: | ||
| pass | ||
| drain_task.cancel() | ||
| try: | ||
| await drain_task | ||
| except asyncio.CancelledError: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except asyncio.CancelledError: | |
| pass | |
| drain_task.cancel() | |
| try: | |
| await drain_task | |
| except asyncio.CancelledError: | |
| except asyncio.CancelledError: | |
| # Task cancellation is expected during cleanup; safely ignore. | |
| pass | |
| drain_task.cancel() | |
| try: | |
| await drain_task | |
| except asyncio.CancelledError: | |
| # Task cancellation is expected during cleanup; safely ignore. |
| except asyncio.CancelledError: | ||
| pass | ||
| drain_task.cancel() | ||
| try: | ||
| await drain_task | ||
| except asyncio.CancelledError: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except asyncio.CancelledError: | |
| pass | |
| drain_task.cancel() | |
| try: | |
| await drain_task | |
| except asyncio.CancelledError: | |
| except asyncio.CancelledError: | |
| # Task was cancelled as part of normal shutdown; ignore. | |
| pass | |
| drain_task.cancel() | |
| try: | |
| await drain_task | |
| except asyncio.CancelledError: | |
| # Task was cancelled as part of normal shutdown; ignore. |
| duplicate_search_detected = True | ||
| break | ||
| executed_search_calls.add(call_key) | ||
| except (KeyError, TypeError): |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| # base_url = "http://localhost:30001" | ||
| api_key = "" | ||
| else: | ||
| base_url = "http://localhost:30001" |
There was a problem hiding this comment.
This statement is unreachable.
No description provided.