fix(memory): prevent context memory loss by improving trim strategy

This commit is contained in:
zhayujie
2026-03-12 15:25:46 +08:00
parent e791a77f77
commit c11623596d
2 changed files with 114 additions and 19 deletions

View File

@@ -8,7 +8,7 @@ import time
from typing import List, Dict, Any, Optional, Callable, Tuple
from agent.protocol.models import LLMRequest, LLMModel
from agent.protocol.message_utils import sanitize_claude_messages
from agent.protocol.message_utils import sanitize_claude_messages, compress_turn_to_text_only
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
@@ -191,6 +191,11 @@ class AgentStreamExecutor:
]
})
# Trim context ONCE before the agent loop starts, not during tool steps.
# This ensures tool_use/tool_result chains created during the current run
# are never stripped mid-execution (which would cause LLM loops).
self._trim_messages()
self._emit_event("agent_start")
final_response = ""
@@ -481,14 +486,10 @@ class AgentStreamExecutor:
Returns:
(response_text, tool_calls)
"""
# Validate and fix message history first
self._validate_and_fix_messages()
# Trim messages if needed (using agent's context management)
self._trim_messages()
# Re-validate after trimming: trimming may produce new orphaned
# tool_result messages when it removes turns at the boundary.
# Validate and fix message history (e.g. orphaned tool_result blocks).
# Context trimming is done once in run_stream() before the loop starts,
# NOT here — trimming mid-execution would strip the current run's
# tool_use/tool_result chains and cause LLM loops.
self._validate_and_fix_messages()
# Prepare messages
@@ -1165,10 +1166,10 @@ class AgentStreamExecutor:
if not turns:
return
# Step 2: 轮次限制 - 超出时裁到 max_turns/2批量 flush 被裁的轮次
# Step 2: 轮次限制 - 超出时移除前一半,保留后一半
if len(turns) > self.max_context_turns:
keep_count = max(1, self.max_context_turns // 2)
removed_count = len(turns) - keep_count
removed_count = len(turns) // 2
keep_count = len(turns) - removed_count
# Flush discarded turns to daily memory
if self.agent.memory_manager:
@@ -1223,9 +1224,47 @@ class AgentStreamExecutor:
logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息")
return
# Token limit exceeded - keep the latest half of turns (same strategy as turn limit)
keep_count = max(1, len(turns) // 2)
removed_count = len(turns) - keep_count
# Token limit exceeded — tiered strategy based on turn count:
#
# Few turns (<5): Compress ALL turns to text-only (strip tool chains,
# keep user query + final reply). Never discard turns
# — losing even one is too painful when context is thin.
#
# Many turns (>=5): Directly discard the first half of turns.
# With enough turns the oldest ones are less
# critical, and keeping the recent half intact
# (with full tool chains) is more useful.
COMPRESS_THRESHOLD = 5
if len(turns) < COMPRESS_THRESHOLD:
# --- Few turns: compress ALL turns to text-only, never discard ---
compressed_turns = []
for t in turns:
compressed = compress_turn_to_text_only(t)
if compressed["messages"]:
compressed_turns.append(compressed)
new_messages = []
for turn in compressed_turns:
new_messages.extend(turn["messages"])
new_tokens = sum(self._estimate_turn_tokens(t) for t in compressed_turns)
old_count = len(self.messages)
self.messages = new_messages
logger.info(
f"📦 上下文tokens超限(轮次<{COMPRESS_THRESHOLD}): "
f"~{current_tokens + system_tokens} > {max_tokens}"
f"压缩全部 {len(turns)} 轮为纯文本 "
f"({old_count} -> {len(self.messages)} 条消息,"
f"~{current_tokens + system_tokens} -> ~{new_tokens + system_tokens} tokens)"
)
return
# --- Many turns (>=5): discard the older half, keep the newer half ---
removed_count = len(turns) // 2
keep_count = len(turns) - removed_count
kept_turns = turns[-keep_count:]
kept_tokens = sum(self._estimate_turn_tokens(t) for t in kept_turns)
@@ -1234,7 +1273,6 @@ class AgentStreamExecutor:
f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)"
)
# Flush discarded turns to daily memory
if self.agent.memory_manager:
discarded_messages = []
for turn in turns[:removed_count]:
@@ -1245,14 +1283,14 @@ class AgentStreamExecutor:
messages=discarded_messages, user_id=user_id,
reason="trim", max_messages=0
)
new_messages = []
for turn in kept_turns:
new_messages.extend(turn['messages'])
old_count = len(self.messages)
self.messages = new_messages
logger.info(
f" 移除了 {removed_count} 轮对话 "
f"({old_count} -> {len(self.messages)} 条消息,"

View File

@@ -177,3 +177,60 @@ def _has_block_type(content: list, block_type: str) -> bool:
isinstance(b, dict) and b.get("type") == block_type
for b in content
)
def _extract_text_from_content(content) -> str:
"""Extract plain text from a message content field (str or list of blocks)."""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts = [
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
]
return "\n".join(p for p in parts if p).strip()
return ""
def compress_turn_to_text_only(turn: Dict) -> Dict:
"""
Compress a full turn (with tool_use/tool_result chains) into a lightweight
text-only turn that keeps only the first user text and the last assistant text.
This preserves the conversational context (what the user asked and what the
agent concluded) while stripping out the bulky intermediate tool interactions.
Returns a new turn dict with a ``messages`` list; the original is not mutated.
"""
user_text = ""
last_assistant_text = ""
for msg in turn["messages"]:
role = msg.get("role")
content = msg.get("content", [])
if role == "user":
if isinstance(content, list) and _has_block_type(content, "tool_result"):
continue
if not user_text:
user_text = _extract_text_from_content(content)
elif role == "assistant":
text = _extract_text_from_content(content)
if text:
last_assistant_text = text
compressed_messages = []
if user_text:
compressed_messages.append({
"role": "user",
"content": [{"type": "text", "text": user_text}]
})
if last_assistant_text:
compressed_messages.append({
"role": "assistant",
"content": [{"type": "text", "text": last_assistant_text}]
})
return {"messages": compressed_messages}