Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions rllm/trainer/verl/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


class RLLMSFTDataset(MultiTurnSFTDataset):
def __init__(self, parquet_files: str | list[str], tokenizer, config=None):
super().__init__(parquet_files, tokenizer, config)
def __init__(self, parquet_files: str | list[str], tokenizer, config=None, processor=None, max_samples=-1):
super().__init__(parquet_files, tokenizer, config, processor=processor, max_samples=max_samples)

self.tokenize_and_mask_method = config.rllm.tokenize_and_mask_method
logger.info(f"Using {self.tokenize_and_mask_method} tokenization and masking method")
Expand All @@ -22,6 +22,8 @@ def _tokenize_and_mask(self, messages):
return self._tokenize_and_mask_cumulative(messages)
elif self.tokenize_and_mask_method == "stepwise":
return self._tokenize_and_mask_stepwise(messages)
elif self.tokenize_and_mask_method == "hf_template":
return self._tokenize_and_mask_hf_template(messages)
else:
raise ValueError(f"Unknown tokenize_and_mask_method {self.tokenize_and_mask_method}")

Expand All @@ -40,6 +42,40 @@ def _tokenize_and_mask_cumulative(self, messages):

return tokens, loss_mask

def _tokenize_and_mask_hf_template(self, messages):
"""Use HF tokenizer.apply_chat_template for native tool call rendering.

Renders incrementally: messages[0:i] vs messages[0:i+1] to isolate each
message's tokens, then applies loss mask only on assistant tokens.
"""
full_text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False,
)
full_ids = self.tokenizer.encode(full_text, add_special_tokens=False)

# Build prefix lengths to find boundaries
prefix_lengths = [0] # char offset where each message starts
for i in range(len(messages)):
prefix_text = self.tokenizer.apply_chat_template(
messages[: i + 1], tokenize=False, add_generation_prompt=False,
)
prefix_lengths.append(len(prefix_text))

# Tokenize each segment and assign loss mask
tokens = []
loss_mask = []
for i in range(len(messages)):
segment = full_text[prefix_lengths[i] : prefix_lengths[i + 1]]
seg_ids = self.tokenizer.encode(segment, add_special_tokens=False)

if messages[i]["role"] == "assistant":
loss_mask.extend([1] * len(seg_ids))
else:
loss_mask.extend([0] * len(seg_ids))
tokens.extend(seg_ids)

return tokens, loss_mask

def _tokenize_and_mask_stepwise(self, messages):
tokens = []
loss_mask = []
Expand Down
Loading