Skip to content

[v1/v0] message_encode with hf_ChatTemplate #10473

@Kuangdd01

Description

@Kuangdd01

Reminder

  • I have read the above rules and searched the existing issues.

Description

In the coming v1, we plan to use hf official chat_template to manage utterances.

Make a proposal following.

It is easy for qwen3.6 template which will preserve history cot and won't change the history context.
Made cur_input_ids - pre_input_ids possible.

encode

import copy
import json

from datasets import load_dataset
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.6-35B-A3B")

v1_tool_dataset = load_dataset("llamafactory/reason-tool-use-demo-1500", split="train").select(range(8))

TOOLS = json.loads(v1_tool_dataset[0]["tools"])
v1_messages = v1_tool_dataset[0]["messages"]


def convert_to_hf_format(messages: list[dict]) -> list[dict]:
    # v1 type -> hf message type
    messages = copy.deepcopy(messages)
    for i in range(len(messages)):
        msg = messages[i]
        contents = msg.get("content") or []
        to_remove: list[int] = []
        reasoning_chunks: list[str] = []
        tool_calls: list[dict] = []

        for j, content_item in enumerate(contents):
            ctype = content_item.get("type")
            if ctype in ("reasoning", "reasoning_content"):
                reasoning_chunks.append(content_item["value"])
                to_remove.append(j)
            elif ctype == "tool_call":
                tc = json.loads(content_item["value"])
                tool_calls.append(
                    {"function": {"name": tc["name"], "arguments": tc.get("arguments", {})}}
                )
                to_remove.append(j)
            elif ctype == "text":
                content_item["text"] = content_item.pop("value", content_item.get("text", ""))

        for j in reversed(to_remove):
            del contents[j]
            if len(contents) == 0:
                del msg["content"]

        if reasoning_chunks:
            msg["reasoning_content"] = "\n".join(reasoning_chunks)

        if tool_calls:
            msg["tool_calls"] = tool_calls

    return messages


def _encode(messages: list[dict]):
    labels, input_ids = [], []
    label_turns: list[list[int]] = []
    n = len(messages)
    kwargs = {
        "tokenize": True,
        "add_generation_prompt": False,
        "tools": TOOLS,
        "preserve_thinking": True,
        "enable_thinking": True,
    }

    prefix_index = 0
    # find the first query message and fill the prefix ids
    for i in range(n):
        if messages[i]["role"] == "user":
            prefix_index = i
            break

    prefix_messages = messages[: prefix_index + 1]
    prefix_input_ids = tokenizer.apply_chat_template(prefix_messages, **kwargs)["input_ids"]
    input_ids.extend(prefix_input_ids)

    labels.extend([-100] * len(input_ids))
    prev_input_ids = prefix_input_ids
    for i in range(prefix_index + 1, n):
        cur_input_ids = tokenizer.apply_chat_template(messages[: i + 1], **kwargs)["input_ids"]
        this_turn_len = len(cur_input_ids) - len(prev_input_ids)
        turn_ids = cur_input_ids[-this_turn_len:]
        input_ids.extend(turn_ids)
        if messages[i]["role"] in ["user", "tool"]:
            labels.extend([-100] * this_turn_len)
        else:
            labels.extend(turn_ids)
            label_turns.append(turn_ids)

        prev_input_ids = cur_input_ids

    return labels, input_ids, label_turns


if __name__ == "__main__":
    v1_messages = convert_to_hf_format(v1_messages)
    labels, input_ids, label_turns = _encode(v1_messages)
    # visualize the input_ids and labels
    print("###input_ids:")
    print(tokenizer.decode(input_ids))
    print("=" * 100)
    print(f"####labels: ({len(label_turns)} supervised turn(s))")
    for idx, turn_ids in enumerate(label_turns):
        print(f"\n----- turn {idx + 1} / {len(label_turns)}  (len={len(turn_ids)}) -----")
        print(tokenizer.decode(turn_ids))

    assert len(labels) == len(input_ids)

output

###input_ids:
<|im_start|>system
# Tools

You have access to the following functions:

