Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions rllm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
Expand All @@ -12,7 +13,8 @@

@dataclass
class Step:
prompt_ids: list[int] = field(default_factory=list)
# this is to accomodate the fact that for backend like `tinker`, the prompt_ids might contain special image blocks
prompt_ids: list[int] | list[Any] = field(default_factory=list)
response_ids: list[int] = field(default_factory=list)
logprobs: list[float] = field(default_factory=list)

Expand Down Expand Up @@ -93,6 +95,19 @@ def from_dict(cls, data: dict) -> Step:
advantage=data["advantage"],
)

@classmethod
def from_model_output(cls, model_output: ModelOutput, messages: list[dict] | None = None, action: Any | None = None) -> Step:
return cls(
prompt_ids=model_output.prompt_ids or [],
response_ids=model_output.completion_ids or [],
logprobs=model_output.logprobs or [],
chat_completions=(messages or []) + [{"role": "assistant", "content": model_output.content, "reasoning": model_output.reasoning}],
thought=model_output.reasoning or "",
action=action,
model_response=model_output.content or "",
model_output=model_output,
)


@dataclass
class Action:
Expand Down Expand Up @@ -199,6 +214,14 @@ def from_dict(cls, data: dict) -> Episode:
info=data.get("info", {}),
)

@cached_property
def task_id(self) -> str:
return self.id.split(":")[0]

@cached_property
def rollout_idx(self) -> str:
return self.id.split(":")[1]


@dataclass
class TrajectoryGroup:
Expand All @@ -219,11 +242,11 @@ class TrajectoryGroup:
group_id: str = ""
metadata: list[dict] = field(default_factory=list)

@property
@cached_property
def group_role(self) -> str:
return self.group_id.split(":")[1] if ":" in self.group_id[:-1] else "all_groups"

@property
@cached_property
def task_id(self) -> str:
return self.group_id.split(":")[0]

Expand Down
65 changes: 35 additions & 30 deletions rllm/experimental/common/advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,41 +35,37 @@ def _calculate_reinforce_advantages(rewards: np.ndarray) -> np.ndarray:
return rewards


def _check_advantage_already_computed(group: TrajectoryGroup, group_role: str) -> tuple[bool, list[float]]:
"""Check if the advantage has already been computed for all steps in the trajectory group.
def _collect_precomputed_advantages(group: TrajectoryGroup, group_role: str) -> list[float]:
"""Collect pre-computed per-token advantages from all steps.

