mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-12 18:01:30 +08:00
feat: add llm retry
This commit is contained in:
@@ -354,7 +354,7 @@ def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"",
|
||||
"**路径使用规则** (非常重要):",
|
||||
"",
|
||||
"- **工作空间内的文件**: 使用相对路径(如 `SOUL.md`、`memory/daily.md`)",
|
||||
"- **工作空间内的文件**: 可以使用相对路径(如 `SOUL.md`、`MEMORY.md`)",
|
||||
"- **工作空间外的文件**: 必须使用绝对路径(如 `~/project/code.py`、`/etc/config`)",
|
||||
"- **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||
"",
|
||||
|
||||
@@ -102,7 +102,7 @@ class AgentStreamExecutor:
|
||||
try:
|
||||
while turn < self.max_turns:
|
||||
turn += 1
|
||||
logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}")
|
||||
logger.info(f"\n🔄 第 {turn} 轮")
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
@@ -156,9 +156,15 @@ class AgentStreamExecutor:
|
||||
})
|
||||
break
|
||||
|
||||
# Log tool calls in compact format
|
||||
tool_names = [tc['name'] for tc in tool_calls]
|
||||
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
|
||||
# Log tool calls with arguments
|
||||
tool_calls_str = []
|
||||
for tc in tool_calls:
|
||||
args_str = ', '.join([f"{k}={v}" for k, v in tc['arguments'].items()])
|
||||
if args_str:
|
||||
tool_calls_str.append(f"{tc['name']}({args_str})")
|
||||
else:
|
||||
tool_calls_str.append(tc['name'])
|
||||
logger.info(f"🔧 {', '.join(tool_calls_str)}")
|
||||
|
||||
# Execute tools
|
||||
tool_results = []
|
||||
@@ -179,13 +185,33 @@ class AgentStreamExecutor:
|
||||
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
||||
|
||||
# Build tool result block (Claude format)
|
||||
# Content should be a string representation of the result
|
||||
result_content = json.dumps(result, ensure_ascii=False) if not isinstance(result, str) else result
|
||||
tool_result_blocks.append({
|
||||
# Format content in a way that's easy for LLM to understand
|
||||
is_error = result.get("status") == "error"
|
||||
|
||||
if is_error:
|
||||
# For errors, provide clear error message
|
||||
result_content = f"Error: {result.get('result', 'Unknown error')}"
|
||||
elif isinstance(result.get('result'), dict):
|
||||
# For dict results, use JSON format
|
||||
result_content = json.dumps(result.get('result'), ensure_ascii=False)
|
||||
elif isinstance(result.get('result'), str):
|
||||
# For string results, use directly
|
||||
result_content = result.get('result')
|
||||
else:
|
||||
# Fallback to full JSON
|
||||
result_content = json.dumps(result, ensure_ascii=False)
|
||||
|
||||
tool_result_block = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call["id"],
|
||||
"content": result_content
|
||||
})
|
||||
}
|
||||
|
||||
# Add is_error field for Claude API (helps model understand failures)
|
||||
if is_error:
|
||||
tool_result_block["is_error"] = True
|
||||
|
||||
tool_result_blocks.append(tool_result_block)
|
||||
|
||||
# Add tool results to message history as user message (Claude format)
|
||||
self.messages.append({
|
||||
@@ -201,6 +227,11 @@ class AgentStreamExecutor:
|
||||
|
||||
if turn >= self.max_turns:
|
||||
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
||||
if not final_response:
|
||||
final_response = (
|
||||
"抱歉,我在处理你的请求时遇到了一些困难,尝试了多次仍未能完成。"
|
||||
"请尝试简化你的问题,或换一种方式描述。"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Agent执行错误: {e}")
|
||||
@@ -208,14 +239,19 @@ class AgentStreamExecutor:
|
||||
raise
|
||||
|
||||
finally:
|
||||
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
|
||||
logger.info(f"🏁 完成({turn}轮)\n")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
|
||||
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]:
|
||||
"""
|
||||
Call LLM with streaming
|
||||
Call LLM with streaming and automatic retry on errors
|
||||
|
||||
Args:
|
||||
retry_on_empty: Whether to retry once if empty response is received
|
||||
retry_count: Current retry attempt (internal use)
|
||||
max_retries: Maximum number of retries for API errors
|
||||
|
||||
Returns:
|
||||
(response_text, tool_calls)
|
||||
@@ -309,8 +345,29 @@ class AgentStreamExecutor:
|
||||
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call error: {e}")
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
# Check if error is retryable (timeout, connection, rate limit, etc.)
|
||||
is_retryable = any(keyword in error_str for keyword in [
|
||||
'timeout', 'timed out', 'connection', 'network',
|
||||
'rate limit', 'overloaded', 'unavailable', '429', '500', '502', '503', '504'
|
||||
])
|
||||
|
||||
if is_retryable and retry_count < max_retries:
|
||||
wait_time = (retry_count + 1) * 2 # Exponential backoff: 2s, 4s, 6s
|
||||
logger.warning(f"⚠️ LLM API error (attempt {retry_count + 1}/{max_retries}): {e}")
|
||||
logger.info(f"🔄 Retrying in {wait_time}s...")
|
||||
time.sleep(wait_time)
|
||||
return self._call_llm_stream(
|
||||
retry_on_empty=retry_on_empty,
|
||||
retry_count=retry_count + 1,
|
||||
max_retries=max_retries
|
||||
)
|
||||
else:
|
||||
if retry_count >= max_retries:
|
||||
logger.error(f"❌ LLM API error after {max_retries} retries: {e}")
|
||||
else:
|
||||
logger.error(f"❌ LLM call error (non-retryable): {e}")
|
||||
raise
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls = []
|
||||
@@ -328,6 +385,21 @@ class AgentStreamExecutor:
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
# Check for empty response and retry once if enabled
|
||||
if retry_on_empty and not full_content and not tool_calls:
|
||||
logger.warning(f"⚠️ LLM returned empty response, retrying once...")
|
||||
self._emit_event("message_end", {
|
||||
"content": "",
|
||||
"tool_calls": [],
|
||||
"empty_retry": True
|
||||
})
|
||||
# Retry without retry flag to avoid infinite loop
|
||||
return self._call_llm_stream(
|
||||
retry_on_empty=False,
|
||||
retry_count=retry_count,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# Add assistant message to history (Claude format uses content blocks)
|
||||
assistant_msg = {"role": "assistant", "content": []}
|
||||
|
||||
|
||||
@@ -255,7 +255,7 @@ class GoogleGeminiBot(Bot):
|
||||
gemini_tools = self._convert_tools_to_gemini_rest_format(tools)
|
||||
if gemini_tools:
|
||||
payload["tools"] = gemini_tools
|
||||
logger.info(f"[Gemini] Added {len(tools)} tools to request")
|
||||
logger.debug(f"[Gemini] Added {len(tools)} tools to request")
|
||||
|
||||
# Make REST API call
|
||||
base_url = f"{self.api_base}/v1beta"
|
||||
@@ -445,6 +445,9 @@ class GoogleGeminiBot(Bot):
|
||||
all_tool_calls = []
|
||||
has_sent_tool_calls = False
|
||||
has_content = False # Track if any content was sent
|
||||
chunk_count = 0
|
||||
last_finish_reason = None
|
||||
last_safety_ratings = None
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
@@ -461,6 +464,7 @@ class GoogleGeminiBot(Bot):
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
chunk_count += 1
|
||||
logger.debug(f"[Gemini] Stream chunk: {json.dumps(chunk_data, ensure_ascii=False)[:200]}")
|
||||
|
||||
candidates = chunk_data.get("candidates", [])
|
||||
@@ -469,6 +473,13 @@ class GoogleGeminiBot(Bot):
|
||||
continue
|
||||
|
||||
candidate = candidates[0]
|
||||
|
||||
# 记录 finish_reason 和 safety_ratings
|
||||
if "finishReason" in candidate:
|
||||
last_finish_reason = candidate["finishReason"]
|
||||
if "safetyRatings" in candidate:
|
||||
last_safety_ratings = candidate["safetyRatings"]
|
||||
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
@@ -512,7 +523,7 @@ class GoogleGeminiBot(Bot):
|
||||
|
||||
# Send tool calls if any were collected
|
||||
if all_tool_calls and not has_sent_tool_calls:
|
||||
logger.info(f"[Gemini] Stream detected {len(all_tool_calls)} tool calls")
|
||||
logger.debug(f"[Gemini] Stream detected {len(all_tool_calls)} tool calls")
|
||||
yield {
|
||||
"id": f"chatcmpl-{time.time()}",
|
||||
"object": "chat.completion.chunk",
|
||||
@@ -526,17 +537,17 @@ class GoogleGeminiBot(Bot):
|
||||
}
|
||||
has_sent_tool_calls = True
|
||||
|
||||
# Log summary
|
||||
logger.info(f"[Gemini] Stream complete: has_content={has_content}, tool_calls={len(all_tool_calls)}")
|
||||
# Log summary (only if there's something interesting)
|
||||
if not has_content and not all_tool_calls:
|
||||
logger.debug(f"[Gemini] Stream complete: has_content={has_content}, tool_calls={len(all_tool_calls)}")
|
||||
elif all_tool_calls:
|
||||
logger.debug(f"[Gemini] Stream complete: {len(all_tool_calls)} tool calls")
|
||||
else:
|
||||
logger.debug(f"[Gemini] Stream complete: text response")
|
||||
|
||||
# 如果返回空响应,记录详细警告
|
||||
if not has_content and not all_tool_calls:
|
||||
logger.warning(f"[Gemini] ⚠️ Empty response detected!")
|
||||
logger.warning(f"[Gemini] Possible reasons:")
|
||||
logger.warning(f" 1. Model couldn't generate response based on context")
|
||||
logger.warning(f" 2. Content blocked by safety filters")
|
||||
logger.warning(f" 3. All previous tool calls failed")
|
||||
logger.warning(f" 4. API error not properly caught")
|
||||
|
||||
# Final chunk
|
||||
yield {
|
||||
|
||||
Reference in New Issue
Block a user