From c11623596d6ca0693f616e8d710d85e4ea89cab3 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Thu, 12 Mar 2026 15:25:46 +0800 Subject: [PATCH] fix(memory): prevent context memory loss by improving trim strategy --- agent/protocol/agent_stream.py | 76 ++++++++++++++++++++++++--------- agent/protocol/message_utils.py | 57 +++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 19 deletions(-) diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py index df6b0db..49050a6 100644 --- a/agent/protocol/agent_stream.py +++ b/agent/protocol/agent_stream.py @@ -8,7 +8,7 @@ import time from typing import List, Dict, Any, Optional, Callable, Tuple from agent.protocol.models import LLMRequest, LLMModel -from agent.protocol.message_utils import sanitize_claude_messages +from agent.protocol.message_utils import sanitize_claude_messages, compress_turn_to_text_only from agent.tools.base_tool import BaseTool, ToolResult from common.log import logger @@ -191,6 +191,11 @@ class AgentStreamExecutor: ] }) + # Trim context ONCE before the agent loop starts, not during tool steps. + # This ensures tool_use/tool_result chains created during the current run + # are never stripped mid-execution (which would cause LLM loops). + self._trim_messages() + self._emit_event("agent_start") final_response = "" @@ -481,14 +486,10 @@ class AgentStreamExecutor: Returns: (response_text, tool_calls) """ - # Validate and fix message history first - self._validate_and_fix_messages() - - # Trim messages if needed (using agent's context management) - self._trim_messages() - - # Re-validate after trimming: trimming may produce new orphaned - # tool_result messages when it removes turns at the boundary. + # Validate and fix message history (e.g. orphaned tool_result blocks). + # Context trimming is done once in run_stream() before the loop starts, + # NOT here — trimming mid-execution would strip the current run's + # tool_use/tool_result chains and cause LLM loops. self._validate_and_fix_messages() # Prepare messages @@ -1165,10 +1166,10 @@ class AgentStreamExecutor: if not turns: return - # Step 2: 轮次限制 - 超出时裁到 max_turns/2,批量 flush 被裁的轮次 + # Step 2: 轮次限制 - 超出时移除前一半,保留后一半 if len(turns) > self.max_context_turns: - keep_count = max(1, self.max_context_turns // 2) - removed_count = len(turns) - keep_count + removed_count = len(turns) // 2 + keep_count = len(turns) - removed_count # Flush discarded turns to daily memory if self.agent.memory_manager: @@ -1223,9 +1224,47 @@ class AgentStreamExecutor: logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息") return - # Token limit exceeded - keep the latest half of turns (same strategy as turn limit) - keep_count = max(1, len(turns) // 2) - removed_count = len(turns) - keep_count + # Token limit exceeded — tiered strategy based on turn count: + # + # Few turns (<5): Compress ALL turns to text-only (strip tool chains, + # keep user query + final reply). Never discard turns + # — losing even one is too painful when context is thin. + # + # Many turns (>=5): Directly discard the first half of turns. + # With enough turns the oldest ones are less + # critical, and keeping the recent half intact + # (with full tool chains) is more useful. + + COMPRESS_THRESHOLD = 5 + + if len(turns) < COMPRESS_THRESHOLD: + # --- Few turns: compress ALL turns to text-only, never discard --- + compressed_turns = [] + for t in turns: + compressed = compress_turn_to_text_only(t) + if compressed["messages"]: + compressed_turns.append(compressed) + + new_messages = [] + for turn in compressed_turns: + new_messages.extend(turn["messages"]) + + new_tokens = sum(self._estimate_turn_tokens(t) for t in compressed_turns) + old_count = len(self.messages) + self.messages = new_messages + + logger.info( + f"📦 上下文tokens超限(轮次<{COMPRESS_THRESHOLD}): " + f"~{current_tokens + system_tokens} > {max_tokens}," + f"压缩全部 {len(turns)} 轮为纯文本 " + f"({old_count} -> {len(self.messages)} 条消息," + f"~{current_tokens + system_tokens} -> ~{new_tokens + system_tokens} tokens)" + ) + return + + # --- Many turns (>=5): discard the older half, keep the newer half --- + removed_count = len(turns) // 2 + keep_count = len(turns) - removed_count kept_turns = turns[-keep_count:] kept_tokens = sum(self._estimate_turn_tokens(t) for t in kept_turns) @@ -1234,7 +1273,6 @@ class AgentStreamExecutor: f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)" ) - # Flush discarded turns to daily memory if self.agent.memory_manager: discarded_messages = [] for turn in turns[:removed_count]: @@ -1245,14 +1283,14 @@ class AgentStreamExecutor: messages=discarded_messages, user_id=user_id, reason="trim", max_messages=0 ) - + new_messages = [] for turn in kept_turns: new_messages.extend(turn['messages']) - + old_count = len(self.messages) self.messages = new_messages - + logger.info( f" 移除了 {removed_count} 轮对话 " f"({old_count} -> {len(self.messages)} 条消息," diff --git a/agent/protocol/message_utils.py b/agent/protocol/message_utils.py index a9606c8..3215ed4 100644 --- a/agent/protocol/message_utils.py +++ b/agent/protocol/message_utils.py @@ -177,3 +177,60 @@ def _has_block_type(content: list, block_type: str) -> bool: isinstance(b, dict) and b.get("type") == block_type for b in content ) + + +def _extract_text_from_content(content) -> str: + """Extract plain text from a message content field (str or list of blocks).""" + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts = [ + b.get("text", "") + for b in content + if isinstance(b, dict) and b.get("type") == "text" + ] + return "\n".join(p for p in parts if p).strip() + return "" + + +def compress_turn_to_text_only(turn: Dict) -> Dict: + """ + Compress a full turn (with tool_use/tool_result chains) into a lightweight + text-only turn that keeps only the first user text and the last assistant text. + + This preserves the conversational context (what the user asked and what the + agent concluded) while stripping out the bulky intermediate tool interactions. + + Returns a new turn dict with a ``messages`` list; the original is not mutated. + """ + user_text = "" + last_assistant_text = "" + + for msg in turn["messages"]: + role = msg.get("role") + content = msg.get("content", []) + + if role == "user": + if isinstance(content, list) and _has_block_type(content, "tool_result"): + continue + if not user_text: + user_text = _extract_text_from_content(content) + + elif role == "assistant": + text = _extract_text_from_content(content) + if text: + last_assistant_text = text + + compressed_messages = [] + if user_text: + compressed_messages.append({ + "role": "user", + "content": [{"type": "text", "text": user_text}] + }) + if last_assistant_text: + compressed_messages.append({ + "role": "assistant", + "content": [{"type": "text", "text": last_assistant_text}] + }) + + return {"messages": compressed_messages}