mirror of
https://github.com/xszyou/Fay.git
synced 2026-03-12 17:51:28 +08:00
记忆模块升级
1、把记忆节点明确区分成对话记忆、观察记忆和反思记忆; 2、修复</think>标签错误输出的问题。
This commit is contained in:
3730
core/fay_core.py
3730
core/fay_core.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -74,9 +74,9 @@ def __get_template():
|
||||
except Exception as e:
|
||||
return f"Error rendering template: {e}", 500
|
||||
|
||||
def __get_device_list():
|
||||
try:
|
||||
if config_util.start_mode == 'common':
|
||||
def __get_device_list():
|
||||
try:
|
||||
if config_util.start_mode == 'common':
|
||||
audio = pyaudio.PyAudio()
|
||||
device_list = []
|
||||
for i in range(audio.get_device_count()):
|
||||
@@ -86,33 +86,33 @@ def __get_device_list():
|
||||
return list(set(device_list))
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Error getting device list: {e}")
|
||||
return []
|
||||
|
||||
def _as_bool(value):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in ("1", "true", "yes", "y", "on")
|
||||
return False
|
||||
|
||||
def _build_llm_url(base_url: str) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
url = base_url.rstrip("/")
|
||||
if url.endswith("/chat/completions"):
|
||||
return url
|
||||
if url.endswith("/v1"):
|
||||
return url + "/chat/completions"
|
||||
return url + "/v1/chat/completions"
|
||||
|
||||
@__app.route('/api/submit', methods=['post'])
|
||||
def api_submit():
|
||||
except Exception as e:
|
||||
print(f"Error getting device list: {e}")
|
||||
return []
|
||||
|
||||
def _as_bool(value):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in ("1", "true", "yes", "y", "on")
|
||||
return False
|
||||
|
||||
def _build_llm_url(base_url: str) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
url = base_url.rstrip("/")
|
||||
if url.endswith("/chat/completions"):
|
||||
return url
|
||||
if url.endswith("/v1"):
|
||||
return url + "/chat/completions"
|
||||
return url + "/v1/chat/completions"
|
||||
|
||||
@__app.route('/api/submit', methods=['post'])
|
||||
def api_submit():
|
||||
data = request.values.get('data')
|
||||
if not data:
|
||||
return jsonify({'result': 'error', 'message': '未提供数据'})
|
||||
@@ -309,27 +309,27 @@ def api_send():
|
||||
|
||||
# 获取指定用户的消息记录(支持分页)
|
||||
@__app.route('/api/get-msg', methods=['post'])
|
||||
def api_get_Msg():
|
||||
try:
|
||||
data = request.form.get('data')
|
||||
if data is None:
|
||||
data = request.get_json(silent=True) or {}
|
||||
else:
|
||||
data = json.loads(data)
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
username = data.get("username")
|
||||
limit = data.get("limit", 30) # 默认每页30条
|
||||
offset = data.get("offset", 0) # 默认从0开始
|
||||
contentdb = content_db.new_instance()
|
||||
uid = 0
|
||||
if username:
|
||||
uid = member_db.new_instance().find_user(username)
|
||||
if uid == 0:
|
||||
return json.dumps({'list': [], 'total': 0, 'hasMore': False})
|
||||
# 获取总数用于判断是否还有更多
|
||||
total = contentdb.get_message_count(uid)
|
||||
list = contentdb.get_list('all', 'desc', limit, uid, offset)
|
||||
def api_get_Msg():
|
||||
try:
|
||||
data = request.form.get('data')
|
||||
if data is None:
|
||||
data = request.get_json(silent=True) or {}
|
||||
else:
|
||||
data = json.loads(data)
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
username = data.get("username")
|
||||
limit = data.get("limit", 30) # 默认每页30条
|
||||
offset = data.get("offset", 0) # 默认从0开始
|
||||
contentdb = content_db.new_instance()
|
||||
uid = 0
|
||||
if username:
|
||||
uid = member_db.new_instance().find_user(username)
|
||||
if uid == 0:
|
||||
return json.dumps({'list': [], 'total': 0, 'hasMore': False})
|
||||
# 获取总数用于判断是否还有更多
|
||||
total = contentdb.get_message_count(uid)
|
||||
list = contentdb.get_list('all', 'desc', limit, uid, offset)
|
||||
relist = []
|
||||
i = len(list) - 1
|
||||
while i >= 0:
|
||||
@@ -354,134 +354,143 @@ def api_send_v1_chat_completions():
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return jsonify({'error': '未提供数据'})
|
||||
try:
|
||||
model = data.get('model', 'fay')
|
||||
if model == 'llm':
|
||||
try:
|
||||
config_util.load_config()
|
||||
llm_url = _build_llm_url(config_util.gpt_base_url)
|
||||
api_key = config_util.key_gpt_api_key
|
||||
model_engine = config_util.gpt_model_engine
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM config load failed: {exc}'}), 500
|
||||
|
||||
if not llm_url:
|
||||
return jsonify({'error': 'LLM base_url is not configured'}), 500
|
||||
|
||||
payload = dict(data)
|
||||
if payload.get('model') == 'llm' and model_engine:
|
||||
payload['model'] = model_engine
|
||||
|
||||
stream_requested = _as_bool(payload.get('stream', False))
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
try:
|
||||
if stream_requested:
|
||||
resp = requests.post(llm_url, headers=headers, json=payload, stream=True)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
if not chunk:
|
||||
continue
|
||||
yield chunk
|
||||
finally:
|
||||
resp.close()
|
||||
|
||||
content_type = resp.headers.get("Content-Type", "text/event-stream")
|
||||
if "charset=" not in content_type.lower():
|
||||
content_type = f"{content_type}; charset=utf-8"
|
||||
return Response(
|
||||
generate(),
|
||||
status=resp.status_code,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
resp = requests.post(llm_url, headers=headers, json=payload, timeout=60)
|
||||
content_type = resp.headers.get("Content-Type", "application/json")
|
||||
if "charset=" not in content_type.lower():
|
||||
content_type = f"{content_type}; charset=utf-8"
|
||||
return Response(
|
||||
resp.content,
|
||||
status=resp.status_code,
|
||||
content_type=content_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM request failed: {exc}'}), 500
|
||||
|
||||
last_content = ""
|
||||
if 'messages' in data and data['messages']:
|
||||
last_message = data['messages'][-1]
|
||||
username = last_message.get('role', 'User')
|
||||
if username == 'user':
|
||||
username = 'User'
|
||||
last_content = last_message.get('content', 'No content provided')
|
||||
else:
|
||||
last_content = 'No messages found'
|
||||
username = 'User'
|
||||
try:
|
||||
model = data.get('model', 'fay')
|
||||
if model == 'llm':
|
||||
try:
|
||||
config_util.load_config()
|
||||
llm_url = _build_llm_url(config_util.gpt_base_url)
|
||||
api_key = config_util.key_gpt_api_key
|
||||
model_engine = config_util.gpt_model_engine
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM config load failed: {exc}'}), 500
|
||||
|
||||
observation = data.get('observation', '')
|
||||
if not llm_url:
|
||||
return jsonify({'error': 'LLM base_url is not configured'}), 500
|
||||
|
||||
payload = dict(data)
|
||||
if payload.get('model') == 'llm' and model_engine:
|
||||
payload['model'] = model_engine
|
||||
|
||||
stream_requested = _as_bool(payload.get('stream', False))
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
try:
|
||||
if stream_requested:
|
||||
resp = requests.post(llm_url, headers=headers, json=payload, stream=True)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
if not chunk:
|
||||
continue
|
||||
yield chunk
|
||||
finally:
|
||||
resp.close()
|
||||
|
||||
content_type = resp.headers.get("Content-Type", "text/event-stream")
|
||||
if "charset=" not in content_type.lower():
|
||||
content_type = f"{content_type}; charset=utf-8"
|
||||
return Response(
|
||||
generate(),
|
||||
status=resp.status_code,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
resp = requests.post(llm_url, headers=headers, json=payload, timeout=60)
|
||||
content_type = resp.headers.get("Content-Type", "application/json")
|
||||
if "charset=" not in content_type.lower():
|
||||
content_type = f"{content_type}; charset=utf-8"
|
||||
return Response(
|
||||
resp.content,
|
||||
status=resp.status_code,
|
||||
content_type=content_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM request failed: {exc}'}), 500
|
||||
|
||||
last_content = ""
|
||||
username = "User"
|
||||
messages = data.get("messages")
|
||||
if isinstance(messages, list) and messages:
|
||||
last_message = messages[-1] or {}
|
||||
username = last_message.get("role", "User") or "User"
|
||||
if username == "user":
|
||||
username = "User"
|
||||
last_content = last_message.get("content") or ""
|
||||
elif isinstance(messages, str):
|
||||
last_content = messages
|
||||
|
||||
observation = data.get('observation', '')
|
||||
# 检查请求中是否指定了流式传输
|
||||
stream_requested = data.get('stream', False)
|
||||
stream_requested = data.get('stream', False)
|
||||
no_reply = _as_bool(data.get('no_reply', data.get('noReply', False)))
|
||||
if no_reply:
|
||||
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream': bool(stream_requested), 'no_reply': True})
|
||||
util.printInfo(1, username, '[text chat no_reply]{}'.format(interact.data["msg"]), time.time())
|
||||
fay_booter.feiFei.on_interact(interact)
|
||||
if stream_requested or model == 'fay-streaming':
|
||||
def generate():
|
||||
message = {
|
||||
"id": "faystreaming-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": ""
|
||||
},
|
||||
"index": 0,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": len(last_content),
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": len(last_content)
|
||||
},
|
||||
"system_fingerprint": "",
|
||||
"no_reply": True
|
||||
}
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
yield 'data: [DONE]\n\n'
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
return jsonify({
|
||||
"id": "fay-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ""
|
||||
},
|
||||
"logprobs": "",
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": len(last_content),
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": len(last_content)
|
||||
},
|
||||
"system_fingerprint": "",
|
||||
"no_reply": True
|
||||
})
|
||||
if stream_requested or model == 'fay-streaming':
|
||||
obs_text = ""
|
||||
if observation is not None:
|
||||
obs_text = observation.strip() if isinstance(observation, str) else str(observation).strip()
|
||||
message_text = last_content.strip() if isinstance(last_content, str) else str(last_content).strip()
|
||||
if not message_text and not obs_text:
|
||||
return jsonify({'error': 'messages and observation are both empty'}), 400
|
||||
if not message_text and obs_text:
|
||||
no_reply = True
|
||||
if no_reply:
|
||||
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream': bool(stream_requested), 'no_reply': True})
|
||||
util.printInfo(1, username, '[text chat no_reply]{}'.format(interact.data["msg"]), time.time())
|
||||
fay_booter.feiFei.on_interact(interact)
|
||||
if stream_requested or model == 'fay-streaming':
|
||||
def generate():
|
||||
message = {
|
||||
"id": "faystreaming-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": ""
|
||||
},
|
||||
"index": 0,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": len(last_content),
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": len(last_content)
|
||||
},
|
||||
"system_fingerprint": "",
|
||||
"no_reply": True
|
||||
}
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
yield 'data: [DONE]\n\n'
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
return jsonify({
|
||||
"id": "fay-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ""
|
||||
},
|
||||
"logprobs": "",
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": len(last_content),
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": len(last_content)
|
||||
},
|
||||
"system_fingerprint": "",
|
||||
"no_reply": True
|
||||
})
|
||||
if stream_requested or model == 'fay-streaming':
|
||||
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream':True})
|
||||
util.printInfo(1, username, '[文字沟通接口(流式)]{}'.format(interact.data["msg"]), time.time())
|
||||
fay_booter.feiFei.on_interact(interact)
|
||||
@@ -588,7 +597,6 @@ def api_delete_user():
|
||||
|
||||
# 清除缓存的 agent 对象
|
||||
try:
|
||||
from llm import nlp_cognitive_stream
|
||||
if hasattr(nlp_cognitive_stream, 'agents') and username in nlp_cognitive_stream.agents:
|
||||
del nlp_cognitive_stream.agents[username]
|
||||
except Exception:
|
||||
|
||||
@@ -65,6 +65,11 @@ memory_cleared = False # 添加记忆清除标记
|
||||
# 新增: 当前会话用户名及按用户获取memory目录的辅助函数
|
||||
current_username = None # 当前会话用户名
|
||||
|
||||
def _log_prompt(messages: List[SystemMessage | HumanMessage], tag: str = ""):
|
||||
"""No-op placeholder for prompt logging (disabled)."""
|
||||
return
|
||||
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model=cfg.gpt_model_engine,
|
||||
base_url=cfg.gpt_base_url,
|
||||
@@ -314,14 +319,26 @@ def _remove_prestart_from_text(text: str, keep_marked: bool = True) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _remove_think_from_text(text: str) -> str:
|
||||
"""从文本中移除 think 标签及其内容"""
|
||||
if not text:
|
||||
return text
|
||||
import re
|
||||
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', text, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
|
||||
return cleaned.strip()
|
||||
def _remove_think_from_text(text: str) -> str:
|
||||
"""从文本中移除 think 标签及其内容"""
|
||||
if not text:
|
||||
return text
|
||||
import re
|
||||
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', text, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _strip_json_code_fence(text: str) -> str:
|
||||
"""Strip ```json ... ``` wrappers if present."""
|
||||
if not text:
|
||||
return text
|
||||
import re
|
||||
trimmed = text.strip()
|
||||
match = re.match(r"^```(?:json)?\\s*(.*?)\\s*```$", trimmed, flags=re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return text
|
||||
|
||||
|
||||
def _format_conversation_block(conversation: List[Dict], username: str = "User") -> str:
|
||||
@@ -440,35 +457,35 @@ def _run_prestart_tools(user_question: str) -> List[Dict[str, Any]]:
|
||||
return results
|
||||
|
||||
|
||||
def _truncate_history(
|
||||
history: List[ToolResult],
|
||||
limit: Optional[int] = None,
|
||||
output_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
if not history:
|
||||
return "(暂无)"
|
||||
lines: List[str] = []
|
||||
selected = history if limit is None else history[-limit:]
|
||||
for item in selected:
|
||||
call = item.get("call", {})
|
||||
name = call.get("name", "未知工具")
|
||||
attempt = item.get("attempt", 0)
|
||||
success = item.get("success", False)
|
||||
status = "成功" if success else "失败"
|
||||
lines.append(f"- {name} 第 {attempt} 次 → {status}")
|
||||
output = item.get("output")
|
||||
if output is not None:
|
||||
output_text = str(output)
|
||||
if output_limit is not None:
|
||||
output_text = _truncate_text(output_text, output_limit)
|
||||
lines.append(" 输出:" + output_text)
|
||||
error = item.get("error")
|
||||
if error is not None:
|
||||
error_text = str(error)
|
||||
if output_limit is not None:
|
||||
error_text = _truncate_text(error_text, output_limit)
|
||||
lines.append(" 错误:" + error_text)
|
||||
return "\n".join(lines)
|
||||
def _truncate_history(
|
||||
history: List[ToolResult],
|
||||
limit: Optional[int] = None,
|
||||
output_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
if not history:
|
||||
return "(暂无)"
|
||||
lines: List[str] = []
|
||||
selected = history if limit is None else history[-limit:]
|
||||
for item in selected:
|
||||
call = item.get("call", {})
|
||||
name = call.get("name", "未知工具")
|
||||
attempt = item.get("attempt", 0)
|
||||
success = item.get("success", False)
|
||||
status = "成功" if success else "失败"
|
||||
lines.append(f"- {name} 第 {attempt} 次 → {status}")
|
||||
output = item.get("output")
|
||||
if output is not None:
|
||||
output_text = str(output)
|
||||
if output_limit is not None:
|
||||
output_text = _truncate_text(output_text, output_limit)
|
||||
lines.append(" 输出:" + output_text)
|
||||
error = item.get("error")
|
||||
if error is not None:
|
||||
error_text = str(error)
|
||||
if output_limit is not None:
|
||||
error_text = _truncate_text(error_text, output_limit)
|
||||
lines.append(" 错误:" + error_text)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_schema_parameters(schema: Dict[str, Any]) -> List[str]:
|
||||
@@ -598,9 +615,9 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess
|
||||
|
||||
# 生成对话文本,使用代码块包裹每条消息
|
||||
convo_text = _format_conversation_block(conversation, username)
|
||||
history_text = _truncate_history(history)
|
||||
tools_text = _format_tools_for_prompt(tool_specs)
|
||||
preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else ""
|
||||
history_text = _truncate_history(history)
|
||||
tools_text = _format_tools_for_prompt(tool_specs)
|
||||
preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else ""
|
||||
|
||||
# 只有当有预启动工具结果时才显示,工具名+参数在外,结果在代码块内
|
||||
if prestart_context and prestart_context.strip():
|
||||
@@ -638,7 +655,6 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess
|
||||
{observation or '(无补充)'}
|
||||
---
|
||||
|
||||
**关联记忆**
|
||||
{memory_context or '(无相关记忆)'}
|
||||
---
|
||||
{prestart_section}
|
||||
@@ -682,9 +698,9 @@ def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessag
|
||||
display_username = "主人" if username == "User" else username
|
||||
|
||||
# 生成对话文本,使用代码块包裹每条消息
|
||||
conversation_block = _format_conversation_block(conversation, username)
|
||||
history_text = _truncate_history(state.get("tool_results", []))
|
||||
preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else ""
|
||||
conversation_block = _format_conversation_block(conversation, username)
|
||||
history_text = _truncate_history(state.get("tool_results", []))
|
||||
preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else ""
|
||||
|
||||
# 只有当有预启动工具结果时才显示,工具名+参数在外,结果在代码块内
|
||||
if prestart_context and prestart_context.strip():
|
||||
@@ -755,6 +771,7 @@ def _call_planner_llm(
|
||||
解析后的决策字典
|
||||
"""
|
||||
messages = _build_planner_messages(state)
|
||||
_log_prompt(messages, tag="planner")
|
||||
|
||||
# 如果有流式回调,使用流式模式检测 finish+message
|
||||
if stream_callback is not None:
|
||||
@@ -826,6 +843,7 @@ def _call_planner_llm(
|
||||
|
||||
# 处理完整响应
|
||||
trimmed = _remove_think_from_text(accumulated.strip())
|
||||
trimmed = _strip_json_code_fence(trimmed)
|
||||
|
||||
if in_message_mode:
|
||||
# 提取完整 message 内容,去掉结尾的 "}
|
||||
@@ -869,6 +887,7 @@ def _call_planner_llm(
|
||||
raise RuntimeError("规划器返回内容异常,未获得字符串。")
|
||||
# 先移除 think 标签,兼容带思考标签的模型(如 DeepSeek R1)
|
||||
trimmed = _remove_think_from_text(content.strip())
|
||||
trimmed = _strip_json_code_fence(trimmed)
|
||||
try:
|
||||
decision = json.loads(trimmed)
|
||||
except json.JSONDecodeError as exc:
|
||||
@@ -929,6 +948,7 @@ def _plan_next_action(state: AgentState) -> AgentState:
|
||||
"规划器:检测到工具重复调用,使用最新结果产出最终回复。"
|
||||
)
|
||||
final_messages = _build_final_messages(state)
|
||||
_log_prompt(final_messages, tag="final")
|
||||
preview = last_entry.get("output")
|
||||
return {
|
||||
"status": "completed",
|
||||
@@ -1669,27 +1689,57 @@ def load_agent_memory(agent, username=None):
|
||||
|
||||
# 记忆对话内容的线程函数
|
||||
def remember_conversation_thread(username, content, response_text):
|
||||
"""
|
||||
在单独线程中记录对话内容到代理记忆
|
||||
|
||||
参数:
|
||||
username: 用户名
|
||||
content: 用户问题内容
|
||||
response_text: 代理回答内容
|
||||
"""
|
||||
global agents
|
||||
"""Background task to store a conversation memory node."""
|
||||
try:
|
||||
ag = create_agent(username)
|
||||
if ag is None:
|
||||
return
|
||||
questioner = username
|
||||
if isinstance(questioner, str) and questioner.lower() == "user":
|
||||
questioner = "主人"
|
||||
answerer = ag.scratch.get("first_name", "Fay")
|
||||
question_text = content if content is not None else ""
|
||||
answer_text = response_text if response_text is not None else ""
|
||||
memory_content = f"{questioner}:{question_text}\n{answerer}:{answer_text}"
|
||||
with agent_lock:
|
||||
ag = agents.get(username)
|
||||
if ag is None:
|
||||
return
|
||||
time_step = get_current_time_step(username)
|
||||
name = "主人" if username == "User" else username
|
||||
# 记录对话内容
|
||||
memory_content = f"在对话中,我回答了{name}的问题:{content}\n,我的回答是:{response_text}"
|
||||
ag.remember(memory_content, time_step)
|
||||
if ag.memory_stream and hasattr(ag.memory_stream, "remember_conversation"):
|
||||
ag.memory_stream.remember_conversation(memory_content, time_step)
|
||||
else:
|
||||
ag.remember(memory_content, time_step)
|
||||
except Exception as e:
|
||||
util.log(1, f"记忆对话内容出错: {str(e)}")
|
||||
util.log(1, f"记录对话记忆失败: {str(e)}")
|
||||
|
||||
def remember_observation_thread(username, observation_text):
|
||||
"""Background task to store an observation memory node."""
|
||||
try:
|
||||
ag = create_agent(username)
|
||||
if ag is None:
|
||||
return
|
||||
text = observation_text if observation_text is not None else ""
|
||||
memory_content = text
|
||||
with agent_lock:
|
||||
time_step = get_current_time_step(username)
|
||||
if ag.memory_stream and hasattr(ag.memory_stream, "remember"):
|
||||
ag.memory_stream.remember(memory_content, time_step)
|
||||
else:
|
||||
ag.remember(memory_content, time_step)
|
||||
except Exception as e:
|
||||
util.log(1, f"记录观察记忆失败: {str(e)}")
|
||||
|
||||
def record_observation(username, observation_text):
|
||||
"""Persist an observation memory node asynchronously."""
|
||||
if observation_text is None:
|
||||
return False, "observation text is required"
|
||||
text = observation_text.strip() if isinstance(observation_text, str) else str(observation_text).strip()
|
||||
if not text:
|
||||
return False, "observation text is required"
|
||||
try:
|
||||
MyThread(target=remember_observation_thread, args=(username, text)).start()
|
||||
return True, "observation recorded"
|
||||
except Exception as exc:
|
||||
util.log(1, f"记录观察记忆失败: {exc}")
|
||||
return False, f"observation record failed: {exc}"
|
||||
|
||||
def question(content, username, observation=None):
|
||||
"""处理用户提问并返回回复。"""
|
||||
@@ -1725,24 +1775,49 @@ def question(content, username, observation=None):
|
||||
),
|
||||
}
|
||||
|
||||
memory_sections = [
|
||||
("观察记忆", "observation"),
|
||||
("对话记忆", "conversation"),
|
||||
("反思记忆", "reflection"),
|
||||
]
|
||||
memory_context = ""
|
||||
if agent.memory_stream and len(agent.memory_stream.seq_nodes) > 0:
|
||||
if agent.memory_stream and len(agent.memory_stream.seq_nodes) > 0 and content:
|
||||
current_time_step = get_current_time_step(username)
|
||||
query = content.strip() if isinstance(content, str) else str(content)
|
||||
max_per_type = 10
|
||||
section_texts = []
|
||||
try:
|
||||
query = f"{'主人' if username == 'User' else username}提出了问题:{content}"
|
||||
related_memories = agent.memory_stream.retrieve(
|
||||
combined = agent.memory_stream.retrieve(
|
||||
[query],
|
||||
current_time_step,
|
||||
n_count=30,
|
||||
n_count=max_per_type * len(memory_sections),
|
||||
curr_filter="all",
|
||||
hp=[0.8, 0.5, 0.5],
|
||||
stateless=False,
|
||||
)
|
||||
if related_memories and query in related_memories:
|
||||
memory_nodes = related_memories[query]
|
||||
memory_context = "\n".join(f"- {node.content}" for node in memory_nodes)
|
||||
all_nodes = combined.get(query, []) if combined else []
|
||||
except Exception as exc:
|
||||
util.log(1, f"获取相关记忆时出错: {exc}")
|
||||
util.log(1, f"获取关联记忆时出错: {exc}")
|
||||
all_nodes = []
|
||||
|
||||
for label, node_type in memory_sections:
|
||||
section_lines = "(无)"
|
||||
try:
|
||||
memory_nodes = [n for n in all_nodes if getattr(n, "node_type", "") == node_type][:max_per_type]
|
||||
if memory_nodes:
|
||||
formatted = []
|
||||
for node in memory_nodes:
|
||||
ts = (getattr(node, "datetime", "") or "").strip()
|
||||
prefix = f"[{ts}] " if ts else ""
|
||||
formatted.append(f"- {prefix}{node.content}")
|
||||
section_lines = "\n".join(formatted)
|
||||
except Exception as exc:
|
||||
util.log(1, f"获取{label}时出错: {exc}")
|
||||
section_lines = "(获取失败)"
|
||||
section_texts.append(f"**{label}**\n{section_lines}")
|
||||
memory_context = "\n".join(section_texts)
|
||||
else:
|
||||
memory_context = "\n".join(f"**{label}**\n(无)" for label, _ in memory_sections)
|
||||
|
||||
prestart_context = ""
|
||||
prestart_stream_text = ""
|
||||
@@ -1866,22 +1941,22 @@ def question(content, username, observation=None):
|
||||
or messages_buffer[-1]['content'] != content
|
||||
):
|
||||
messages_buffer.append({"role": "user", "content": content})
|
||||
else:
|
||||
# 不隔离:按独立消息存储,保留用户名信息
|
||||
def append_to_buffer_multi(role: str, text_value: str, msg_username: str = "") -> None:
|
||||
if not text_value:
|
||||
return
|
||||
messages_buffer.append({"role": role, "content": text_value, "username": msg_username})
|
||||
if len(messages_buffer) > 60:
|
||||
del messages_buffer[:-60]
|
||||
|
||||
def append_to_buffer(role: str, text_value: str) -> None:
|
||||
append_to_buffer_multi(role, text_value, "")
|
||||
|
||||
for record in history_records:
|
||||
msg_type, msg_text, msg_username = record
|
||||
if not msg_text:
|
||||
continue
|
||||
else:
|
||||
# 不隔离:按独立消息存储,保留用户名信息
|
||||
def append_to_buffer_multi(role: str, text_value: str, msg_username: str = "") -> None:
|
||||
if not text_value:
|
||||
return
|
||||
messages_buffer.append({"role": role, "content": text_value, "username": msg_username})
|
||||
if len(messages_buffer) > 60:
|
||||
del messages_buffer[:-60]
|
||||
|
||||
def append_to_buffer(role: str, text_value: str) -> None:
|
||||
append_to_buffer_multi(role, text_value, "")
|
||||
|
||||
for record in history_records:
|
||||
msg_type, msg_text, msg_username = record
|
||||
if not msg_text:
|
||||
continue
|
||||
if msg_type and msg_type.lower() in ('member', 'user'):
|
||||
append_to_buffer_multi("user", msg_text, msg_username)
|
||||
else:
|
||||
@@ -2020,22 +2095,22 @@ def question(content, username, observation=None):
|
||||
# 创建规划器流式回调,用于实时输出 finish+message 响应
|
||||
planner_stream_buffer = {"text": "", "first_chunk": True}
|
||||
|
||||
def planner_stream_callback(chunk_text: str) -> None:
|
||||
"""规划器流式回调,将 message 内容实时输出"""
|
||||
nonlocal accumulated_text, full_response_text, is_first_sentence, is_agent_think_start
|
||||
if not chunk_text:
|
||||
return
|
||||
planner_stream_buffer["text"] += chunk_text
|
||||
if planner_stream_buffer["first_chunk"]:
|
||||
planner_stream_buffer["first_chunk"] = False
|
||||
if is_agent_think_start:
|
||||
closing = "</think>"
|
||||
accumulated_text += closing
|
||||
full_response_text += closing
|
||||
is_agent_think_start = False
|
||||
# 使用 stream_response_chunks 的逻辑进行分句流式输出
|
||||
accumulated_text += chunk_text
|
||||
full_response_text += chunk_text
|
||||
def planner_stream_callback(chunk_text: str) -> None:
|
||||
"""规划器流式回调,将 message 内容实时输出"""
|
||||
nonlocal accumulated_text, full_response_text, is_first_sentence, is_agent_think_start
|
||||
if not chunk_text:
|
||||
return
|
||||
planner_stream_buffer["text"] += chunk_text
|
||||
if planner_stream_buffer["first_chunk"]:
|
||||
planner_stream_buffer["first_chunk"] = False
|
||||
if is_agent_think_start:
|
||||
closing = "</think>"
|
||||
accumulated_text += closing
|
||||
full_response_text += closing
|
||||
is_agent_think_start = False
|
||||
# 使用 stream_response_chunks 的逻辑进行分句流式输出
|
||||
accumulated_text += chunk_text
|
||||
full_response_text += chunk_text
|
||||
# 检查是否有完整句子可以输出
|
||||
if len(accumulated_text) >= 20:
|
||||
while True:
|
||||
@@ -2507,46 +2582,46 @@ def save_agent_memory():
|
||||
agent.scratch = {}
|
||||
|
||||
# 保存记忆前进行完整性检查
|
||||
try:
|
||||
# 检查seq_nodes中的每个节点是否有效
|
||||
valid_nodes = []
|
||||
for node in agent.memory_stream.seq_nodes:
|
||||
if node is None:
|
||||
util.log(1, "发现无效节点(None),跳过")
|
||||
continue
|
||||
|
||||
if not hasattr(node, 'node_id') or not hasattr(node, 'content'):
|
||||
util.log(1, f"发现无效节点(缺少必要属性),跳过")
|
||||
continue
|
||||
raw_content = node.content if isinstance(node.content, str) else str(node.content)
|
||||
cleaned_content = _remove_think_from_text(raw_content)
|
||||
if cleaned_content != raw_content:
|
||||
old_content = raw_content
|
||||
node.content = cleaned_content
|
||||
if (
|
||||
agent.memory_stream.embeddings is not None
|
||||
and old_content in agent.memory_stream.embeddings
|
||||
and cleaned_content not in agent.memory_stream.embeddings
|
||||
):
|
||||
agent.memory_stream.embeddings[cleaned_content] = agent.memory_stream.embeddings[old_content]
|
||||
else:
|
||||
node.content = raw_content
|
||||
valid_nodes.append(node)
|
||||
|
||||
# 更新seq_nodes为有效节点列表
|
||||
agent.memory_stream.seq_nodes = valid_nodes
|
||||
|
||||
# 重建id_to_node字典
|
||||
agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')}
|
||||
if agent.memory_stream.embeddings is not None:
|
||||
kept_contents = {node.content for node in valid_nodes if hasattr(node, 'content')}
|
||||
agent.memory_stream.embeddings = {
|
||||
key: value
|
||||
for key, value in agent.memory_stream.embeddings.items()
|
||||
if key in kept_contents
|
||||
}
|
||||
except Exception as e:
|
||||
util.log(1, f"检查记忆完整性时出错: {str(e)}")
|
||||
try:
|
||||
# 检查seq_nodes中的每个节点是否有效
|
||||
valid_nodes = []
|
||||
for node in agent.memory_stream.seq_nodes:
|
||||
if node is None:
|
||||
util.log(1, "发现无效节点(None),跳过")
|
||||
continue
|
||||
|
||||
if not hasattr(node, 'node_id') or not hasattr(node, 'content'):
|
||||
util.log(1, f"发现无效节点(缺少必要属性),跳过")
|
||||
continue
|
||||
raw_content = node.content if isinstance(node.content, str) else str(node.content)
|
||||
cleaned_content = _remove_think_from_text(raw_content)
|
||||
if cleaned_content != raw_content:
|
||||
old_content = raw_content
|
||||
node.content = cleaned_content
|
||||
if (
|
||||
agent.memory_stream.embeddings is not None
|
||||
and old_content in agent.memory_stream.embeddings
|
||||
and cleaned_content not in agent.memory_stream.embeddings
|
||||
):
|
||||
agent.memory_stream.embeddings[cleaned_content] = agent.memory_stream.embeddings[old_content]
|
||||
else:
|
||||
node.content = raw_content
|
||||
valid_nodes.append(node)
|
||||
|
||||
# 更新seq_nodes为有效节点列表
|
||||
agent.memory_stream.seq_nodes = valid_nodes
|
||||
|
||||
# 重建id_to_node字典
|
||||
agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')}
|
||||
if agent.memory_stream.embeddings is not None:
|
||||
kept_contents = {node.content for node in valid_nodes if hasattr(node, 'content')}
|
||||
agent.memory_stream.embeddings = {
|
||||
key: value
|
||||
for key, value in agent.memory_stream.embeddings.items()
|
||||
if key in kept_contents
|
||||
}
|
||||
except Exception as e:
|
||||
util.log(1, f"检查记忆完整性时出错: {str(e)}")
|
||||
|
||||
# 保存记忆
|
||||
try:
|
||||
|
||||
@@ -91,7 +91,7 @@ if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("示例1:张三的对话(带观察数据)")
|
||||
print("=" * 60)
|
||||
test_gpt("你好,今天天气不错啊", username="user", observation=OBSERVATION_SAMPLES["张三"])
|
||||
test_gpt(prompt="", username="user", observation=OBSERVATION_SAMPLES["张三"], no_reply=False)
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user