<tools>
{"type": "function", "function": {"name": "calculate_factorial", "description": "Calculates the factorial of a non-negative integer.\n\n    Args:\n        n (int): The non-negative integer.\n\n    Returns:\n        int: The factorial of n.\n\n    Raises:\n        ValueError: If n is negative.", "parameters": {"type": "object", "properties": {"n": {"type": "integer", "description": "The non-negative integer."}}, "required": ["n"]}}}
</tools>

If you choose to call a function ONLY reply in the following format with NO suffix:

<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>

<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT>

You are a methodical and expert assistant. Your primary goal is to solve user requests by leveraging a set of available tools. You must reason for the best course of action in a structured manner before responding.<|im_end|>
<|im_start|>user
Calculate the factorial for 25.<|im_end|>
<|im_start|>assistant
<think>
Alright, the user wants me to calculate the factorial of 25. Since there's a 'calculate_factorial' tool available, I can use that. Factorial of 25 is a large number, but the tool should handle it. Let me just call the function with n=25.
</think>

<tool_call>
<function=calculate_factorial>
<parameter=n>
25
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
[{"tool_call_id": "call_C5l3iwmhActJBZnZsg4PTymF", "name": "calculate_factorial", "content": "{'factorial': 15511210043330985984000000, 'n': 25}"}]
</tool_response><|im_end|>
<|im_start|>assistant
<think>
The assistant was given a factorial calculation for 25. The function calculate_factorial was called with n=25. The result is a very large number, 15511210043330985984000000. There's no need for further processing here, just present the result as is.
</think>

The factorial of 25 is 15511210043330985984000000.<|im_end|>
<|im_start|>user
I need to calculate the factorial for another number. Please determine the factorial of 30 using the provided function.<|im_end|>
<|im_start|>assistant
<think>
The user wants to compute the factorial of 30, which is a larger number than the previous request. I'll use the same 'calculate_factorial' function by providing n=30. Let's see what the result is.
</think>

<tool_call>
<function=calculate_factorial>
<parameter=n>
30
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
[{"tool_call_id": "call_Jd6idzuU1bTksV68EsijgP4m", "name": "calculate_factorial", "content": "{'factorial': 265252859812191058636308480000000, 'n': 30}"}]
</tool_response><|im_end|>
<|im_start|>assistant
<think>
The user has requested the factorial of 30. The calculation was performed using the 'calculate_factorial' function provided, and the result is 265252859812191058636308480000000. This is a very large number, typical for factorials of numbers in the 30s.
</think>

The factorial of 30 is 265252859812191058636308480000000.<|im_end|>

====================================================================================================
####labels: (4 supervised turn(s))

----- turn 1 / 4  (len=99) -----
<|im_start|>assistant
<think>
Alright, the user wants me to calculate the factorial of 25. Since there's a 'calculate_factorial' tool available, I can use that. Factorial of 25 is a large number, but the tool should handle it. Let me just call the function with n=25.
</think>

<tool_call>
<function=calculate_factorial>
<parameter=n>
25
</parameter>
</function>
</tool_call><|im_end|>


----- turn 2 / 4  (len=122) -----
<|im_start|>assistant
<think>
The assistant was given a factorial calculation for 25. The function calculate_factorial was called with n=25. The result is a very large number, 15511210043330985984000000. There's no need for further processing here, just present the result as is.
</think>

The factorial of 25 is 15511210043330985984000000.<|im_end|>


----- turn 3 / 4  (len=85) -----
<|im_start|>assistant
<think>
The user wants to compute the factorial of 30, which is a larger number than the previous request. I'll use the same 'calculate_factorial' function by providing n=30. Let's see what the result is.
</think>

<tool_call>
<function=calculate_factorial>
<parameter=n>
30
</parameter>
</function>
</tool_call><|im_end|>


----- turn 4 / 4  (len=136) -----
<|im_start|>assistant
<think>
The user has requested the factorial of 30. The calculation was performed using the 'calculate_factorial' function provided, and the result is 265252859812191058636308480000000. This is a very large number, typical for factorials of numbers in the 30s.
</think>

The factorial of 30 is 265252859812191058636308480000000.<|im_end|>

cc @hiyouga @jiaqiw09 for suggestion

Pull Request

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestpendingThis problem is yet to be addressed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions