Trae-Agent中的 selector核心逻辑

0 阅读6分钟

Trae Agent Selector 选择逻辑分析

概述

Selector(选择器)是 Trae Agent 评估模块中的核心组件,用于从多个候选 Patch 中选择最佳解决方案。本文档详细分析 Selector 的核心逻辑和实现。


一、整体架构

┌─────────────────────────────────────────────────────────────────────────────┐
                      Patch Selection 流程                                    
└─────────────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────────────┐
  Phase 1: 准备阶段                                                          
  ┌─────────────────────────────────────────────────────────────────────┐   
    - 读取候选 Patch 列表 (candidates.jsonl)                               
    - 回归测试过滤                                                         
    - Patch 去重                                                           
    - 分组处理                                                             
  └─────────────────────────────────────────────────────────────────────┘   
└─────────────────────────────────────────────────────────────────────────────┘
                                     
                                     
┌─────────────────────────────────────────────────────────────────────────────┐
  Phase 2: 选择阶段                                                          
  ┌─────────────────────────────────────────────────────────────────────┐   
    SelectorAgent 运行                                                     
    - 分析每个候选 Patch                                                    
    - 在沙箱中验证(可选)                                                  
    - 选择最佳 Patch                                                        
  └─────────────────────────────────────────────────────────────────────┘   
└─────────────────────────────────────────────────────────────────────────────┘
                                     
                                     
┌─────────────────────────────────────────────────────────────────────────────┐
  Phase 3: 投票阶段(可选)                                                   
  ┌─────────────────────────────────────────────────────────────────────┐   
    - 多次运行 SelectorAgent                                               
    - 统计选择频率                                                         
    - 多数投票得出最终结果                                                  
  └─────────────────────────────────────────────────────────────────────┘   
└─────────────────────────────────────────────────────────────────────────────┘

二、核心组件

2.1 入口:selector.py

文件: evaluation/patch_selection/selector.py

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--instances_path", required=True)    # 实例列表
    parser.add_argument("--candidate_path", required=True)   # 候选 Patch 文件
    parser.add_argument("--num_candidate", type=int, default=10)
    parser.add_argument("--group_size", type=int, default=10)  # 分组大小
    parser.add_argument("--majority_voting", action="store_true")  # 多数投票
    args = parser.parse_args()
    
    # 加载候选 Patch
    candidate_dic = {}
    with open(args.candidate_path, "r") as file:
        for line in file.readlines():
            candidate = json.loads(line.strip())
            candidate_dic[candidate["instance_id"]] = candidate
    
    # 创建评估器
    evaluation = SelectorEvaluation(
        llm_config,
        args.num_candidate,
        args.max_retry,
        args.max_turn,
        args.log_path,
        args.output_path,
        args.patches_path,
        instance_list,
        candidate_dic,
        tools_path,
        args.statistics_path,
        args.group_size,
        majority_voting=args.majority_voting,
    )
    
    # 运行所有实例
    evaluation.run_all(max_workers=args.max_workers)

2.2 SelectorEvaluation 类

文件: evaluation/patch_selection/trae_selector/selector_evaluation.py

2.2.1 分组处理
def run_instance(
    instance,
    candidate_log,  # 候选 Patch 日志
    num_candidate,
    group_size,
    ...
):
    """将候选 Patch 分组处理"""
    # 将 N 个候选分为 M 组
    groups = []
    for i in range(0, num_candidate, group_size):
        this_group = {
            "instance_id": candidate_log["instance_id"],
            "issue": candidate_log["issue"],
            "patches": candidate_log["patches"][i:i + group_size],
            "regressions": candidate_log["regressions"][i:i + group_size],
            "success_id": candidate_log["success_id"][i:i + group_size],
        }
        groups.append(this_group)
    
    # 每组独立选择
    for group_id, group in enumerate(groups):
        run_instance_by_group(instance=instance, candidate_log=group, ...)

分组策略示例

50 个候选 Patch,group_size=10Group 0: Patch 0-9
Group 1: Patch 10-19
Group 2: Patch 20-29
Group 3: Patch 30-39
Group 4: Patch 40-49
  ↓
每组选择 1 个 → 共 5 个候选
  ↓
