Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions rllm/agents/swe_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import json
import logging
import re
Expand All @@ -10,8 +11,6 @@
from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
from rllm.agents.system_prompts import SWE_SYSTEM_PROMPT, SWE_SYSTEM_PROMPT_FN_CALL, SWE_USER_PROMPT, SWE_USER_PROMPT_FN_CALL, SWEAGENT_SYSTEM_PROMPT, SWEAGENT_USER_PROMPT

TOKEN_WARNING_THRESHOLD = 28000


def parse_oai_response(response):
thought = response.choices[0].message.content
Expand Down Expand Up @@ -59,7 +58,7 @@ def parse_xml_response(response_text: str) -> tuple[str, SWEAction]:


class SWEAgent(BaseAgent):
def __init__(self, use_fn_calling: bool = False, format_model_response: bool = False, scaffold: str = "r2egym"):
def __init__(self, use_fn_calling: bool = False, format_model_response: bool = False, scaffold: str = "r2egym", **kwargs):
self.use_fn_calling = use_fn_calling
self.format_model_response = format_model_response
self.scaffold = scaffold
Expand All @@ -70,6 +69,9 @@ def __init__(self, use_fn_calling: bool = False, format_model_response: bool = F
self.user_prompt_template = SWE_USER_PROMPT_FN_CALL if use_fn_calling else SWE_USER_PROMPT
if scaffold == "sweagent":
self.user_prompt_template = SWEAGENT_USER_PROMPT
self.token_warning_threshold = 1e9
if kwargs.get("token_warning_threshold") is not None:
self.token_warning_threshold = kwargs["token_warning_threshold"]

self._trajectory = Trajectory()
self.reset()
Expand Down Expand Up @@ -117,7 +119,7 @@ def update_from_env(self, observation, reward, done, info):
observation += "\nYou have reached the maximum number of steps. Please submit your answer NOW."

cur_tokens = info.get("cur_tokens", None)
if cur_tokens is not None and cur_tokens >= TOKEN_WARNING_THRESHOLD:
if cur_tokens is not None and cur_tokens >= self.token_warning_threshold:
observation += "\nYou are running out of tokens. Please submit your answer NOW."

if self._trajectory.steps:
Expand Down Expand Up @@ -162,6 +164,7 @@ def update_from_model(self, response: str, **kwargs):
self.messages.append({"role": "assistant", "content": f"{thought}\n\n{action_str}"})
else:
self.messages.append({"role": "assistant", "content": response})
cur_step.chat_completions = copy.deepcopy(self.chat_completions)
self.step += 1
return Action(action=cur_step.action)

Expand Down
Loading