Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion rllm/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

from .agent_execution_engine import AgentExecutionEngine, AsyncAgentExecutionEngine
from .agent_workflow_engine import AgentWorkflowEngine
from .rollout.openai_engine import OpenAIEngine
from .rollout.rollout_engine import RolloutEngine

Expand All @@ -23,3 +22,11 @@
__all__.append("VerlEngine")
except Exception:
VerlEngine = None


def __getattr__(name):
if name == "AgentWorkflowEngine":
from .agent_workflow_engine import AgentWorkflowEngine as _AgentWorkflowEngine

return _AgentWorkflowEngine
raise AttributeError(name)
7 changes: 6 additions & 1 deletion rllm/trainer/agent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def __init__(

def train(self):
if not ray.is_initialized():
ray.init(runtime_env=get_ppo_ray_runtime_env(), num_cpus=self.config.ray_init.num_cpus)
# read off all the `ray_init` settings from the config
if self.config is not None and hasattr(self.config, "ray_init"):
ray_init_settings = {k: v for k, v in self.config.ray_init.items() if v is not None}
else:
ray_init_settings = {}
ray.init(runtime_env=get_ppo_ray_runtime_env(), **ray_init_settings)

runner = TaskRunner.remote()

Expand Down