Skip to content

Commit 0956764

Browse files
authored
feat: Let AgentFlow use is_validation and preserve TerminationReason in metrics (#523)
* add is_validation to AgentConfig and let TerminationReason flow to metrics * clean up tests
1 parent 1eb76a4 commit 0956764

6 files changed

Lines changed: 96 additions & 11 deletions

File tree

rllm/experimental/engine/agent_flow_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ async def _run_single(self, task: dict, uid: str, is_validation: bool = False) -
231231
base_url=session_url,
232232
model=self.model,
233233
session_uid=uid,
234+
is_validation=is_validation,
234235
)
235236

236237
# 3. Run agent flow (prefers arun if available, else run in executor)
@@ -276,7 +277,8 @@ async def _run_single(self, task: dict, uid: str, is_validation: bool = False) -
276277
for signal in eval_output.signals:
277278
enriched.metrics[signal.name] = signal.value
278279

279-
enriched.termination_reason = TerminationReason.ENV_DONE
280+
if enriched.termination_reason is None:
281+
enriched.termination_reason = TerminationReason.ENV_DONE
280282
return enriched
281283

282284
def _enrich_episode(

rllm/experimental/unified_trainer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,12 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N
418418
return
419419

420420
workflow_metrics, termination_counts = self._collect_workflow_metrics_from_episodes(trainer_state.episodes)
421+
for key, value in workflow_metrics.items():
422+
trainer_state.metrics[f"batch/{key}"] = np.mean(value)
423+
424+
total_counts = max(sum(termination_counts.values()), 1)
425+
for r in TerminationReason:
426+
trainer_state.metrics[f"batch/termination_reason/{r.value}"] = termination_counts[r.value] / total_counts
421427

422428
# stage 2: transform episodes to trajectory groups (sync)
423429
trajectory_groups, transform_metrics = transform_episodes_to_trajectory_groups(trainer_state.episodes, self.transform_config, self.cf_config, traj_grouping_hook=self.traj_grouping_hook)
@@ -461,13 +467,6 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N
461467
show_workflow_metadata=True,
462468
)
463469

464-
for key, value in workflow_metrics.items():
465-
trainer_state.metrics[f"batch/{key}"] = np.mean(value)
466-
467-
total_counts = max(sum(termination_counts.values()), 1)
468-
for r in TerminationReason:
469-
trainer_state.metrics[f"batch/termination_reason/{r.value}"] = termination_counts[r.value] / total_counts
470-
471470
# =========================================================================
472471
# Fully-asynchronous training pipeline
473472
# =========================================================================
@@ -809,14 +808,15 @@ def shutdown(self):
809808
# =========================================================================
810809
# Helper functions
811810
# =========================================================================
812-
def _collect_workflow_metrics_from_episodes(self, episodes: list[Episode]) -> tuple[dict, Counter]:
811+
@staticmethod
812+
def _collect_workflow_metrics_from_episodes(episodes: list[Episode]) -> tuple[dict, Counter]:
813813
workflow_metrics = defaultdict(list)
814814
termination_counts = Counter()
815815
for episode in episodes:
816816
for k, v in episode.metrics.items():
817817
workflow_metrics[k].append(v)
818-
if episode.termination_reason is not None:
819-
termination_counts[episode.termination_reason.value] += 1
818+
reason = episode.termination_reason or TerminationReason.UNKNOWN
819+
termination_counts[getattr(reason, "value", reason)] += 1
820820
# reduce the metrics to a scalar value, with error handling
821821
reduced_workflow_metrics = {}
822822
for k, v in workflow_metrics.items():

rllm/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class AgentConfig:
144144
model: str
145145
session_uid: str
146146
metadata: dict = field(default_factory=dict)
147+
is_validation: bool = False
147148

148149

149150
@runtime_checkable
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import asyncio
2+
3+
from rllm.agents.agent import Episode, Trajectory
4+
from rllm.eval.types import EvalOutput
5+
from rllm.experimental.engine.agent_flow_engine import AgentFlowEngine
6+
from rllm.workflows.workflow import TerminationReason
7+
8+
9+
class _Agent:
10+
def __init__(self):
11+
self.config = None
12+
13+
async def arun(self, task, config):
14+
self.config = config
15+
return Episode(
16+
id=task.id,
17+
termination_reason=TerminationReason.ERROR,
18+
trajectories=[Trajectory(name="solver")],
19+
)
20+
21+
22+
class _Evaluator:
23+
def evaluate(self, task, episode):
24+
return EvalOutput(reward=0.0, is_correct=False)
25+
26+
27+
class _Gateway:
28+
def __init__(self):
29+
self.created = None
30+
self.deleted = None
31+
32+
async def acreate_session(self, session_id, is_validation=False):
33+
self.created = (session_id, is_validation)
34+
35+
def get_session_url(self, session_id):
36+
return f"http://gateway/{session_id}"
37+
38+
async def aget_traces(self, session_id):
39+
return []
40+
41+
async def adelete_session(self, session_id):
42+
self.deleted = session_id
43+
44+
45+
def test_run_single_passes_validation_flag_and_preserves_termination_reason():
46+
agent = _Agent()
47+
gateway = _Gateway()
48+
engine = AgentFlowEngine(
49+
agent_flow=agent,
50+
evaluator=_Evaluator(),
51+
gateway=gateway,
52+
model="test-model",
53+
n_parallel_tasks=1,
54+
)
55+
56+
try:
57+
episode = asyncio.run(engine._run_single({"question": "q"}, "task:0", is_validation=True))
58+
finally:
59+
engine.shutdown()
60+
61+
assert gateway.created == ("task:0", True)
62+
assert gateway.deleted == "task:0"
63+
assert agent.config.is_validation is True
64+
assert agent.config.session_uid == "task:0"
65+
assert episode.termination_reason == TerminationReason.ERROR

tests/eval/test_eval_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def test_signals(self):
188188
def test_agent_config_defaults():
189189
config = AgentConfig(base_url="http://localhost:8000", model="test-model", session_uid="s1")
190190
assert config.metadata == {}
191+
assert config.is_validation is False
191192

192193

193194
# ---------------------------------------------------------------------------
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from rllm.agents.agent import Episode
2+
from rllm.experimental.unified_trainer import UnifiedTrainer
3+
from rllm.workflows.workflow import TerminationReason
4+
5+
6+
def test_collect_workflow_metrics_counts_unknown_termination_reason():
7+
episodes = [
8+
Episode(id="task:0", termination_reason=None, metrics={"custom": 1.0}),
9+
Episode(id="task:1", termination_reason=TerminationReason.ERROR, metrics={"custom": 3.0}),
10+
]
11+
12+
workflow_metrics, termination_counts = UnifiedTrainer._collect_workflow_metrics_from_episodes(episodes)
13+
14+
assert workflow_metrics["custom"] == 2.0
15+
assert termination_counts[TerminationReason.UNKNOWN.value] == 1
16+
assert termination_counts[TerminationReason.ERROR.value] == 1

0 commit comments

Comments
 (0)