最终选择 1
2.2.2 单组处理流程
def run_instance_by_group(instance, candidate_log, ...):
    """处理单个组的候选 Patch"""
    
    # 1. 检查是否已处理
    file_path = statistics_path + f"/group_{group_id}/{instance['instance_id']}.json"
    if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
        print("Already processed. Skipping...")
        return
    
    # 2. 检查是否全部失败/成功(边界情况)
    all_failed = all(success_id == 0 for success_id in candidate_log["success_id"])
    all_success = all(success_id == 1 for success_id in candidate_log["success_id"])
    if all_failed or all_success:
        # 直接保存结果,跳过选择
        save_patches(...)
        save_selection_success(...)
        return
    
    # 3. 构建候选列表
    candidate_list = []
    for idx in range(len(candidate_log["patches"])):
        candidate_list.append(CandidatePatch(
            id=idx,
            patch=candidate_log["patches"][idx],
            cleaned_patch=clean_patch(candidate_log["patches"][idx]),
            is_success_regression=len(candidate_log["regressions"][idx]) == 0,
            is_success_patch=candidate_log["success_id"][idx],
        ))
    
    # 4. 回归测试过滤
    candidate_list_regression = [
        c for c in candidate_list if c.is_success_regression
    ]
    if len(candidate_list_regression) > 0:
        candidate_list = candidate_list_regression
    
    # 5. Patch 去重
    candidate_list_deduplication = []
    cleaned_candidate_set = set()
    for candidate in candidate_list:
        if candidate.cleaned_patch not in cleaned_candidate_set:
            cleaned_candidate_set.add(candidate.cleaned_patch)
            candidate_list_deduplication.append(candidate)
    candidate_list = candidate_list_deduplication
    
    # 6. 创建沙箱
    sandbox = Sandbox(namespace, image_name, tag, instance, tools_path)
    sandbox.start_container()
    project_path = sandbox.get_project_path()
    
    # 7. 运行选择
    if majority_voting:
        # 多数投票模式
        final_id_list, final_patch_list = [], []
        for idx in range(num_candidate):
            select_agent = SelectorAgent(...)
            final_id, final_patch = select_agent.run()
            final_id_list.append(final_id)
            final_patch_list.append(final_patch)
            
            # 提前终止
            if max(Counter(final_id_list).values()) > num_candidate / 2:
                break
        
        # 统计投票
        counter = Counter(final_id_list)
        max_count = max(counter.values())
        most_common_ids = [elem for elem, count in counter.items() if count == max_count]
        final_id = most_common_ids[0]
        final_patch = final_patch_list[final_id_list.index(final_id)]
    else:
        # 单次选择模式
        select_agent = SelectorAgent(...)
        final_id, final_patch = select_agent.run()
    
    # 8. 保存结果
    save_patches(instance_id=instance["instance_id"], patches_path=patches_path,
                 patches=final_patch, group_id=group_id)
    save_selection_success(instance_id=instance["instance_id"], ...)
    
    sandbox.stop_container()

2.3 SelectorAgent 类

文件: evaluation/patch_selection/trae_selector/selector_agent.py

class SelectorAgent:
    def __init__(
        self,
        llm_config: ModelConfig,
        sandbox: Sandbox,
        project_path: str,
        issue_description: str,
        trajectory_file_name: str,
        candidate_list: list[CandidatePatch],
        max_turn: int = 50,
    ):
        self.llm_config = llm_config
        self.max_turn = max_turn
        self.sandbox = sandbox
        self.sandbox_session = self.sandbox.get_session()
        
        # 重置代码到基线
        self.sandbox_session.execute("git reset --hard HEAD")
        
        # 初始化工具(仅 bash 和 str_replace_based_edit_tool)
        self.tools = [
            tools_registry[tool_name](model_provider=llm_config.model_provider.provider)
            for tool_name in ["bash", "str_replace_based_edit_tool"]
        ]
        
        # 初始化 LLM 客户端
        self.llm_client = LLMClient(llm_config)
        
        # 构建初始消息
        self.initial_messages = [
            LLMMessage(role="system", content=build_system_prompt(len(candidate_list)))
        ]
        
        # 添加用户提示
        user_prompt = f"""
[Codebase path]: {project_path}
[Github issue description]:

{issue_description}

[Candidate Patches]:
"""
        for idx, candidate in enumerate(candidate_list):
            user_prompt += f"\nPatch-{idx + 1}:\n```\n{candidate.patch}\n```"
        
        self.initial_messages.append(LLMMessage(role="user", content=user_prompt))

2.4 系统提示词设计

def build_system_prompt(candidate_length: int) -> str:
    return f"""\
# ROLE: Act as an expert code evaluator. 
Given a codebase, an github issue and **{candidate_length} candidate patches** 
proposed by your colleagues, your responsibility is to **select the correct one** 
to solve the issue.

# WORK PROCESS:
1. Understand the Issue and Codebase
2. Analyze the Candidate Patches
3. Validate Functionality (Optional but Recommended)
4. Select the Best Patch

# FINAL REPORT:
### Status: succeed
### Result: Patch-x
### Analysis: [Explain why Patch-x is correct.]

# IMPORTANT TIPS:
1. Never avoid making a selection.
2. Do not propose new patches.
3. There must be at least one correct patch.
"""

三、执行流程

3.1 SelectorAgent.run() 方法

def run(self):
    """Selector Agent 主循环"""
    turn = 0
    final_id, final_patch = self.candidate_list[0].id, self.candidate_list[0].patch
    messages = self.initial_messages
    
    while turn < self.max_turn:
        turn += 1
        
        # 1. 调用 LLM
        llm_response = self.llm_client.chat(messages, self.llm_config, self.tools)
        
        # 2. 记录轨迹
        self.trajectory_recorder.record_llm_interaction(...)
        
        # 3. 检查是否完成选择
        match = re.search(
            r"Status:\s*(success|succeed).*\n.*Result:\s*Patch-(\d+)",
            llm_response.content,
        )
        
        if match:
            # 提取选择的 Patch
            selected_idx = int(match.group(2)) - 1
            if selected_idx < len(self.candidate_list):
                final_id = self.candidate_list[selected_idx].id
                final_patch = self.candidate_list[selected_idx].patch
            break
        
        # 4. 执行工具调用
        messages += parse_tool_response(
            llm_response, self.sandbox_session
        )
    
    # 清理
    self.trajectory_recorder.finalize_recording(True, final_patch)
    self.sandbox_session.execute("git reset --hard HEAD")
    self.sandbox_session.close()
    
    return final_id, final_patch

3.2 工具响应解析

def parse_tool_response(answer: LLMResponse, sandbox_session):
    """解析 LLM 工具调用并在沙箱中执行"""
    result = []
    
    for tool_call in answer.tool_calls:
        tool_call_id = tool_call.call_id
        tool_name = tool_call.name
        
        # 1. 构建执行命令
        if tool_name == "str_replace_based_edit_tool":
            cmd = "cd /home/swe-bench/tools/ && /home/swe-bench/py312/bin/python3 execute_str_replace_editor.py"
        elif tool_name == "bash":
            cmd = "cd /home/swe-bench/tools/ && /home/swe-bench/py312/bin/python3 execute_bash.py"
        else:
            # 未知工具
            result.append(LLMMessage(
                role="user",
                content="Tool not available...",
                tool_result=ToolResult(success=False, ...)
            ))
            continue
        
        # 2. 添加参数
        for key, value in tool_call.arguments.items():
            cmd += f" --{key} {shlex.quote(str(value))}"
        
        # 3. 在沙箱中执行
        cmd += " > /home/swe-bench/tools/log.out 2>&1"
        sandbox_session.execute(cmd)
        sandbox_res = sandbox_session.execute("cat /home/swe-bench/tools/log.out")
        
        # 4. 解析执行结果
        status = "Tool Call Status: 0"  # 假设成功
        if "Tool Call Status: -1" in sandbox_res:
            status = "Tool Call Status: -1"
        
        result.append(LLMMessage(
            role="user",
            content=sandbox_res,
            tool_result=ToolResult(
                success=status == "Tool Call Status: 0",
                ...
            )
        ))
    
    return result

四、关键数据结构

4.1 CandidatePatch

class CandidatePatch:
    def __init__(
        self,
        id,                       # Patch ID
        patch,                    # 原始 Patch
        cleaned_patch,            # 清理后的 Patch(用于去重)
        is_success_regression,    # 是否通过回归测试
        is_success_patch,         # 是否正确(Ground Truth)
    ):
        self.id = id
        self.patch = patch
        self.cleaned_patch = cleaned_patch
        self.is_success_regression = is_success_regression
        self.is_success_patch = is_success_patch

4.2 输入数据格式

{
    "instance_id": "astropy__astropy-14369",
    "issue": "Issue description...",
    "patches": [
        "patch diff 1",
        "patch diff 2",
        "patch diff N"
    ],
    "success_id": [1, 0, 1],
    "regressions": [
        [],
        ["test_module.py::test_func"],
        []
    ]
}
字段说明
instance_id实例 ID
issue问题描述
patches候选 Patch 列表
success_id是否正确 (1=正确, 0=错误)
regressions回归测试失败的测试列表

五、输出结构

results/
├── log/
│   └── group_0/
│       └── instance_id_voting_0_trail_1.json   # LLM 交互日志
├── output/
│   └── group_0/
│       └── instance_id.log                      # 标准输出
├── patch/
│   └── group_0/
│       └── instance_id_1.patch                  # 选中的 Patch
└── statistics/
    └── group_0/
        └── instance_id.json                     # 统计结果

六、设计特点

特点说明
分组处理避免上下文过长,提高选择准确性
回归过滤优先选择通过回归测试的 Patch
去重机制避免重复分析相同的 Patch
多数投票提高选择稳定性
沙箱验证在隔离环境中验证 Patch
轨迹记录完整记录选择过程用于分析

七、关键文件

文件功能
evaluation/patch_selection/selector.py主入口
evaluation/patch_selection/trae_selector/selector_evaluation.py评估协调
evaluation/patch_selection/trae_selector/selector_agent.pyAgent 实现
evaluation/patch_selection/trae_selector/sandbox.py沙箱环境

最后更新: 2026-03-16