mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-12 18:01:30 +08:00
fix(memory): prevent context memory loss by improving trim strategy
This commit is contained in:
@@ -8,7 +8,7 @@ import time
|
||||
from typing import List, Dict, Any, Optional, Callable, Tuple
|
||||
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.message_utils import sanitize_claude_messages
|
||||
from agent.protocol.message_utils import sanitize_claude_messages, compress_turn_to_text_only
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
@@ -191,6 +191,11 @@ class AgentStreamExecutor:
|
||||
]
|
||||
})
|
||||
|
||||
# Trim context ONCE before the agent loop starts, not during tool steps.
|
||||
# This ensures tool_use/tool_result chains created during the current run
|
||||
# are never stripped mid-execution (which would cause LLM loops).
|
||||
self._trim_messages()
|
||||
|
||||
self._emit_event("agent_start")
|
||||
|
||||
final_response = ""
|
||||
@@ -481,14 +486,10 @@ class AgentStreamExecutor:
|
||||
Returns:
|
||||
(response_text, tool_calls)
|
||||
"""
|
||||
# Validate and fix message history first
|
||||
self._validate_and_fix_messages()
|
||||
|
||||
# Trim messages if needed (using agent's context management)
|
||||
self._trim_messages()
|
||||
|
||||
# Re-validate after trimming: trimming may produce new orphaned
|
||||
# tool_result messages when it removes turns at the boundary.
|
||||
# Validate and fix message history (e.g. orphaned tool_result blocks).
|
||||
# Context trimming is done once in run_stream() before the loop starts,
|
||||
# NOT here — trimming mid-execution would strip the current run's
|
||||
# tool_use/tool_result chains and cause LLM loops.
|
||||
self._validate_and_fix_messages()
|
||||
|
||||
# Prepare messages
|
||||
@@ -1165,10 +1166,10 @@ class AgentStreamExecutor:
|
||||
if not turns:
|
||||
return
|
||||
|
||||
# Step 2: 轮次限制 - 超出时裁到 max_turns/2,批量 flush 被裁的轮次
|
||||
# Step 2: 轮次限制 - 超出时移除前一半,保留后一半
|
||||
if len(turns) > self.max_context_turns:
|
||||
keep_count = max(1, self.max_context_turns // 2)
|
||||
removed_count = len(turns) - keep_count
|
||||
removed_count = len(turns) // 2
|
||||
keep_count = len(turns) - removed_count
|
||||
|
||||
# Flush discarded turns to daily memory
|
||||
if self.agent.memory_manager:
|
||||
@@ -1223,9 +1224,47 @@ class AgentStreamExecutor:
|
||||
logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息")
|
||||
return
|
||||
|
||||
# Token limit exceeded - keep the latest half of turns (same strategy as turn limit)
|
||||
keep_count = max(1, len(turns) // 2)
|
||||
removed_count = len(turns) - keep_count
|
||||
# Token limit exceeded — tiered strategy based on turn count:
|
||||
#
|
||||
# Few turns (<5): Compress ALL turns to text-only (strip tool chains,
|
||||
# keep user query + final reply). Never discard turns
|
||||
# — losing even one is too painful when context is thin.
|
||||
#
|
||||
# Many turns (>=5): Directly discard the first half of turns.
|
||||
# With enough turns the oldest ones are less
|
||||
# critical, and keeping the recent half intact
|
||||
# (with full tool chains) is more useful.
|
||||
|
||||
COMPRESS_THRESHOLD = 5
|
||||
|
||||
if len(turns) < COMPRESS_THRESHOLD:
|
||||
# --- Few turns: compress ALL turns to text-only, never discard ---
|
||||
compressed_turns = []
|
||||
for t in turns:
|
||||
compressed = compress_turn_to_text_only(t)
|
||||
if compressed["messages"]:
|
||||
compressed_turns.append(compressed)
|
||||
|
||||
new_messages = []
|
||||
for turn in compressed_turns:
|
||||
new_messages.extend(turn["messages"])
|
||||
|
||||
new_tokens = sum(self._estimate_turn_tokens(t) for t in compressed_turns)
|
||||
old_count = len(self.messages)
|
||||
self.messages = new_messages
|
||||
|
||||
logger.info(
|
||||
f"📦 上下文tokens超限(轮次<{COMPRESS_THRESHOLD}): "
|
||||
f"~{current_tokens + system_tokens} > {max_tokens},"
|
||||
f"压缩全部 {len(turns)} 轮为纯文本 "
|
||||
f"({old_count} -> {len(self.messages)} 条消息,"
|
||||
f"~{current_tokens + system_tokens} -> ~{new_tokens + system_tokens} tokens)"
|
||||
)
|
||||
return
|
||||
|
||||
# --- Many turns (>=5): discard the older half, keep the newer half ---
|
||||
removed_count = len(turns) // 2
|
||||
keep_count = len(turns) - removed_count
|
||||
kept_turns = turns[-keep_count:]
|
||||
kept_tokens = sum(self._estimate_turn_tokens(t) for t in kept_turns)
|
||||
|
||||
@@ -1234,7 +1273,6 @@ class AgentStreamExecutor:
|
||||
f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)"
|
||||
)
|
||||
|
||||
# Flush discarded turns to daily memory
|
||||
if self.agent.memory_manager:
|
||||
discarded_messages = []
|
||||
for turn in turns[:removed_count]:
|
||||
@@ -1245,14 +1283,14 @@ class AgentStreamExecutor:
|
||||
messages=discarded_messages, user_id=user_id,
|
||||
reason="trim", max_messages=0
|
||||
)
|
||||
|
||||
|
||||
new_messages = []
|
||||
for turn in kept_turns:
|
||||
new_messages.extend(turn['messages'])
|
||||
|
||||
|
||||
old_count = len(self.messages)
|
||||
self.messages = new_messages
|
||||
|
||||
|
||||
logger.info(
|
||||
f" 移除了 {removed_count} 轮对话 "
|
||||
f"({old_count} -> {len(self.messages)} 条消息,"
|
||||
|
||||
@@ -177,3 +177,60 @@ def _has_block_type(content: list, block_type: str) -> bool:
|
||||
isinstance(b, dict) and b.get("type") == block_type
|
||||
for b in content
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from a message content field (str or list of blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def compress_turn_to_text_only(turn: Dict) -> Dict:
|
||||
"""
|
||||
Compress a full turn (with tool_use/tool_result chains) into a lightweight
|
||||
text-only turn that keeps only the first user text and the last assistant text.
|
||||
|
||||
This preserves the conversational context (what the user asked and what the
|
||||
agent concluded) while stripping out the bulky intermediate tool interactions.
|
||||
|
||||
Returns a new turn dict with a ``messages`` list; the original is not mutated.
|
||||
"""
|
||||
user_text = ""
|
||||
last_assistant_text = ""
|
||||
|
||||
for msg in turn["messages"]:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result"):
|
||||
continue
|
||||
if not user_text:
|
||||
user_text = _extract_text_from_content(content)
|
||||
|
||||
elif role == "assistant":
|
||||
text = _extract_text_from_content(content)
|
||||
if text:
|
||||
last_assistant_text = text
|
||||
|
||||
compressed_messages = []
|
||||
if user_text:
|
||||
compressed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if last_assistant_text:
|
||||
compressed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": last_assistant_text}]
|
||||
})
|
||||
|
||||
return {"messages": compressed_messages}
|
||||
|
||||
Reference in New Issue
Block a user