Returns True only when every step has a non-None advantage and, for list-valued
advantages, the length matches the corresponding logprobs length.
Called when use_precomputed_advantage is True. Steps with None or length-mismatched
advantages are defaulted to zero lists. Raises if step.advantage is a scalar float
(pre-computed advantages must be per-token lists).
"""
total_steps = 0
steps_with_advantage = 0
flattened_advantages = []
steps_missing = 0
total_steps = 0

for traj in group.trajectories:
if total_steps > steps_with_advantage:
break
for step in traj.steps:
total_steps += 1
if step.advantage is None:
break
# validate list-valued advantages against logprobs length
if isinstance(step.advantage, list):
if len(step.advantage) != len(step.logprobs):
logger.warning(f"[group={group_role}] Detected a step has advantage length {len(step.advantage)} but logprobs length {len(step.logprobs)}. Fall back to re-compute all advantages.")
break
else:
flattened_advantages.extend(step.advantage)
step.advantage = [0.0] * len(step.response_ids)
flattened_advantages.extend(step.advantage)
steps_missing += 1
elif isinstance(step.advantage, list):
if len(step.advantage) != len(step.response_ids):
logger.warning(f"[group={group_role}] Step has advantage length {len(step.advantage)} but response_ids length {len(step.response_ids)}. Defaulting to zeros.")
step.advantage = [0.0] * len(step.response_ids)
steps_missing += 1
flattened_advantages.extend(step.advantage)
else:
flattened_advantages.append(step.advantage)
steps_with_advantage += 1
raise ValueError(f"[group={group_role}] step.advantage must be a list when use_precomputed_advantage is True, got {type(step.advantage)}")

if steps_with_advantage < total_steps:
# give a warning if at least one step has advantage
if steps_with_advantage > 0:
logger.warning(f"[group={group_role}] Detected some steps have advantages already computed, while others do not. Fall back to re-compute all advantages. Please check the pre-computed advantage in workflow.")
return False, flattened_advantages
# all steps have advantage
return True, flattened_advantages
if steps_missing > 0:
logger.warning(f"[group={group_role}] {steps_missing}/{total_steps} steps missing pre-computed advantages, defaulted to zeros.")

return flattened_advantages


def collect_reward_and_advantage_from_trajectory_groups(
Expand Down Expand Up @@ -106,9 +102,20 @@ def collect_reward_and_advantage_from_trajectory_groups(
# TODO(listar2000): in the future, we should support per-trajectory-group advantage modes
for group in groups:
group_role = group.group_role
# check if the advantage has already been computed for all steps in the trajectory group
advantages_already_computed, flattened_advantages = _check_advantage_already_computed(group, group_role)
if not advantages_already_computed:

if algorithm_config.use_precomputed_advantage:
# Distillation mode: always use pre-computed per-token advantages from the workflow.
if collect_advantage:
flattened_advantages = _collect_precomputed_advantages(group, group_role)
advantages_by_group[group_role].extend(flattened_advantages)
else:
# RL mode: compute advantages from trajectory rewards.
if collect_advantage:
# Warn if steps have pre-computed advantages that will be overwritten.
has_any = any(step.advantage is not None for traj in group.trajectories for step in traj.steps)
if has_any:
logger.warning(f"[group={group_role}] Steps have pre-computed advantages but use_precomputed_advantage is False. Overwriting with {algorithm_config.estimator.value}.")

assert all(traj.reward is not None for traj in group.trajectories), "Trajectory reward cannot be None in broadcast mode"
traj_rewards = np.array([traj.reward for traj in group.trajectories])
rewards_by_group[group_role].extend(traj_rewards)
Expand All @@ -120,8 +127,6 @@ def collect_reward_and_advantage_from_trajectory_groups(
for traj, advantage in zip(group.trajectories, advantages, strict=False):
for step in traj.steps:
step.advantage = advantage
elif collect_advantage: # we simply need to collect the advantage
advantages_by_group[group_role].extend(flattened_advantages)

# reduce metrics by group
final_metrics = {}
Expand Down
5 changes: 5 additions & 0 deletions rllm/experimental/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ class AlgorithmConfig:
stepwise_advantage_mode: Literal["broadcast", "per_step"] = "broadcast"
norm_adv_by_std_in_grpo: bool = True
use_rllm: bool = False
# When True, always use pre-computed step.advantage from the workflow and skip
# advantage computation (GRPO/REINFORCE). Steps missing advantages default to 0.0.
# When False (default), always compute advantages normally.
use_precomputed_advantage: bool = False
# for tinker backend only
loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro", "cross_entropy"] | None = None
lr_schedule: Literal["linear", "cosine", "constant"] = "constant"
Expand All @@ -130,6 +134,7 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig":
stepwise_advantage_mode=config.rllm.stepwise_advantage.mode,
norm_adv_by_std_in_grpo=config.rllm.stepwise_advantage.get("norm_adv_by_std_in_grpo", True),
use_rllm=config.rllm.stepwise_advantage.get("use_rllm", False),
use_precomputed_advantage=config.rllm.algorithm.get("use_precomputed_advantage", False),
loss_fn=config.rllm.algorithm.get("loss_fn", None),
lr_schedule=config.rllm.algorithm.get("lr_schedule", "constant"),
warmup_steps_ratio=config.rllm.algorithm.get("warmup_steps_ratio", 0.0),
Expand Down
2 changes: 1 addition & 1 deletion rllm/experimental/common/rejection_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def update_episode_metrics(
if len(episode.trajectories) == 0:
continue
# Extract task_id from episode.id format "task_id:rollout_idx"
task_id = episode.id.split(":")[0]
task_id = episode.task_id
if task_id not in episodes_by_task:
episodes_by_task[task_id] = []
episodes_by_task[task_id].append(episode)
Expand Down
5 changes: 3 additions & 2 deletions rllm/experimental/common/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@ def _build_trajectory_groups(episodes: list[Episode], compact_filtering_config:
# skip episode if it should be masked by compact filtering
if compact_filtering_config and compact_filtering_config.should_mask(termination_reason):
continue
task_id = episode.id.split(":")[0]
task_id = episode.task_id
for trajectory in episode.trajectories:
if len(trajectory.steps) == 0:
continue
trajectories_by_name[f"{task_id}:{trajectory.name}"].append(trajectory)
metadata_by_name[f"{task_id}:{trajectory.name}"].append(
{
"episode_id": episode.id,
"task_id": episode.task_id,
"rollout_idx": episode.rollout_idx,
"termination_reason": episode.termination_reason,
"is_correct": episode.is_correct,
}
Expand Down
7 changes: 3 additions & 4 deletions rllm/experimental/common/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ def _visualize_metadata(trajectory: Trajectory, metadata: dict, config: Visualiz
Visualizes workflow metadata for a given trajectory.
"""
header_parts = []
if "episode_id" in metadata:
task_id, rollout_idx = metadata["episode_id"].split(":")
header_parts.append(f"Task ID: {task_id}")
header_parts.append(f"Rollout: #{rollout_idx}")
if "task_id" in metadata and "rollout_idx" in metadata:
header_parts.append(f"Task ID: {metadata['task_id']}")
header_parts.append(f"Rollout: #{metadata['rollout_idx']}")
header_parts.append(f"Trajectory: {trajectory.name}")
colorful_print(" | ".join(header_parts), **config.header_style)

Expand Down
1 change: 1 addition & 0 deletions rllm/experimental/config/rllm/backend/tinker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ rollout_engine:
accumulate_reasoning: false
disable_thinking: false
bypass_render_with_parser: true
renderer_name: null

# Data Configuration
data:
Expand Down
3 changes: 3 additions & 0 deletions rllm/experimental/config/rllm/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ algorithm:
lam: 0.95
norm_adv_by_std_in_grpo: true
use_rllm: false
# When true, always use pre-computed step.advantage from the workflow (e.g. distillation)
# and skip advantage computation (GRPO/REINFORCE). Missing advantages default to 0.
use_precomputed_advantage: false
# for tinker backend only (avaiable options: importance_sampling, ppo, cispo, dro, cross_entropy)
loss_fn: null

Expand Down
32 changes: 10 additions & 22 deletions rllm/experimental/rollout/completer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,8 @@ async def complete(self, messages: list[dict], action_hook: Callable[[ModelOutpu
model_output: ModelOutput = await self.rollout_engine.get_model_response(messages, **kwargs)

# construct the step
chat_completions = messages + [{"role": "assistant", "content": model_output.content or "", "reasoning": model_output.reasoning or ""}]
action = action_hook(model_output) if action_hook is not None else None
return Step(
prompt_ids=model_output.prompt_ids or [],
response_ids=model_output.completion_ids or [],
logprobs=model_output.logprobs or [],
chat_completions=chat_completions,
thought=model_output.reasoning or "",
action=action,
model_output=model_output, # type: ignore
)
return Step.from_model_output(model_output, messages, action) # type: ignore

def reset(self):
"""Reset the completer to its initial state."""
Expand Down Expand Up @@ -123,21 +114,18 @@ async def complete(self, messages: list[dict], action_hook: Callable[[ModelOutpu

# update the previous messages and token input
self._prev_messages_str = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, accumulate_reasoning=True)
self._prev_token_input = curr_token_input + curr_token_output.completion_ids
# backend-specific handling for retrieving the completion ids
if hasattr(curr_token_output, "token_ids"): # Verl
curr_completion_ids: list[int] = curr_token_output.token_ids # type: ignore[assignment]
elif hasattr(curr_token_output, "tokens"): # Tinker
curr_completion_ids: list[int] = curr_token_output.tokens
else:
raise ValueError(f"Unsupported token output type: {type(curr_token_output)}")
# update the number of completions and prefixes
self._n_completions += 1
self._n_prefixes += int(is_prefix)

return Step(
prompt_ids=model_output.prompt_ids or [],
response_ids=model_output.completion_ids or [],
logprobs=model_output.logprobs or [],
chat_completions=messages + [{"role": "assistant", "content": model_output.content, "reasoning": model_output.reasoning}],
thought=model_output.reasoning or "",
action=action,
model_response=model_output.content or "",
model_output=model_output, # type: ignore
)
self._prev_token_input = curr_token_input + curr_completion_ids
return Step.from_model_output(model_output, messages, action) # type: ignore

def reset(self):
"""Reset the completer to its initial state."""
Expand Down
Loading