@@ -418,6 +418,12 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N
418418 return
419419
420420 workflow_metrics , termination_counts = self ._collect_workflow_metrics_from_episodes (trainer_state .episodes )
421+ for key , value in workflow_metrics .items ():
422+ trainer_state .metrics [f"batch/{ key } " ] = np .mean (value )
423+
424+ total_counts = max (sum (termination_counts .values ()), 1 )
425+ for r in TerminationReason :
426+ trainer_state .metrics [f"batch/termination_reason/{ r .value } " ] = termination_counts [r .value ] / total_counts
421427
422428 # stage 2: transform episodes to trajectory groups (sync)
423429 trajectory_groups , transform_metrics = transform_episodes_to_trajectory_groups (trainer_state .episodes , self .transform_config , self .cf_config , traj_grouping_hook = self .traj_grouping_hook )
@@ -461,13 +467,6 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N
461467 show_workflow_metadata = True ,
462468 )
463469
464- for key , value in workflow_metrics .items ():
465- trainer_state .metrics [f"batch/{ key } " ] = np .mean (value )
466-
467- total_counts = max (sum (termination_counts .values ()), 1 )
468- for r in TerminationReason :
469- trainer_state .metrics [f"batch/termination_reason/{ r .value } " ] = termination_counts [r .value ] / total_counts
470-
471470 # =========================================================================
472471 # Fully-asynchronous training pipeline
473472 # =========================================================================
@@ -809,14 +808,15 @@ def shutdown(self):
809808 # =========================================================================
810809 # Helper functions
811810 # =========================================================================
812- def _collect_workflow_metrics_from_episodes (self , episodes : list [Episode ]) -> tuple [dict , Counter ]:
811+ @staticmethod
812+ def _collect_workflow_metrics_from_episodes (episodes : list [Episode ]) -> tuple [dict , Counter ]:
813813 workflow_metrics = defaultdict (list )
814814 termination_counts = Counter ()
815815 for episode in episodes :
816816 for k , v in episode .metrics .items ():
817817 workflow_metrics [k ].append (v )
818- if episode .termination_reason is not None :
819- termination_counts [episode . termination_reason . value ] += 1
818+ reason = episode .termination_reason or TerminationReason . UNKNOWN
819+ termination_counts [getattr ( reason , " value" , reason ) ] += 1
820820 # reduce the metrics to a scalar value, with error handling
821821 reduced_workflow_metrics = {}
822822 for k , v in workflow_metrics .items ():
0 commit comments