diff --git a/core/fay_core.py b/core/fay_core.py
index 0a0100b..1165f80 100644
--- a/core/fay_core.py
+++ b/core/fay_core.py
@@ -1,1026 +1,2706 @@
-# -*- coding: utf-8 -*-
-#作用是处理交互逻辑,文字输入,语音、文字及情绪的发送、播放及展示输出
-import math
-from operator import index
-import os
-import time
-import socket
-import requests
-from pydub import AudioSegment
-from queue import Queue
-import re # 添加正则表达式模块用于过滤表情符号
-import uuid
-
-# 适应模型使用
-import numpy as np
-from ai_module import baidu_emotion
-from core import wsa_server
-from core.interact import Interact
-from tts.tts_voice import EnumVoice
-from scheduler.thread_manager import MyThread
-from tts import tts_voice
-from utils import util, config_util
-from core import qa_service
-from utils import config_util as cfg
-from core import content_db
-from ai_module import nlp_cemotion
-from core import stream_manager
-
-from core import member_db
-import threading
-
-#加载配置
-cfg.load_config()
-if cfg.tts_module =='ali':
- from tts.ali_tss import Speech
-elif cfg.tts_module == 'gptsovits':
- from tts.gptsovits import Speech
-elif cfg.tts_module == 'gptsovits_v3':
- from tts.gptsovits_v3 import Speech
-elif cfg.tts_module == 'volcano':
- from tts.volcano_tts import Speech
-else:
- from tts.ms_tts_sdk import Speech
-
-#windows运行推送唇形数据
-import platform
-if platform.system() == "Windows":
- import sys
- sys.path.append("test/ovr_lipsync")
- from test_olipsync import LipSyncGenerator
-
-
-#可以使用自动播报的标记
-can_auto_play = True
-auto_play_lock = threading.RLock()
-
-class FeiFei:
- def __init__(self):
- self.lock = threading.Lock()
- self.nlp_streams = {} # 存储用户ID到句子缓存的映射
- self.nlp_stream_lock = threading.Lock() # 保护nlp_streams字典的锁
- self.mood = 0.0 # 情绪值
- self.old_mood = 0.0
- self.item_index = 0
- self.X = np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, -1) # 适应模型变量矩阵
- # self.W = np.array([0.01577594,1.16119452,0.75828,0.207746,1.25017864,0.1044121,0.4294899,0.2770932]).reshape(-1,1) #适应模型变量矩阵
- self.W = np.array([0.0, 0.6, 0.1, 0.7, 0.3, 0.0, 0.0, 0.0]).reshape(-1, 1) # 适应模型变量矩阵
-
- self.wsParam = None
- self.wss = None
- self.sp = Speech()
- self.speaking = False #声音是否在播放
- self.__running = True
- self.sp.connect() #TODO 预连接
-
- self.timer = None
- self.sound_query = Queue()
- self.think_mode_users = {} # 使用字典存储每个用户的think模式状态
- self.think_time_users = {} #使用字典存储每个用户的think开始时间
- self.think_display_state = {}
- self.think_display_limit = 400
- self.user_conv_map = {} #存储用户对话id及句子流序号,key为(username, conversation_id)
- self.pending_isfirst = {} # 存储因prestart被过滤而延迟的isfirst标记,key为username
-
- def __remove_emojis(self, text):
- """
- 改进的表情包过滤,避免误删除正常Unicode字符
- """
- # 更精确的emoji范围,避免误删除正常字符
- emoji_pattern = re.compile(
- "["
- "\U0001F600-\U0001F64F" # 表情符号 (Emoticons)
- "\U0001F300-\U0001F5FF" # 杂项符号和象形文字 (Miscellaneous Symbols and Pictographs)
- "\U0001F680-\U0001F6FF" # 交通和地图符号 (Transport and Map Symbols)
- "\U0001F1E0-\U0001F1FF" # 区域指示符号 (Regional Indicator Symbols)
- "\U0001F900-\U0001F9FF" # 补充符号和象形文字 (Supplemental Symbols and Pictographs)
- "\U0001FA70-\U0001FAFF" # 扩展A符号和象形文字 (Symbols and Pictographs Extended-A)
- "\U00002600-\U000026FF" # 杂项符号 (Miscellaneous Symbols)
- "\U00002700-\U000027BF" # 装饰符号 (Dingbats)
- "\U0000FE00-\U0000FE0F" # 变体选择器 (Variation Selectors)
- "\U0001F000-\U0001F02F" # 麻将牌 (Mahjong Tiles)
- "\U0001F0A0-\U0001F0FF" # 扑克牌 (Playing Cards)
- "]+",
- flags=re.UNICODE,
- )
-
- # 保护常用的中文标点符号和特殊字符
- protected_chars = ["。", ",", "!", "?", ":", ";", "、", """, """, "'", "'", "(", ")", "【", "】", "《", "》"]
-
- # 先保存保护字符的位置
- protected_positions = {}
- for i, char in enumerate(text):
- if char in protected_chars:
- protected_positions[i] = char
-
- # 执行emoji过滤
- filtered_text = emoji_pattern.sub('', text)
-
- # 如果过滤后文本长度变化太大,可能误删了正常字符,返回原文本
- if len(filtered_text) < len(text) * 0.5: # 如果删除了超过50%的内容
- return text
-
- return filtered_text
-
- def __process_stream_output(self, text, username, session_type="type2_stream", is_qa=False):
- """
- 按流式方式分割和发送 type=2 的文本
- 使用安全的流式文本处理器和状态管理器
- """
- if not text or text.strip() == "":
- return
-
- # 使用安全的流式文本处理器
- from utils.stream_text_processor import get_processor
- from utils.stream_state_manager import get_state_manager
-
- processor = get_processor()
- state_manager = get_state_manager()
-
- # 处理流式文本,is_qa=False表示普通模式
- success = processor.process_stream_text(text, username, is_qa=is_qa, session_type=session_type)
-
- if success:
- # 普通模式结束会话
- state_manager.end_session(username, conversation_id=stream_manager.new_instance().get_conversation_id(username))
- else:
- util.log(1, f"type=2流式处理失败,文本长度: {len(text)}")
- # 失败时也要确保结束会话
- state_manager.force_reset_user_state(username)
-
- #语音消息处理检查是否命中q&a
- def __get_answer(self, interleaver, text):
- answer = None
- # 全局问答
- answer, type = qa_service.QAService().question('qa',text)
- if answer is not None:
- return answer, type
- else:
- return None, None
-
-
- #消息处理
- def __process_interact(self, interact: Interact):
- if self.__running:
- try:
- index = interact.interact_type
- username = interact.data.get("user", "User")
- uid = member_db.new_instance().find_user(username)
no_reply = interact.data.get("no_reply", False)
if isinstance(no_reply, str):
no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on")
else:
no_reply = bool(no_reply)
-
- if index == 1: #语音、文字交互
-
- #记录用户问题,方便obs等调用
- self.write_to_file("./logs", "asr_result.txt", interact.data["msg"])
-
- #同步用户问题到数字人
- if wsa_server.get_instance().is_connected(username):
- content = {'Topic': 'human', 'Data': {'Key': 'question', 'Value': interact.data["msg"]}, 'Username' : interact.data.get("user")}
- wsa_server.get_instance().add_cmd(content)
-
- #记录用户问题
- content_id = content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid)
- if wsa_server.get_web_instance().is_connected(username):
- wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid, "id":content_id}, "Username" : username})
-
- if no_reply:
return ""
#确定是否命中q&a
- answer, type = self.__get_answer(interact.interleaver, interact.data["msg"])
-
- #大语言模型回复
- text = ''
- if answer is None or type != "qa":
- if wsa_server.get_web_instance().is_connected(username):
- wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'})
- if wsa_server.get_instance().is_connected(username):
- content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}
- wsa_server.get_instance().add_cmd(content)
-
- # 根据配置动态调用不同的NLP模块
- if cfg.config["memory"].get("use_bionic_memory", False):
- from llm import nlp_bionicmemory_stream
- text = nlp_bionicmemory_stream.question(interact.data["msg"], username, interact.data.get("observation", None))
- else:
- from llm import nlp_cognitive_stream
- text = nlp_cognitive_stream.question(interact.data["msg"], username, interact.data.get("observation", None))
-
- else:
- text = answer
- # 使用流式分割处理Q&A答案
- self.__process_stream_output(text, username, session_type="qa", is_qa=True)
-
-
- return text
-
- elif (index == 2):#透传模式:有音频则仅播音频;仅文本则流式+TTS
- audio_url = interact.data.get("audio")
- text = interact.data.get("text")
-
- # 1) 存在音频:忽略文本,仅播放音频
- if audio_url and str(audio_url).strip():
- try:
- audio_interact = Interact(
- "stream", 1,
- {"user": username, "msg": "", "isfirst": True, "isend": True, "audio": audio_url}
- )
- self.say(audio_interact, "")
- except Exception:
- pass
- return 'success'
-
- # 2) 只有文本:执行流式切分并TTS
- if text and str(text).strip():
- # 进行流式处理(用于TTS,流式处理中会记录到数据库)
- self.__process_stream_output(text, username, f"type2_{interact.interleaver}", is_qa=False)
-
- # 不再需要额外记录,因为流式处理已经记录了
- # self.__process_text_output(text, username, uid)
-
- return 'success'
-
- # 没有有效内容
- return 'success'
-
- except BaseException as e:
- print(e)
- return e
- else:
- return "还没有开始运行"
-
- #记录问答到log
- def write_to_file(self, path, filename, content):
- if not os.path.exists(path):
- os.makedirs(path)
- full_path = os.path.join(path, filename)
- with open(full_path, 'w', encoding='utf-8') as file:
- file.write(content)
- file.flush()
- os.fsync(file.fileno())
-
- #触发交互
- def on_interact(self, interact: Interact):
- #创建用户
- username = interact.data.get("user", "User")
- if member_db.new_instance().is_username_exist(username) == "notexists":
- member_db.new_instance().add_user(username)
- no_reply = interact.data.get("no_reply", False)
- if isinstance(no_reply, str):
- no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on")
- else:
- no_reply = bool(no_reply)
-
- if not no_reply:
- try:
- from utils.stream_state_manager import get_state_manager
- import uuid
- if get_state_manager().is_session_active(username):
- stream_manager.new_instance().clear_Stream_with_audio(username)
- conv_id = "conv_" + str(uuid.uuid4())
- stream_manager.new_instance().set_current_conversation(username, conv_id)
- # 将当前会话ID附加到交互数据
- interact.data["conversation_id"] = conv_id
- # 允许新的生成
- stream_manager.new_instance().set_stop_generation(username, stop=False)
- except Exception:
- util.log(3, "开启新会话失败")
-
- if interact.interact_type == 1:
- MyThread(target=self.__process_interact, args=[interact]).start()
- else:
- return self.__process_interact(interact)
-
- #获取不同情绪声音
- def __get_mood_voice(self):
- voice = tts_voice.get_voice_of(config_util.config["attribute"]["voice"])
- if voice is None:
- voice = EnumVoice.XIAO_XIAO
- styleList = voice.value["styleList"]
- sayType = styleList["calm"]
- return sayType
-
- # 合成声音
- def say(self, interact, text, type = ""):
- try:
- uid = member_db.new_instance().find_user(interact.data.get("user"))
- is_end = interact.data.get("isend", False)
- is_first = interact.data.get("isfirst", False)
- username = interact.data.get("user", "User")
-
- # 提前进行会话有效性与中断检查,避免产生多余面板/数字人输出
- try:
- user_for_stop = interact.data.get("user", "User")
- conv_id_for_stop = interact.data.get("conversation_id")
- if not is_end and stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
- return None
- except Exception:
- pass
-
- #无效流式文本提前结束
- if not is_first and not is_end and (text is None or text.strip() == ""):
- return None
-
- # 检查是否是 prestart 内容(不应该影响 thinking 状态)
- is_prestart_content = self.__has_prestart(text)
-
- # 流式文本拼接存库
- content_id = 0
- # 使用 (username, conversation_id) 作为 key,避免并发会话覆盖
- conv = interact.data.get("conversation_id") or ""
- conv_map_key = (username, conv)
-
- if is_first == True:
- # reset any leftover think-mode at the start of a new reply
- # 但如果是 prestart 内容,不重置 thinking 状态
- try:
- if uid is not None and not is_prestart_content:
- self.think_mode_users[uid] = False
- if uid in self.think_time_users:
- del self.think_time_users[uid]
if uid in self.think_display_state:
- del self.think_display_state[uid]
-
- except Exception:
- pass
- # 如果没有 conversation_id,生成一个新的
- if not conv:
- conv = "conv_" + str(uuid.uuid4())
- conv_map_key = (username, conv)
- conv_no = 0
- # 创建第一条数据库记录,获得content_id
- if text and text.strip():
- content_id = content_db.new_instance().add_content('fay', 'speak', text, username, uid)
- else:
- content_id = content_db.new_instance().add_content('fay', 'speak', '', username, uid)
-
- # 保存content_id到会话映射中,使用 (username, conversation_id) 作为 key
- self.user_conv_map[conv_map_key] = {
- "conversation_id": conv,
- "conversation_msg_no": conv_no,
- "content_id": content_id
- }
- util.log(1, f"流式会话开始: key={conv_map_key}, content_id={content_id}")
- else:
- # 获取之前保存的content_id
- conv_info = self.user_conv_map.get(conv_map_key, {})
- content_id = conv_info.get("content_id", 0)
-
- # 如果 conv_map_key 不存在,尝试使用 username 作为备用查找
- if not conv_info and text and text.strip():
- # 查找所有匹配用户名的会话
- for (u, c), info in list(self.user_conv_map.items()):
- if u == username and info.get("content_id", 0) > 0:
- content_id = info.get("content_id", 0)
- conv_info = info
- util.log(1, f"警告:使用备用会话 ({u}, {c}) 的 content_id={content_id},原 key=({username}, {conv})")
- break
-
- if conv_info:
- conv_info["conversation_msg_no"] = conv_info.get("conversation_msg_no", 0) + 1
-
- # 如果有新内容,更新数据库
- if content_id > 0 and text and text.strip():
- # 获取当前已有内容
- existing_content = content_db.new_instance().get_content_by_id(content_id)
- if existing_content:
- # 累积内容
- accumulated_text = existing_content[3] + text
- content_db.new_instance().update_content(content_id, accumulated_text)
- elif content_id == 0 and text and text.strip():
- # content_id 为 0 表示可能会话 key 不匹配,记录警告
- util.log(1, f"警告:content_id=0,无法更新数据库。user={username}, conv={conv}, text片段={text[:50] if len(text) > 50 else text}")
-
- # 会话结束时清理 user_conv_map 中的对应条目,避免内存泄漏
- if is_end and conv_map_key in self.user_conv_map:
- del self.user_conv_map[conv_map_key]
-
- # 推送给前端和数字人
- try:
- user_for_stop = interact.data.get("user", "User")
- conv_id_for_stop = interact.data.get("conversation_id")
- if is_end or not stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
- self.__process_text_output(text, interact.data.get('user'), uid, content_id, type, is_first, is_end)
- except Exception:
- self.__process_text_output(text, interact.data.get('user'), uid, content_id, type, is_first, is_end)
-
- # 处理think标签
- # 第一步:处理结束标记
- if "" in text:
- # 设置用户退出思考模式
- self.think_mode_users[uid] = False
-
- # 分割文本,提取后面的内容
- # 如果有多个,我们只关心最后一个后面的内容
- parts = text.split("")
- text = parts[-1].strip()
-
- # 如果提取出的文本为空,则不需要继续处理
- if text == "":
- return None
- # 第二步:处理开始标记
- # 注意:这里要检查经过上面处理后的text
- if "" in text:
- self.think_mode_users[uid] = True
- self.think_time_users[uid] = time.time()
-
- #”思考中“的输出
- if self.think_mode_users.get(uid, False):
- try:
- user_for_stop = interact.data.get("user", "User")
- conv_id_for_stop = interact.data.get("conversation_id")
- should_block = stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop)
- except Exception:
- should_block = False
- if not should_block:
- if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
- wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'})
- if wsa_server.get_instance().is_connected(interact.data.get("user")):
- content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}
- wsa_server.get_instance().add_cmd(content)
-
- #”请稍等“的音频输出(不影响文本输出)
- if self.think_mode_users.get(uid, False) == True and time.time() - self.think_time_users[uid] >= 5:
- self.think_time_users[uid] = time.time()
- text = "请稍等..."
- elif self.think_mode_users.get(uid, False) == True and "" not in text:
- return None
-
- result = None
- audio_url = interact.data.get('audio', None)#透传的音频
-
- # 移除 prestart 标签内容,不进行TTS
- tts_text = self.__remove_prestart_tags(text) if text else text
-
- if audio_url is not None:#透传音频下载
- file_name = 'sample-' + str(int(time.time() * 1000)) + audio_url[-4:]
- result = self.download_wav(audio_url, './samples/', file_name)
- elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().get_client_output(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts
- if tts_text != None and tts_text.replace("*", "").strip() != "":
- # 检查是否需要停止TTS处理(按会话)
- if stream_manager.new_instance().should_stop_generation(
- interact.data.get("user", "User"),
- conversation_id=interact.data.get("conversation_id")
- ):
- util.printInfo(1, interact.data.get('user'), 'TTS处理被打断,跳过音频合成')
- return None
-
- # 先过滤表情符号,然后再合成语音
- filtered_text = self.__remove_emojis(tts_text.replace("*", ""))
- if filtered_text is not None and filtered_text.strip() != "":
- util.printInfo(1, interact.data.get('user'), '合成音频...')
- tm = time.time()
- result = self.sp.to_sample(filtered_text, self.__get_mood_voice())
- # 合成完成后再次检查会话是否仍有效,避免继续输出旧会话结果
- try:
- user_for_stop = interact.data.get("user", "User")
- conv_id_for_stop = interact.data.get("conversation_id")
- if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
- return None
- except Exception:
- pass
- util.printInfo(1, interact.data.get("user"), "合成音频完成. 耗时: {} ms 文件:{}".format(math.floor((time.time() - tm) * 1000), result))
- else:
- # prestart 内容不应该触发机器人表情重置
- if is_end and not is_prestart_content and wsa_server.get_web_instance().is_connected(interact.data.get('user')):
- wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'})
-
- if result is not None or is_first or is_end:
- # prestart 内容不需要进入音频处理流程
- if is_prestart_content:
- return result
- if is_end:#TODO 临时方案:如果结束标记,则延迟1秒处理,免得is end比前面的音频tts要快
- time.sleep(1)
- MyThread(target=self.__process_output_audio, args=[result, interact, text]).start()
- return result
-
- except BaseException as e:
- print(e)
- return None
-
- #下载wav
- def download_wav(self, url, save_directory, filename):
- try:
- # 发送HTTP GET请求以获取WAV文件内容
- response = requests.get(url, stream=True)
- response.raise_for_status() # 检查请求是否成功
-
- # 确保保存目录存在
- if not os.path.exists(save_directory):
- os.makedirs(save_directory)
-
- # 构建保存文件的路径
- save_path = os.path.join(save_directory, filename)
-
- # 将WAV文件内容保存到指定文件
- with open(save_path, 'wb') as f:
- for chunk in response.iter_content(chunk_size=1024):
- if chunk:
- f.write(chunk)
-
- return save_path
- except requests.exceptions.RequestException as e:
- print(f"[Error] Failed to download file: {e}")
- return None
-
-
- #面板播放声音
- def __play_sound(self):
- try:
- import pygame
- pygame.mixer.init() # 初始化pygame.mixer,只需要在此处初始化一次, 如果初始化失败,则不播放音频
- except Exception as e:
- util.printInfo(1, "System", "音频播放初始化失败,本机无法播放音频")
- return
-
- while self.__running:
- time.sleep(0.01)
- if not self.sound_query.empty(): # 如果队列不为空则播放音频
- file_url, audio_length, interact = self.sound_query.get()
-
- is_first = interact.data.get('isfirst') is True
- is_end = interact.data.get('isend') is True
-
-
-
- if file_url is not None:
- util.printInfo(1, interact.data.get('user'), '播放音频...')
-
- if is_first:
- self.speaking = True
- elif not is_end:
- self.speaking = True
-
- #自动播报关闭
- global auto_play_lock
- global can_auto_play
- with auto_play_lock:
- if self.timer is not None:
- self.timer.cancel()
- self.timer = None
- can_auto_play = False
-
- if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
- wsa_server.get_web_instance().add_cmd({"panelMsg": "播放中 ...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'})
-
- if file_url is not None:
- pygame.mixer.music.load(file_url)
- pygame.mixer.music.play()
-
- # 播放过程中计时,直到音频播放完毕
- length = 0
- while length < audio_length:
- try:
- user_for_stop = interact.data.get("user", "User")
- conv_id_for_stop = interact.data.get("conversation_id")
- if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
- try:
- pygame.mixer.music.stop()
- except Exception:
- pass
- break
- except Exception:
- pass
- length += 0.01
- time.sleep(0.01)
-
- if is_end:
- self.play_end(interact)
-
- if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
- wsa_server.get_web_instance().add_cmd({"panelMsg": "", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'})
- # 播放完毕后通知
- if wsa_server.get_web_instance().is_connected(interact.data.get("user")):
- wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username': interact.data.get('user')})
-
- #推送远程音频
- def __send_remote_device_audio(self, file_url, interact):
- if file_url is None:
- return
- delkey = None
- for key, value in fay_booter.DeviceInputListenerDict.items():
- if value.username == interact.data.get("user") and value.isOutput: #按username选择推送,booter.devicelistenerdice按用户名记录
- try:
- value.deviceConnector.send(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08") # 发送音频开始标志,同时也检查设备是否在线
- wavfile = open(os.path.abspath(file_url), "rb")
- data = wavfile.read(102400)
- total = 0
- while data:
- total += len(data)
- value.deviceConnector.send(data)
- data = wavfile.read(102400)
- time.sleep(0.0001)
- value.deviceConnector.send(b'\x08\x07\x06\x05\x04\x03\x02\x01\x00')# 发送音频结束标志
- util.printInfo(1, value.username, "远程音频发送完成:{}".format(total))
- except socket.error as serr:
- util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key))
- value.stop()
- delkey = key
- if delkey:
- value = fay_booter.DeviceInputListenerDict.pop(delkey)
- if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
- wsa_server.get_web_instance().add_cmd({"remote_audio_connect": False, "Username" : interact.data.get('user')})
-
- def __is_send_remote_device_audio(self, interact):
- for key, value in fay_booter.DeviceInputListenerDict.items():
- if value.username == interact.data.get("user") and value.isOutput:
- return True
- return False
-
- #输出音频处理
- def __process_output_audio(self, file_url, interact, text):
- try:
- # 会话有效性与中断检查(最早返回,避免向面板/数字人发送任何旧会话输出)
- try:
- user_for_stop = interact.data.get("user", "User")
- conv_id_for_stop = interact.data.get("conversation_id")
- if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
- return
- except Exception:
- pass
- try:
- if file_url is None:
- audio_length = 0
- elif file_url.endswith('.wav'):
- audio = AudioSegment.from_wav(file_url)
- audio_length = len(audio) / 1000.0 # 时长以秒为单位
- elif file_url.endswith('.mp3'):
- audio = AudioSegment.from_mp3(file_url)
- audio_length = len(audio) / 1000.0 # 时长以秒为单位
- except Exception as e:
- audio_length = 3
-
- #推送远程音频
- if file_url is not None:
- MyThread(target=self.__send_remote_device_audio, args=[file_url, interact]).start()
-
- #发送音频给数字人接口
- if file_url is not None and wsa_server.get_instance().get_client_output(interact.data.get("user")):
- # 使用 (username, conversation_id) 作为 key 获取会话信息
- audio_username = interact.data.get("user", "User")
- audio_conv_id = interact.data.get("conversation_id") or ""
- audio_conv_info = self.user_conv_map.get((audio_username, audio_conv_id), {})
- content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0, 'CONV_ID' : audio_conv_info.get("conversation_id", ""), 'CONV_MSG_NO' : audio_conv_info.get("conversation_msg_no", 0) }, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}
- #计算lips
- if platform.system() == "Windows":
- try:
- lip_sync_generator = LipSyncGenerator()
- viseme_list = lip_sync_generator.generate_visemes(os.path.abspath(file_url))
- consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
- content["Data"]["Lips"] = consolidated_visemes
- except Exception as e:
- print(e)
- util.printInfo(1, interact.data.get("user"), "唇型数据生成失败")
- wsa_server.get_instance().add_cmd(content)
- util.printInfo(1, interact.data.get("user"), "数字人接口发送音频数据成功")
-
- #面板播放
- config_util.load_config()
- # 检查是否是 prestart 内容
- is_prestart = self.__has_prestart(text)
- if config_util.config["interact"]["playSound"]:
- # prestart 内容不应该进入播放队列,避免触发 Normal 状态
- if not is_prestart:
- self.sound_query.put((file_url, audio_length, interact))
- else:
- # prestart 内容不应该重置机器人表情
- if not is_prestart and wsa_server.get_web_instance().is_connected(interact.data.get('user')):
- wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'})
-
- except Exception as e:
- print(e)
-
- def play_end(self, interact):
- self.speaking = False
- global can_auto_play
- global auto_play_lock
- with auto_play_lock:
- if self.timer:
- self.timer.cancel()
- self.timer = None
- if interact.interleaver != 'auto_play': #交互后暂停自动播报30秒
- self.timer = threading.Timer(30, self.set_auto_play)
- self.timer.start()
- else:
- can_auto_play = True
-
- #恢复自动播报(如果有)
- def set_auto_play(self):
- global auto_play_lock
- global can_auto_play
- with auto_play_lock:
- can_auto_play = True
- self.timer = None
-
- #启动核心服务
- def start(self):
- MyThread(target=self.__play_sound).start()
-
- #停止核心服务
- def stop(self):
- self.__running = False
- self.speaking = False
- self.sp.close()
- wsa_server.get_web_instance().add_cmd({"panelMsg": ""})
- content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': ""}}
- wsa_server.get_instance().add_cmd(content)
-
- def __record_response(self, text, username, uid):
- """
- 记录AI的回复内容
- :param text: 回复文本
- :param username: 用户名
- :param uid: 用户ID
- :return: content_id
- """
- self.write_to_file("./logs", "answer_result.txt", text)
- return content_db.new_instance().add_content('fay', 'speak', text, username, uid)
-
- def __remove_prestart_tags(self, text):
- """
- 移除文本中的 prestart 标签及其内容
- :param text: 原始文本
- :return: 移除 prestart 标签后的文本
- """
- if not text:
- return text
- import re
- # 移除 ... 标签及其内容(支持属性)
- cleaned = re.sub(r']*>[\s\S]*?', '', text, flags=re.IGNORECASE)
- return cleaned.strip()
-
- def __has_prestart(self, text):
- """
- 判断文本中是否包含 prestart 标签(支持属性)
- """
- if not text:
- return False
- return re.search(r']*>[\s\S]*?', text, flags=re.IGNORECASE) is not None
-
-
- def __truncate_think_for_panel(self, text, uid, username):
-
- if not text or not isinstance(text, str):
-
- return text
-
- key = uid if uid is not None else username
-
- state = self.think_display_state.get(key)
-
- if state is None:
-
- state = {"in_think": False, "in_tool_output": False, "tool_count": 0, "tool_truncated": False}
-
- self.think_display_state[key] = state
-
- if not state["in_think"] and "" not in text and "" not in text:
-
- return text
-
- tool_output_regex = re.compile(r"\[TOOL\]\s*(?:Output|\u8f93\u51fa)[:\uff1a]", re.IGNORECASE)
-
- section_regex = re.compile(r"(?i)(^|[\r\n])(\[(?:TOOL|PLAN)\])")
-
- out = []
-
- i = 0
-
- while i < len(text):
-
- if not state["in_think"]:
-
- idx = text.find("", i)
-
- if idx == -1:
-
- out.append(text[i:])
-
- break
-
- out.append(text[i:idx + len("")])
-
- state["in_think"] = True
-
- i = idx + len("")
-
- continue
-
- if not state["in_tool_output"]:
-
- think_end = text.find("", i)
-
- tool_match = tool_output_regex.search(text, i)
-
- next_pos = None
-
- next_kind = None
-
- if tool_match:
-
- next_pos = tool_match.start()
-
- next_kind = "tool"
-
- if think_end != -1 and (next_pos is None or think_end < next_pos):
-
- next_pos = think_end
-
- next_kind = "think_end"
-
- if next_pos is None:
-
- out.append(text[i:])
-
- break
-
- if next_pos > i:
-
- out.append(text[i:next_pos])
-
- if next_kind == "think_end":
-
- out.append("")
-
- state["in_think"] = False
-
- state["in_tool_output"] = False
-
- state["tool_count"] = 0
-
- state["tool_truncated"] = False
-
- i = next_pos + len("")
-
- else:
-
- marker_end = tool_match.end()
-
- out.append(text[next_pos:marker_end])
-
- state["in_tool_output"] = True
-
- state["tool_count"] = 0
-
- state["tool_truncated"] = False
-
- i = marker_end
-
- continue
-
- think_end = text.find("", i)
-
- section_match = section_regex.search(text, i)
-
- end_pos = None
-
- if section_match:
-
- end_pos = section_match.start(2)
-
- if think_end != -1 and (end_pos is None or think_end < end_pos):
-
- end_pos = think_end
-
- segment = text[i:] if end_pos is None else text[i:end_pos]
-
- if segment:
-
- if state["tool_truncated"]:
-
- pass
-
- else:
-
- remaining = self.think_display_limit - state["tool_count"]
-
- if remaining <= 0:
-
- out.append("...")
-
- state["tool_truncated"] = True
-
- elif len(segment) <= remaining:
-
- out.append(segment)
-
- state["tool_count"] += len(segment)
-
- else:
-
- out.append(segment[:remaining] + "...")
-
- state["tool_count"] += remaining
-
- state["tool_truncated"] = True
-
- if end_pos is None:
-
- break
-
- state["in_tool_output"] = False
-
- state["tool_count"] = 0
-
- state["tool_truncated"] = False
-
- i = end_pos
-
- return "".join(out)
-
- def __send_panel_message(self, text, username, uid, content_id=None, type=None):
- """
- 发送消息到Web面板
- :param text: 消息文本
- :param username: 用户名
- :param uid: 用户ID
- :param content_id: 内容ID
- :param type: 消息类型
- """
- if not wsa_server.get_web_instance().is_connected(username):
- return
-
- # 检查是否是 prestart 内容,prestart 内容不应该更新日志区消息
- # 因为这会覆盖掉"思考中..."的状态显示
- is_prestart = self.__has_prestart(text)
display_text = self.__truncate_think_for_panel(text, uid, username)
+# -*- coding: utf-8 -*-
+
+
+#作用是处理交互逻辑,文字输入,语音、文字及情绪的发送、播放及展示输出
+
+
+import math
+
+
+from operator import index
+
+
+import os
+
+
+import time
+
+
+import socket
+
+
+import requests
+
+
+from pydub import AudioSegment
+
+
+from queue import Queue
+
+
+import re # 添加正则表达式模块用于过滤表情符号
+
+
+import uuid
+
+
+
+
+
+# 适应模型使用
+
+
+import numpy as np
+
+
+from ai_module import baidu_emotion
+
+
+from core import wsa_server
+
+
+from core.interact import Interact
+
+
+from tts.tts_voice import EnumVoice
+
+
+from scheduler.thread_manager import MyThread
+
+
+from tts import tts_voice
+
+
+from utils import util, config_util
+
+
+from core import qa_service
+
+
+from utils import config_util as cfg
+
+
+from core import content_db
+
+
+from ai_module import nlp_cemotion
+
+
+from core import stream_manager
+
+
+
+
+
+from core import member_db
+
+
+import threading
+
+
+
+
+
+#加载配置
+
+
+cfg.load_config()
+
+
+if cfg.tts_module =='ali':
+
+
+ from tts.ali_tss import Speech
+
+
+elif cfg.tts_module == 'gptsovits':
+
+
+ from tts.gptsovits import Speech
+
+
+elif cfg.tts_module == 'gptsovits_v3':
+
+
+ from tts.gptsovits_v3 import Speech
+
+
+elif cfg.tts_module == 'volcano':
+
+
+ from tts.volcano_tts import Speech
+
+
+else:
+
+
+ from tts.ms_tts_sdk import Speech
+
+
+
+
+
+#windows运行推送唇形数据
+
+
+import platform
+
+
+if platform.system() == "Windows":
+
+
+ import sys
+
+
+ sys.path.append("test/ovr_lipsync")
+
+
+ from test_olipsync import LipSyncGenerator
+
+
+
+
+
+
+
+
+#可以使用自动播报的标记
+
+
+can_auto_play = True
+
+
+auto_play_lock = threading.RLock()
+
+
+
+
+
+class FeiFei:
+
+
+ def __init__(self):
+
+
+ self.lock = threading.Lock()
+
+
+ self.nlp_streams = {} # 存储用户ID到句子缓存的映射
+
+
+ self.nlp_stream_lock = threading.Lock() # 保护nlp_streams字典的锁
+
+
+ self.mood = 0.0 # 情绪值
+
+
+ self.old_mood = 0.0
+
+
+ self.item_index = 0
+
+
+ self.X = np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, -1) # 适应模型变量矩阵
+
+
+ # self.W = np.array([0.01577594,1.16119452,0.75828,0.207746,1.25017864,0.1044121,0.4294899,0.2770932]).reshape(-1,1) #适应模型变量矩阵
+
+
+ self.W = np.array([0.0, 0.6, 0.1, 0.7, 0.3, 0.0, 0.0, 0.0]).reshape(-1, 1) # 适应模型变量矩阵
+
+
+
+
+
+ self.wsParam = None
+
+
+ self.wss = None
+
+
+ self.sp = Speech()
+
+
+ self.speaking = False #声音是否在播放
+
+
+ self.__running = True
+
+
+ self.sp.connect() #TODO 预连接
+
+
+
+
+
+ self.timer = None
+
+
+ self.sound_query = Queue()
+
+
+ self.think_mode_users = {} # 使用字典存储每个用户的think模式状态
+
+
+ self.think_time_users = {} #使用字典存储每个用户的think开始时间
+ self.think_display_state = {}
+ self.think_display_limit = 400
+ self.user_conv_map = {} #存储用户对话id及句子流序号,key为(username, conversation_id)
+
+ self.pending_isfirst = {} # 存储因prestart被过滤而延迟的isfirst标记,key为username
+
+
+
+
+ def __remove_emojis(self, text):
+
+
+ """
+
+
+ 改进的表情包过滤,避免误删除正常Unicode字符
+
+
+ """
+
+
+ # 更精确的emoji范围,避免误删除正常字符
+
+
+ emoji_pattern = re.compile(
+
+
+ "["
+
+
+ "\U0001F600-\U0001F64F" # 表情符号 (Emoticons)
+
+
+ "\U0001F300-\U0001F5FF" # 杂项符号和象形文字 (Miscellaneous Symbols and Pictographs)
+
+
+ "\U0001F680-\U0001F6FF" # 交通和地图符号 (Transport and Map Symbols)
+
+
+ "\U0001F1E0-\U0001F1FF" # 区域指示符号 (Regional Indicator Symbols)
+
+
+ "\U0001F900-\U0001F9FF" # 补充符号和象形文字 (Supplemental Symbols and Pictographs)
+
+
+ "\U0001FA70-\U0001FAFF" # 扩展A符号和象形文字 (Symbols and Pictographs Extended-A)
+
+
+ "\U00002600-\U000026FF" # 杂项符号 (Miscellaneous Symbols)
+
+
+ "\U00002700-\U000027BF" # 装饰符号 (Dingbats)
+
+
+ "\U0000FE00-\U0000FE0F" # 变体选择器 (Variation Selectors)
+
+
+ "\U0001F000-\U0001F02F" # 麻将牌 (Mahjong Tiles)
+
+
+ "\U0001F0A0-\U0001F0FF" # 扑克牌 (Playing Cards)
+
+
+ "]+",
+
+
+ flags=re.UNICODE,
+
+
+ )
+
+
+
+
+
+ # 保护常用的中文标点符号和特殊字符
+
+
+ protected_chars = ["。", ",", "!", "?", ":", ";", "、", """, """, "'", "'", "(", ")", "【", "】", "《", "》"]
+
+
+
+
+
+ # 先保存保护字符的位置
+
+
+ protected_positions = {}
+
+
+ for i, char in enumerate(text):
+
+
+ if char in protected_chars:
+
+
+ protected_positions[i] = char
+
+
+
+
+
+ # 执行emoji过滤
+
+
+ filtered_text = emoji_pattern.sub('', text)
+
+
+
+
+
+ # 如果过滤后文本长度变化太大,可能误删了正常字符,返回原文本
+
+
+ if len(filtered_text) < len(text) * 0.5: # 如果删除了超过50%的内容
+
+
+ return text
+
+
+
+
+
+ return filtered_text
+
+
+
+
+
+ def __process_stream_output(self, text, username, session_type="type2_stream", is_qa=False):
+
+
+ """
+
+
+ 按流式方式分割和发送 type=2 的文本
+
+
+ 使用安全的流式文本处理器和状态管理器
+
+
+ """
+
+
+ if not text or text.strip() == "":
+
+
+ return
+
+
+
+
+
+ # 使用安全的流式文本处理器
+
+
+ from utils.stream_text_processor import get_processor
+
+
+ from utils.stream_state_manager import get_state_manager
+
+
+
+
+
+ processor = get_processor()
+
+
+ state_manager = get_state_manager()
+
+
+
+
+
+ # 处理流式文本,is_qa=False表示普通模式
+
+
+ success = processor.process_stream_text(text, username, is_qa=is_qa, session_type=session_type)
+
+
+
+
+
+ if success:
+
+
+ # 普通模式结束会话
+
+
+ state_manager.end_session(username, conversation_id=stream_manager.new_instance().get_conversation_id(username))
+
+
+ else:
+
+
+ util.log(1, f"type=2流式处理失败,文本长度: {len(text)}")
+
+
+ # 失败时也要确保结束会话
+
+
+ state_manager.force_reset_user_state(username)
+
+
+
+
+
+ #语音消息处理检查是否命中q&a
+
+
+ def __get_answer(self, interleaver, text):
+
+
+ answer = None
+
+
+ # 全局问答
+
+
+ answer, type = qa_service.QAService().question('qa',text)
+
+
+ if answer is not None:
+
+
+ return answer, type
+
+
+ else:
+
+
+ return None, None
+
+
+
+
+
+
+
+
+ #消息处理
+
+
+ def __process_interact(self, interact: Interact):
+
+
+ if self.__running:
+
+
+ try:
+
+
+ index = interact.interact_type
+
+
+ username = interact.data.get("user", "User")
+
+
+ uid = member_db.new_instance().find_user(username)
+ no_reply = interact.data.get("no_reply", False)
+ if isinstance(no_reply, str):
+ no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on")
+ else:
+ no_reply = bool(no_reply)
+
+
+
+
+
+ if index == 1: #语音、文字交互
+
+
+
+
+
+ #记录用户问题,方便obs等调用
+
+
+ self.write_to_file("./logs", "asr_result.txt", interact.data["msg"])
+
+
+
+
+
+ #同步用户问题到数字人
+
+
+ if wsa_server.get_instance().is_connected(username):
+
+
+ content = {'Topic': 'human', 'Data': {'Key': 'question', 'Value': interact.data["msg"]}, 'Username' : interact.data.get("user")}
+
+
+ wsa_server.get_instance().add_cmd(content)
+
+
+
+
+
+ #记录用户问题
+
+
+ if not no_reply:
+ content_id = content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid)
+ if wsa_server.get_web_instance().is_connected(username):
+ wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid, "id":content_id}, "Username" : username})
+
+
+
+
+
+ observation = interact.data.get("observation", None)
+ obs_text = ""
+ if observation is not None:
+ obs_text = observation.strip() if isinstance(observation, str) else str(observation).strip()
+ if not obs_text and no_reply:
+ msg_text = interact.data.get("msg", "")
+ obs_text = msg_text.strip() if isinstance(msg_text, str) else str(msg_text).strip()
+ if obs_text:
+ from llm import nlp_cognitive_stream
+ nlp_cognitive_stream.record_observation(username, obs_text)
+ if no_reply:
+ return ""
+
+ #确定是否命中q&a
+
+
+ answer, type = self.__get_answer(interact.interleaver, interact.data["msg"])
+
+
+
+
+
+ #大语言模型回复
+
+
+ text = ''
+
+
+ if answer is None or type != "qa":
+
+
+ if wsa_server.get_web_instance().is_connected(username):
+
+
+ wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'})
+
+
+ if wsa_server.get_instance().is_connected(username):
+
+
+ content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}
+
+
+ wsa_server.get_instance().add_cmd(content)
+
+
+
+
+
+ # 根据配置动态调用不同的NLP模块
+
+
+ if cfg.config["memory"].get("use_bionic_memory", False):
+
+
+ from llm import nlp_bionicmemory_stream
+
+
+ text = nlp_bionicmemory_stream.question(interact.data["msg"], username, interact.data.get("observation", None))
+
+
+ else:
+
+
+ from llm import nlp_cognitive_stream
+
+
+ text = nlp_cognitive_stream.question(interact.data["msg"], username, interact.data.get("observation", None))
+
+
+
+
+
+ else:
+
+
+ text = answer
+
+
+ # 使用流式分割处理Q&A答案
+
+
+ self.__process_stream_output(text, username, session_type="qa", is_qa=True)
+
+
+
+
+
+
+
+
+ return text
+
+
+
+
+
+ elif (index == 2):#透传模式:有音频则仅播音频;仅文本则流式+TTS
+
+
+ audio_url = interact.data.get("audio")
+
+
+ text = interact.data.get("text")
+
+
+
+
+
+ # 1) 存在音频:忽略文本,仅播放音频
+
+
+ if audio_url and str(audio_url).strip():
+
+
+ try:
+
+
+ audio_interact = Interact(
+
+
+ "stream", 1,
+
+
+ {"user": username, "msg": "", "isfirst": True, "isend": True, "audio": audio_url}
+
+
+ )
+
+
+ self.say(audio_interact, "")
+
+
+ except Exception:
+
+
+ pass
+
+
+ return 'success'
+
+
+
+
+
+ # 2) 只有文本:执行流式切分并TTS
+
+
+ if text and str(text).strip():
+
+
+ # 进行流式处理(用于TTS,流式处理中会记录到数据库)
+
+
+ self.__process_stream_output(text, username, f"type2_{interact.interleaver}", is_qa=False)
+
+
+
+
+
+ # 不再需要额外记录,因为流式处理已经记录了
+
+
+ # self.__process_text_output(text, username, uid)
+
+
+
+
+
+ return 'success'
+
+
+
+
+
+ # 没有有效内容
+
+
+ return 'success'
+
+
+
+
+
+ except BaseException as e:
+
+
+ print(e)
+
+
+ return e
+
+
+ else:
+
+
+ return "还没有开始运行"
+
+
+
+
+
+ #记录问答到log
+
+
+ def write_to_file(self, path, filename, content):
+
+
+ if not os.path.exists(path):
+
+
+ os.makedirs(path)
+
+
+ full_path = os.path.join(path, filename)
+
+
+ with open(full_path, 'w', encoding='utf-8') as file:
+
+
+ file.write(content)
+
+
+ file.flush()
+
+
+ os.fsync(file.fileno())
+
+
+
+
+
+ #触发交互
+
+
+ def on_interact(self, interact: Interact):
+
+
+ #创建用户
+
+
+ username = interact.data.get("user", "User")
+
+
+ if member_db.new_instance().is_username_exist(username) == "notexists":
+
+
+ member_db.new_instance().add_user(username)
+
+
+ no_reply = interact.data.get("no_reply", False)
+
+ if isinstance(no_reply, str):
+
+ no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on")
+
+ else:
+
+ no_reply = bool(no_reply)
+
+
+
+ if not no_reply:
+
+ try:
+
+
+ from utils.stream_state_manager import get_state_manager
+
+
+ import uuid
+
+
+ if get_state_manager().is_session_active(username):
+
+
+ stream_manager.new_instance().clear_Stream_with_audio(username)
+
+
+ conv_id = "conv_" + str(uuid.uuid4())
+
+
+ stream_manager.new_instance().set_current_conversation(username, conv_id)
+
+
+ # 将当前会话ID附加到交互数据
+
+
+ interact.data["conversation_id"] = conv_id
+
+
+ # 允许新的生成
+
+
+ stream_manager.new_instance().set_stop_generation(username, stop=False)
+
+
+ except Exception:
+
+
+ util.log(3, "开启新会话失败")
+
+
+
+
+
+ if interact.interact_type == 1:
+
+
+ MyThread(target=self.__process_interact, args=[interact]).start()
+
+
+ else:
+
+
+ return self.__process_interact(interact)
+
+
+
+
+
+ #获取不同情绪声音
+
+
+ def __get_mood_voice(self):
+
+
+ voice = tts_voice.get_voice_of(config_util.config["attribute"]["voice"])
+
+
+ if voice is None:
+
+
+ voice = EnumVoice.XIAO_XIAO
+
+
+ styleList = voice.value["styleList"]
+
+
+ sayType = styleList["calm"]
+
+
+ return sayType
+
+
+
+
+
+ # 合成声音
+
+
+ def say(self, interact, text, type = ""):
+
+
+ try:
+
+
+ uid = member_db.new_instance().find_user(interact.data.get("user"))
+
+
+ is_end = interact.data.get("isend", False)
+
+
+ is_first = interact.data.get("isfirst", False)
+
+
+ username = interact.data.get("user", "User")
+
+
+
+
+
+ # 提前进行会话有效性与中断检查,避免产生多余面板/数字人输出
+
+
+ try:
+
+
+ user_for_stop = interact.data.get("user", "User")
+
+
+ conv_id_for_stop = interact.data.get("conversation_id")
+
+
+ if not is_end and stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
+
+
+ return None
+
+
+ except Exception:
+
+
+ pass
+
+
+
+
+
+ #无效流式文本提前结束
+
+
+ if not is_first and not is_end and (text is None or text.strip() == ""):
+
+
+ return None
+
+
+
+
+
+ # 检查是否是 prestart 内容(不应该影响 thinking 状态)
+
+
+ is_prestart_content = self.__has_prestart(text)
+
+
+
+
+ # 流式文本拼接存库
+
+
+ content_id = 0
+
+
+ # 使用 (username, conversation_id) 作为 key,避免并发会话覆盖
+
+
+ conv = interact.data.get("conversation_id") or ""
+
+
+ conv_map_key = (username, conv)
+
+
+
+
+
+ if is_first == True:
+
+
+ # reset any leftover think-mode at the start of a new reply
+
+
+ # 但如果是 prestart 内容,不重置 thinking 状态
+
+
+ try:
+
+
+ if uid is not None and not is_prestart_content:
+
+
+ self.think_mode_users[uid] = False
+
+
+ if uid in self.think_time_users:
+
+
+ del self.think_time_users[uid]
+ if uid in self.think_display_state:
+ del self.think_display_state[uid]
+
+
+ except Exception:
+
+
+ pass
+
+
+ # 如果没有 conversation_id,生成一个新的
+
+
+ if not conv:
+
+
+ conv = "conv_" + str(uuid.uuid4())
+
+
+ conv_map_key = (username, conv)
+
+
+ conv_no = 0
+
+
+ # 创建第一条数据库记录,获得content_id
+
+
+ if text and text.strip():
+
+
+ content_id = content_db.new_instance().add_content('fay', 'speak', text, username, uid)
+
+
+ else:
+
+
+ content_id = content_db.new_instance().add_content('fay', 'speak', '', username, uid)
+
+
+
+
+
+ # 保存content_id到会话映射中,使用 (username, conversation_id) 作为 key
+
+
+ self.user_conv_map[conv_map_key] = {
+
+
+ "conversation_id": conv,
+
+
+ "conversation_msg_no": conv_no,
+
+
+ "content_id": content_id
+
+
+ }
+
+
+ util.log(1, f"流式会话开始: key={conv_map_key}, content_id={content_id}")
+
+
+ else:
+
+
+ # 获取之前保存的content_id
+
+
+ conv_info = self.user_conv_map.get(conv_map_key, {})
+
+
+ content_id = conv_info.get("content_id", 0)
+
+
+
+
+
+ # 如果 conv_map_key 不存在,尝试使用 username 作为备用查找
+
+
+ if not conv_info and text and text.strip():
+
+
+ # 查找所有匹配用户名的会话
+
+
+ for (u, c), info in list(self.user_conv_map.items()):
+
+
+ if u == username and info.get("content_id", 0) > 0:
+
+
+ content_id = info.get("content_id", 0)
+
+
+ conv_info = info
+
+
+ util.log(1, f"警告:使用备用会话 ({u}, {c}) 的 content_id={content_id},原 key=({username}, {conv})")
+
+
+ break
+
+
+
+
+
+ if conv_info:
+
+
+ conv_info["conversation_msg_no"] = conv_info.get("conversation_msg_no", 0) + 1
+
+
+
+
+
+ # 如果有新内容,更新数据库
+
+
+ if content_id > 0 and text and text.strip():
+
+
+ # 获取当前已有内容
+
+
+ existing_content = content_db.new_instance().get_content_by_id(content_id)
+
+
+ if existing_content:
+
+
+ # 累积内容
+
+
+ accumulated_text = existing_content[3] + text
+
+
+ content_db.new_instance().update_content(content_id, accumulated_text)
+
+
+ elif content_id == 0 and text and text.strip():
+
+
+ # content_id 为 0 表示可能会话 key 不匹配,记录警告
+
+
+ util.log(1, f"警告:content_id=0,无法更新数据库。user={username}, conv={conv}, text片段={text[:50] if len(text) > 50 else text}")
+
+
+
+
+
+ # 会话结束时清理 user_conv_map 中的对应条目,避免内存泄漏
+
+
+ if is_end and conv_map_key in self.user_conv_map:
+
+
+ del self.user_conv_map[conv_map_key]
+
+
+
+
+
+ # 推送给前端和数字人
+
+
+ try:
+
+
+ user_for_stop = interact.data.get("user", "User")
+
+
+ conv_id_for_stop = interact.data.get("conversation_id")
+
+
+ if is_end or not stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
+
+
+ self.__process_text_output(text, interact.data.get('user'), uid, content_id, type, is_first, is_end)
+
+
+ except Exception:
+
+
+ self.__process_text_output(text, interact.data.get('user'), uid, content_id, type, is_first, is_end)
+
+
+
+
+
+ # 处理think标签
+
+
+ # 第一步:处理结束标记
+
+
+ if "" in text:
+
+
+ # 设置用户退出思考模式
+
+
+ self.think_mode_users[uid] = False
+
+
+
+
+
+ # 分割文本,提取后面的内容
+
+
+ # 如果有多个,我们只关心最后一个后面的内容
+
+
+ parts = text.split("")
+
+
+ text = parts[-1].strip()
+
+
+
+
+
+ # 如果提取出的文本为空,则不需要继续处理
+
+
+ if text == "":
+
+
+ return None
+
+
+ # 第二步:处理开始标记
+
+
+ # 注意:这里要检查经过上面处理后的text
+
+
+ if "" in text:
+
+
+ self.think_mode_users[uid] = True
+
+
+ self.think_time_users[uid] = time.time()
+
+
+
+
+
+ #”思考中“的输出
+
+
+ if self.think_mode_users.get(uid, False):
+
+
+ try:
+
+
+ user_for_stop = interact.data.get("user", "User")
+
+
+ conv_id_for_stop = interact.data.get("conversation_id")
+
+
+ should_block = stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop)
+
+
+ except Exception:
+
+
+ should_block = False
+
+
+ if not should_block:
+
+
+ if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
+
+
+ wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'})
+
+
+ if wsa_server.get_instance().is_connected(interact.data.get("user")):
+
+
+ content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}
+
+
+ wsa_server.get_instance().add_cmd(content)
+
+
+
+
+
+ #”请稍等“的音频输出(不影响文本输出)
+
+
+ if self.think_mode_users.get(uid, False) == True and time.time() - self.think_time_users[uid] >= 5:
+
+
+ self.think_time_users[uid] = time.time()
+
+
+ text = "请稍等..."
+
+
+ elif self.think_mode_users.get(uid, False) == True and "" not in text:
+
+
+ return None
+
+
+
+
+
+ result = None
+
+
+ audio_url = interact.data.get('audio', None)#透传的音频
+
+
+
+
+
+ # 移除 prestart 标签内容,不进行TTS
+
+
+ tts_text = self.__remove_prestart_tags(text) if text else text
+
+
+
+
+
+ if audio_url is not None:#透传音频下载
+
+
+ file_name = 'sample-' + str(int(time.time() * 1000)) + audio_url[-4:]
+
+
+ result = self.download_wav(audio_url, './samples/', file_name)
+
+
+ elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().get_client_output(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts
+
+
+ if tts_text != None and tts_text.replace("*", "").strip() != "":
+
+
+ # 检查是否需要停止TTS处理(按会话)
+
+
+ if stream_manager.new_instance().should_stop_generation(
+
+
+ interact.data.get("user", "User"),
+
+
+ conversation_id=interact.data.get("conversation_id")
+
+
+ ):
+
+
+ util.printInfo(1, interact.data.get('user'), 'TTS处理被打断,跳过音频合成')
+
+
+ return None
+
+
+
+
+
+ # 先过滤表情符号,然后再合成语音
+
+
+ filtered_text = self.__remove_emojis(tts_text.replace("*", ""))
+
+
+ if filtered_text is not None and filtered_text.strip() != "":
+
+
+ util.printInfo(1, interact.data.get('user'), '合成音频...')
+
+
+ tm = time.time()
+
+
+ result = self.sp.to_sample(filtered_text, self.__get_mood_voice())
+
+
+ # 合成完成后再次检查会话是否仍有效,避免继续输出旧会话结果
+
+
+ try:
+
+
+ user_for_stop = interact.data.get("user", "User")
+
+
+ conv_id_for_stop = interact.data.get("conversation_id")
+
+
+ if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
+
+
+ return None
+
+
+ except Exception:
+
+
+ pass
+
+
+ util.printInfo(1, interact.data.get("user"), "合成音频完成. 耗时: {} ms 文件:{}".format(math.floor((time.time() - tm) * 1000), result))
+
+
+ else:
+
+
+ # prestart 内容不应该触发机器人表情重置
+
+
+ if is_end and not is_prestart_content and wsa_server.get_web_instance().is_connected(interact.data.get('user')):
+
+
+ wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'})
+
+
+
+
+
+ if result is not None or is_first or is_end:
+
+
+ # prestart 内容不需要进入音频处理流程
+
+
+ if is_prestart_content:
+
+
+ return result
+
+
+ if is_end:#TODO 临时方案:如果结束标记,则延迟1秒处理,免得is end比前面的音频tts要快
+
+
+ time.sleep(1)
+
+
+ MyThread(target=self.__process_output_audio, args=[result, interact, text]).start()
+
+
+ return result
+
+
+
+
+
+ except BaseException as e:
+
+
+ print(e)
+
+
+ return None
+
+
+
+
+
+ #下载wav
+
+
+ def download_wav(self, url, save_directory, filename):
+
+
+ try:
+
+
+ # 发送HTTP GET请求以获取WAV文件内容
+
+
+ response = requests.get(url, stream=True)
+
+
+ response.raise_for_status() # 检查请求是否成功
+
+
+
+
+
+ # 确保保存目录存在
+
+
+ if not os.path.exists(save_directory):
+
+
+ os.makedirs(save_directory)
+
+
+
+
+
+ # 构建保存文件的路径
+
+
+ save_path = os.path.join(save_directory, filename)
+
+
+
+
+
+ # 将WAV文件内容保存到指定文件
+
+
+ with open(save_path, 'wb') as f:
+
+
+ for chunk in response.iter_content(chunk_size=1024):
+
+
+ if chunk:
+
+
+ f.write(chunk)
+
+
+
+
+
+ return save_path
+
+
+ except requests.exceptions.RequestException as e:
+
+
+ print(f"[Error] Failed to download file: {e}")
+
+
+ return None
+
+
+
+
+
+
+
+
+ #面板播放声音
+
+
+ def __play_sound(self):
+
+
+ try:
+
+
+ import pygame
+
+
+ pygame.mixer.init() # 初始化pygame.mixer,只需要在此处初始化一次, 如果初始化失败,则不播放音频
+
+
+ except Exception as e:
+
+
+ util.printInfo(1, "System", "音频播放初始化失败,本机无法播放音频")
+
+
+ return
+
+
+
+
+
+ while self.__running:
+
+
+ time.sleep(0.01)
+
+
+ if not self.sound_query.empty(): # 如果队列不为空则播放音频
+
+
+ file_url, audio_length, interact = self.sound_query.get()
+
+
+
+
+
+ is_first = interact.data.get('isfirst') is True
+
+
+ is_end = interact.data.get('isend') is True
+
+
+
+
+
+
+
+
+
+
+
+ if file_url is not None:
+
+
+ util.printInfo(1, interact.data.get('user'), '播放音频...')
+
+
+
+
+
+ if is_first:
+
+
+ self.speaking = True
+
+
+ elif not is_end:
+
+
+ self.speaking = True
+
+
+
+
+
+ #自动播报关闭
+
+
+ global auto_play_lock
+
+
+ global can_auto_play
+
+
+ with auto_play_lock:
+
+
+ if self.timer is not None:
+
+
+ self.timer.cancel()
+
+
+ self.timer = None
+
+
+ can_auto_play = False
+
+
+
+
+
+ if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
+
+
+ wsa_server.get_web_instance().add_cmd({"panelMsg": "播放中 ...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'})
+
+
+
+
+
+ if file_url is not None:
+
+
+ pygame.mixer.music.load(file_url)
+
+
+ pygame.mixer.music.play()
+
+
+
+
+
+ # 播放过程中计时,直到音频播放完毕
+
+
+ length = 0
+
+
+ while length < audio_length:
+
+
+ try:
+
+
+ user_for_stop = interact.data.get("user", "User")
+
+
+ conv_id_for_stop = interact.data.get("conversation_id")
+
+
+ if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
+
+
+ try:
+
+
+ pygame.mixer.music.stop()
+
+
+ except Exception:
+
+
+ pass
+
+
+ break
+
+
+ except Exception:
+
+
+ pass
+
+
+ length += 0.01
+
+
+ time.sleep(0.01)
+
+
+
+
+
+ if is_end:
+
+
+ self.play_end(interact)
+
+
+
+
+
+ if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
+
+
+ wsa_server.get_web_instance().add_cmd({"panelMsg": "", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'})
+
+
+ # 播放完毕后通知
+
+
+ if wsa_server.get_web_instance().is_connected(interact.data.get("user")):
+
+
+ wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username': interact.data.get('user')})
+
+
+
+
+
+ #推送远程音频
+
+
+ def __send_remote_device_audio(self, file_url, interact):
+
+
+ if file_url is None:
+
+
+ return
+
+
+ delkey = None
+
+
+ for key, value in fay_booter.DeviceInputListenerDict.items():
+
+
+ if value.username == interact.data.get("user") and value.isOutput: #按username选择推送,booter.devicelistenerdice按用户名记录
+
+
+ try:
+
+
+ value.deviceConnector.send(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08") # 发送音频开始标志,同时也检查设备是否在线
+
+
+ wavfile = open(os.path.abspath(file_url), "rb")
+
+
+ data = wavfile.read(102400)
+
+
+ total = 0
+
+
+ while data:
+
+
+ total += len(data)
+
+
+ value.deviceConnector.send(data)
+
+
+ data = wavfile.read(102400)
+
+
+ time.sleep(0.0001)
+
+
+ value.deviceConnector.send(b'\x08\x07\x06\x05\x04\x03\x02\x01\x00')# 发送音频结束标志
+
+
+ util.printInfo(1, value.username, "远程音频发送完成:{}".format(total))
+
+
+ except socket.error as serr:
+
+
+ util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key))
+
+
+ value.stop()
+
+
+ delkey = key
+
+
+ if delkey:
+
+
+ value = fay_booter.DeviceInputListenerDict.pop(delkey)
+
+
+ if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
+
+
+ wsa_server.get_web_instance().add_cmd({"remote_audio_connect": False, "Username" : interact.data.get('user')})
+
+
+
+
+
+ def __is_send_remote_device_audio(self, interact):
+
+
+ for key, value in fay_booter.DeviceInputListenerDict.items():
+
+
+ if value.username == interact.data.get("user") and value.isOutput:
+
+
+ return True
+
+
+ return False
+
+
+
+
+
+ #输出音频处理
+
+
+ def __process_output_audio(self, file_url, interact, text):
+
+
+ try:
+
+
+ # 会话有效性与中断检查(最早返回,避免向面板/数字人发送任何旧会话输出)
+
+
+ try:
+
+
+ user_for_stop = interact.data.get("user", "User")
+
+
+ conv_id_for_stop = interact.data.get("conversation_id")
+
+
+ if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
+
+
+ return
+
+
+ except Exception:
+
+
+ pass
+
+
+ try:
+
+
+ if file_url is None:
+
+
+ audio_length = 0
+
+
+ elif file_url.endswith('.wav'):
+
+
+ audio = AudioSegment.from_wav(file_url)
+
+
+ audio_length = len(audio) / 1000.0 # 时长以秒为单位
+
+
+ elif file_url.endswith('.mp3'):
+
+
+ audio = AudioSegment.from_mp3(file_url)
+
+
+ audio_length = len(audio) / 1000.0 # 时长以秒为单位
+
+
+ except Exception as e:
+
+
+ audio_length = 3
+
+
+
+
+
+ #推送远程音频
+
+
+ if file_url is not None:
+
+
+ MyThread(target=self.__send_remote_device_audio, args=[file_url, interact]).start()
+
+
+
+
+
+ #发送音频给数字人接口
+
+
+ if file_url is not None and wsa_server.get_instance().get_client_output(interact.data.get("user")):
+
+
+ # 使用 (username, conversation_id) 作为 key 获取会话信息
+
+
+ audio_username = interact.data.get("user", "User")
+
+
+ audio_conv_id = interact.data.get("conversation_id") or ""
+
+
+ audio_conv_info = self.user_conv_map.get((audio_username, audio_conv_id), {})
+
+
+ content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0, 'CONV_ID' : audio_conv_info.get("conversation_id", ""), 'CONV_MSG_NO' : audio_conv_info.get("conversation_msg_no", 0) }, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}
+
+
+ #计算lips
+
+
+ if platform.system() == "Windows":
+
+
+ try:
+
+
+ lip_sync_generator = LipSyncGenerator()
+
+
+ viseme_list = lip_sync_generator.generate_visemes(os.path.abspath(file_url))
+
+
+ consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
+
+
+ content["Data"]["Lips"] = consolidated_visemes
+
+
+ except Exception as e:
+
+
+ print(e)
+
+
+ util.printInfo(1, interact.data.get("user"), "唇型数据生成失败")
+
+
+ wsa_server.get_instance().add_cmd(content)
+
+
+ util.printInfo(1, interact.data.get("user"), "数字人接口发送音频数据成功")
+
+
+
+
+
+ #面板播放
+
+
+ config_util.load_config()
+
+
+ # 检查是否是 prestart 内容
+
+
+ is_prestart = self.__has_prestart(text)
+
+ if config_util.config["interact"]["playSound"]:
+
+
+ # prestart 内容不应该进入播放队列,避免触发 Normal 状态
+
+
+ if not is_prestart:
+
+
+ self.sound_query.put((file_url, audio_length, interact))
+
+
+ else:
+
+
+ # prestart 内容不应该重置机器人表情
+
+
+ if not is_prestart and wsa_server.get_web_instance().is_connected(interact.data.get('user')):
+
+
+ wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'})
+
+
+
+
+
+ except Exception as e:
+
+
+ print(e)
+
+
+
+
+
+ def play_end(self, interact):
+
+
+ self.speaking = False
+
+
+ global can_auto_play
+
+
+ global auto_play_lock
+
+
+ with auto_play_lock:
+
+
+ if self.timer:
+
+
+ self.timer.cancel()
+
+
+ self.timer = None
+
+
+ if interact.interleaver != 'auto_play': #交互后暂停自动播报30秒
+
+
+ self.timer = threading.Timer(30, self.set_auto_play)
+
+
+ self.timer.start()
+
+
+ else:
+
+
+ can_auto_play = True
+
+
+
+
+
+ #恢复自动播报(如果有)
+
+
+ def set_auto_play(self):
+
+
+ global auto_play_lock
+
+
+ global can_auto_play
+
+
+ with auto_play_lock:
+
+
+ can_auto_play = True
+
+
+ self.timer = None
+
+
+
+
+
+ #启动核心服务
+
+
+ def start(self):
+
+
+ MyThread(target=self.__play_sound).start()
+
+
+
+
+
+ #停止核心服务
+
+
+ def stop(self):
+
+
+ self.__running = False
+
+
+ self.speaking = False
+
+
+ self.sp.close()
+
+
+ wsa_server.get_web_instance().add_cmd({"panelMsg": ""})
+
+
+ content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': ""}}
+
+
+ wsa_server.get_instance().add_cmd(content)
+
+
+
+
+
+ def __record_response(self, text, username, uid):
+
+
+ """
+
+
+ 记录AI的回复内容
+
+
+ :param text: 回复文本
+
+
+ :param username: 用户名
+
+
+ :param uid: 用户ID
+
+
+ :return: content_id
+
+
+ """
+
+
+ self.write_to_file("./logs", "answer_result.txt", text)
+
+
+ return content_db.new_instance().add_content('fay', 'speak', text, username, uid)
+
+
+
+
+
+ def __remove_prestart_tags(self, text):
+
+
+ """
+
+
+ 移除文本中的 prestart 标签及其内容
+
+
+ :param text: 原始文本
+
+
+ :return: 移除 prestart 标签后的文本
+
+
+ """
+
+
+ if not text:
+
+
+ return text
+
+
+ import re
+
+
+ # 移除 ... 标签及其内容(支持属性)
+
+ cleaned = re.sub(r']*>[\s\S]*?', '', text, flags=re.IGNORECASE)
+
+ return cleaned.strip()
+
+
+
+ def __has_prestart(self, text):
+
+ """
+
+ 判断文本中是否包含 prestart 标签(支持属性)
+
+ """
+
+ if not text:
+
+ return False
+
+ return re.search(r']*>[\s\S]*?', text, flags=re.IGNORECASE) is not None
+
+
+
+
+
+ def __truncate_think_for_panel(self, text, uid, username):
+
+ if not text or not isinstance(text, str):
+
+ return text
+
+ key = uid if uid is not None else username
+
+ state = self.think_display_state.get(key)
+
+ if state is None:
+
+ state = {"in_think": False, "in_tool_output": False, "tool_count": 0, "tool_truncated": False}
+
+ self.think_display_state[key] = state
+
+ if not state["in_think"] and "" not in text and "" not in text:
+
+ return text
+
+ tool_output_regex = re.compile(r"\[TOOL\]\s*(?:Output|\u8f93\u51fa)[:\uff1a]", re.IGNORECASE)
+
+ section_regex = re.compile(r"(?i)(^|[\r\n])(\[(?:TOOL|PLAN)\])")
+
+ out = []
+
+ i = 0
+
+ while i < len(text):
+
+ if not state["in_think"]:
+
+ idx = text.find("", i)
+
+ if idx == -1:
+
+ out.append(text[i:])
+
+ break
+
+ out.append(text[i:idx + len("")])
+
+ state["in_think"] = True
+
+ i = idx + len("")
+
+ continue
+
+ if not state["in_tool_output"]:
+
+ think_end = text.find("", i)
+
+ tool_match = tool_output_regex.search(text, i)
+
+ next_pos = None
+
+ next_kind = None
+
+ if tool_match:
+
+ next_pos = tool_match.start()
+
+ next_kind = "tool"
+
+ if think_end != -1 and (next_pos is None or think_end < next_pos):
+
+ next_pos = think_end
+
+ next_kind = "think_end"
+
+ if next_pos is None:
+
+ out.append(text[i:])
+
+ break
+
+ if next_pos > i:
+
+ out.append(text[i:next_pos])
+
+ if next_kind == "think_end":
+
+ out.append("")
+
+ state["in_think"] = False
+
+ state["in_tool_output"] = False
+
+ state["tool_count"] = 0
+
+ state["tool_truncated"] = False
+
+ i = next_pos + len("")
+
+ else:
+
+ marker_end = tool_match.end()
+
+ out.append(text[next_pos:marker_end])
+
+ state["in_tool_output"] = True
+
+ state["tool_count"] = 0
+
+ state["tool_truncated"] = False
+
+ i = marker_end
+
+ continue
+
+ think_end = text.find("", i)
+
+ section_match = section_regex.search(text, i)
+
+ end_pos = None
+
+ if section_match:
+
+ end_pos = section_match.start(2)
+
+ if think_end != -1 and (end_pos is None or think_end < end_pos):
+
+ end_pos = think_end
+
+ segment = text[i:] if end_pos is None else text[i:end_pos]
+
+ if segment:
+
+ if state["tool_truncated"]:
+
+ pass
+
+ else:
+
+ remaining = self.think_display_limit - state["tool_count"]
+
+ if remaining <= 0:
+
+ out.append("...")
+
+ state["tool_truncated"] = True
+
+ elif len(segment) <= remaining:
+
+ out.append(segment)
+
+ state["tool_count"] += len(segment)
+
+ else:
+
+ out.append(segment[:remaining] + "...")
+
+ state["tool_count"] += remaining
+
+ state["tool_truncated"] = True
+
+ if end_pos is None:
+
+ break
+
+ state["in_tool_output"] = False
+
+ state["tool_count"] = 0
+
+ state["tool_truncated"] = False
+
+ i = end_pos
+
+ return "".join(out)
+
+ def __send_panel_message(self, text, username, uid, content_id=None, type=None):
+
+
+ """
+
+
+ 发送消息到Web面板
+
+
+ :param text: 消息文本
+
+
+ :param username: 用户名
+
+
+ :param uid: 用户ID
+
+
+ :param content_id: 内容ID
+
+
+ :param type: 消息类型
+
+
+ """
+
+
+ if not wsa_server.get_web_instance().is_connected(username):
+
+
+ return
+
+
+
+
+
+ # 检查是否是 prestart 内容,prestart 内容不应该更新日志区消息
+
+
+ # 因为这会覆盖掉"思考中..."的状态显示
+
+
+ is_prestart = self.__has_prestart(text)
+ display_text = self.__truncate_think_for_panel(text, uid, username)
+
+
+
+
+ # gui日志区消息(prestart 内容跳过,保持"思考中..."状态)
+
+
+ if not is_prestart:
+
+
+ wsa_server.get_web_instance().add_cmd({
+
+
+ "panelMsg": display_text,
+
+
+ "Username": username
+
+
+ })
+
+
+
+
+
+ # 聊天窗消息
+
+
+ if content_id is not None:
+
+
+ wsa_server.get_web_instance().add_cmd({
+
+
+ "panelReply": {
+
+
+ "type": "fay",
+
+
+ "content": display_text,
+
+
+ "username": username,
+
+
+ "uid": uid,
+
+
+ "id": content_id,
+
+
+ "is_adopted": type == 'qa'
+
+
+ },
+
+
+ "Username": username
+
+
+ })
+
+
+
+
+
+ def __send_digital_human_message(self, text, username, is_first=False, is_end=False):
+
+
+ """
+
+
+ 发送消息到数字人(语音应该在say方法驱动数字人输出)
+
+
+ :param text: 消息文本
+
+
+ :param username: 用户名
+
+
+ :param is_first: 是否是第一段文本
+
+
+ :param is_end: 是否是最后一段文本
+
+
+ """
+
+
+ # 移除 prestart 标签内容,不发送给数字人
+
+
+ cleaned_text = self.__remove_prestart_tags(text) if text else ""
+
+
+ full_text = self.__remove_emojis(cleaned_text.replace("*", "")) if cleaned_text else ""
+
+
+
+
+
+ # 如果文本为空且不是结束标记,则不发送,但需保留 is_first
+
+ if not full_text and not is_end:
+
+ if is_first:
+
+ self.pending_isfirst[username] = True
+
+ return
+
+
+
+ # 检查是否有延迟的 is_first 需要应用
+
+ if self.pending_isfirst.get(username, False):
+
+ is_first = True
+
+ self.pending_isfirst[username] = False
+
+
+
+
+ if wsa_server.get_instance().is_connected(username):
+
+
+ content = {
+
+
+ 'Topic': 'human',
+
+
+ 'Data': {
+
+
+ 'Key': 'text',
+
+
+ 'Value': full_text,
+
+
+ 'IsFirst': 1 if is_first else 0,
+
+
+ 'IsEnd': 1 if is_end else 0
+
+
+ },
+
+
+ 'Username': username
+
+
+ }
+
+
+ wsa_server.get_instance().add_cmd(content)
+
+
+
+
+
+ def __process_text_output(self, text, username, uid, content_id, type, is_first=False, is_end=False):
+
+
+ """
+
+
+ 完整文本输出到各个终端
+
+
+ :param text: 主要回复文本
+
+
+ :param textlist: 额外回复列表
+
+
+ :param username: 用户名
+
+
+ :param uid: 用户ID
+
+
+ :param type: 消息类型
+
+
+ :param is_first: 是否是第一段文本
+
+
+ :param is_end: 是否是最后一段文本
+
+
+ """
+
+
+ if text:
+
+
+ text = text.strip()
+
+
+
+
+
+ # 记录主回复
+
+
+ # content_id = self.__record_response(text, username, uid)
+
+
+
+
+
+ # 发送主回复到面板和数字人
+
+
+ self.__send_panel_message(text, username, uid, content_id, type)
+
+
+ self.__send_digital_human_message(text, username, is_first, is_end)
+
+
+
+
+
+ # 打印日志
+
+
+ util.printInfo(1, username, '({}) {}'.format("llm", text))
+
+
+
+
+
+import importlib
+
+
+fay_booter = importlib.import_module('fay_booter')
+
+
+
+
-
- # gui日志区消息(prestart 内容跳过,保持"思考中..."状态)
- if not is_prestart:
- wsa_server.get_web_instance().add_cmd({
- "panelMsg": display_text,
- "Username": username
- })
-
- # 聊天窗消息
- if content_id is not None:
- wsa_server.get_web_instance().add_cmd({
- "panelReply": {
- "type": "fay",
- "content": display_text,
- "username": username,
- "uid": uid,
- "id": content_id,
- "is_adopted": type == 'qa'
- },
- "Username": username
- })
-
- def __send_digital_human_message(self, text, username, is_first=False, is_end=False):
- """
- 发送消息到数字人(语音应该在say方法驱动数字人输出)
- :param text: 消息文本
- :param username: 用户名
- :param is_first: 是否是第一段文本
- :param is_end: 是否是最后一段文本
- """
- # 移除 prestart 标签内容,不发送给数字人
- cleaned_text = self.__remove_prestart_tags(text) if text else ""
- full_text = self.__remove_emojis(cleaned_text.replace("*", "")) if cleaned_text else ""
-
- # 如果文本为空且不是结束标记,则不发送,但需保留 is_first
- if not full_text and not is_end:
- if is_first:
- self.pending_isfirst[username] = True
- return
-
- # 检查是否有延迟的 is_first 需要应用
- if self.pending_isfirst.get(username, False):
- is_first = True
- self.pending_isfirst[username] = False
-
- if wsa_server.get_instance().is_connected(username):
- content = {
- 'Topic': 'human',
- 'Data': {
- 'Key': 'text',
- 'Value': full_text,
- 'IsFirst': 1 if is_first else 0,
- 'IsEnd': 1 if is_end else 0
- },
- 'Username': username
- }
- wsa_server.get_instance().add_cmd(content)
-
- def __process_text_output(self, text, username, uid, content_id, type, is_first=False, is_end=False):
- """
- 完整文本输出到各个终端
- :param text: 主要回复文本
- :param textlist: 额外回复列表
- :param username: 用户名
- :param uid: 用户ID
- :param type: 消息类型
- :param is_first: 是否是第一段文本
- :param is_end: 是否是最后一段文本
- """
- if text:
- text = text.strip()
-
- # 记录主回复
- # content_id = self.__record_response(text, username, uid)
-
- # 发送主回复到面板和数字人
- self.__send_panel_message(text, username, uid, content_id, type)
- self.__send_digital_human_message(text, username, is_first, is_end)
-
- # 打印日志
- util.printInfo(1, username, '({}) {}'.format("llm", text))
-
-import importlib
-fay_booter = importlib.import_module('fay_booter')
-
diff --git a/genagents/modules/memory_stream.py b/genagents/modules/memory_stream.py
index 360c0b6..27ca8bd 100644
--- a/genagents/modules/memory_stream.py
+++ b/genagents/modules/memory_stream.py
@@ -1,368 +1,381 @@
-import math
-import sys
-import datetime
-import random
-import string
-import re
-
-from numpy import dot
-from numpy.linalg import norm
-
-from simulation_engine.settings import *
-from simulation_engine.global_methods import *
-from simulation_engine.gpt_structure import *
-from simulation_engine.llm_json_parser import *
-
-
-def run_gpt_generate_importance(
- records,
- prompt_version="1",
- gpt_version="GPT4o",
- verbose=False):
-
- def create_prompt_input(records):
- records_str = ""
- for count, r in enumerate(records):
- records_str += f"Item {str(count+1)}:\n"
- records_str += f"{r}\n"
- return [records_str]
-
- def _func_clean_up(gpt_response, prompt=""):
- gpt_response = extract_first_json_dict(gpt_response)
- # 处理gpt_response为None的情况
- if gpt_response is None:
- print("警告: extract_first_json_dict返回None,使用默认值")
- return [50] # 返回默认重要性分数
- return list(gpt_response.values())
-
- def _get_fail_safe():
- return 25
-
- if len(records) > 1:
- prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/batch_v1.txt"
- else:
- prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/singular_v1.txt"
-
- prompt_input = create_prompt_input(records)
- fail_safe = _get_fail_safe()
-
- output, prompt, prompt_input, fail_safe = chat_safe_generate(
- prompt_input, prompt_lib_file, gpt_version, 1, fail_safe,
- _func_clean_up, verbose)
-
- return output, [output, prompt, prompt_input, fail_safe]
-
-
-def generate_importance_score(records):
- return run_gpt_generate_importance(records, "1", LLM_VERS)[0]
-
-
-def run_gpt_generate_reflection(
- records,
- anchor,
- reflection_count,
- prompt_version="1",
- gpt_version="GPT4o",
- verbose=False):
-
- def create_prompt_input(records, anchor, reflection_count):
- records_str = ""
- for count, r in enumerate(records):
- records_str += f"Item {str(count+1)}:\n"
- records_str += f"{r}\n"
- return [records_str, reflection_count, anchor]
-
- def _func_clean_up(gpt_response, prompt=""):
- return extract_first_json_dict(gpt_response)["reflection"]
-
- def _get_fail_safe():
- return []
-
- if reflection_count > 1:
- prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/batch_v1.txt"
- else:
- prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/singular_v1.txt"
-
- prompt_input = create_prompt_input(records, anchor, reflection_count)
- fail_safe = _get_fail_safe()
-
- output, prompt, prompt_input, fail_safe = chat_safe_generate(
- prompt_input, prompt_lib_file, gpt_version, 1, fail_safe,
- _func_clean_up, verbose)
-
- return output, [output, prompt, prompt_input, fail_safe]
-
-
-def generate_reflection(records, anchor, reflection_count):
- records = [i.content for i in records]
- return run_gpt_generate_reflection(records, anchor, reflection_count, "1",
- LLM_VERS)[0]
-
-
-# ##############################################################################
-# ### HELPER FUNCTIONS FOR GENERATIVE AGENTS ###
-# ##############################################################################
-
-def get_random_str(length):
- """
- Generates a random string of alphanumeric characters with the specified
- length. This function creates a random string by selecting characters from
- the set of uppercase letters, lowercase letters, and digits. The length of
- the random string is determined by the 'length' parameter.
-
- Parameters:
- length (int): The desired length of the random string.
- Returns:
- random_string: A randomly generated string of the specified length.
-
- Example:
- >>> get_random_str(8)
- 'aB3R7tQ2'
- """
- characters = string.ascii_letters + string.digits
- random_string = ''.join(random.choice(characters) for _ in range(length))
- return random_string
-
-
-def cos_sim(a, b):
- """
- This function calculates the cosine similarity between two input vectors
- 'a' and 'b'. Cosine similarity is a measure of similarity between two
- non-zero vectors of an inner product space that measures the cosine
- of the angle between them.
-
- Parameters:
- a: 1-D array object
- b: 1-D array object
- Returns:
- A scalar value representing the cosine similarity between the input
- vectors 'a' and 'b'.
-
- Example:
- >>> a = [0.3, 0.2, 0.5]
- >>> b = [0.2, 0.2, 0.5]
- >>> cos_sim(a, b)
- """
- return dot(a, b)/(norm(a)*norm(b))
-
-
-def normalize_dict_floats(d, target_min, target_max):
- """
- This function normalizes the float values of a given dictionary 'd' between
- a target minimum and maximum value. The normalization is done by scaling the
- values to the target range while maintaining the same relative proportions
- between the original values.
-
- Parameters:
- d: Dictionary. The input dictionary whose float values need to be
- normalized.
- target_min: Integer or float. The minimum value to which the original
- values should be scaled.
- target_max: Integer or float. The maximum value to which the original
- values should be scaled.
- Returns:
- d: A new dictionary with the same keys as the input but with the float
- values normalized between the target_min and target_max.
-
- Example:
- >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8}
- >>> target_min = -5
- >>> target_max = 5
- >>> normalize_dict_floats(d, target_min, target_max)
- """
- # 检查字典是否为None或为空
- if d is None:
- print("警告: normalize_dict_floats接收到None字典")
- return {}
-
- if not d:
- print("警告: normalize_dict_floats接收到空字典")
- return {}
-
- try:
- min_val = min(val for val in d.values())
- max_val = max(val for val in d.values())
- range_val = max_val - min_val
-
- if range_val == 0:
- for key, val in d.items():
- d[key] = (target_max - target_min)/2
- else:
- for key, val in d.items():
- d[key] = ((val - min_val) * (target_max - target_min)
- / range_val + target_min)
- return d
- except Exception as e:
- print(f"normalize_dict_floats处理字典时出错: {str(e)}")
- # 返回原始字典,避免处理失败
- return d
-
-
-def top_highest_x_values(d, x):
- """
- This function takes a dictionary 'd' and an integer 'x' as input, and
- returns a new dictionary containing the top 'x' key-value pairs from the
- input dictionary 'd' with the highest values.
-
- Parameters:
- d: Dictionary. The input dictionary from which the top 'x' key-value pairs
- with the highest values are to be extracted.
- x: Integer. The number of top key-value pairs with the highest values to
- be extracted from the input dictionary.
- Returns:
- A new dictionary containing the top 'x' key-value pairs from the input
- dictionary 'd' with the highest values.
-
- Example:
- >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8}
- >>> x = 3
- >>> top_highest_x_values(d, x)
- """
- top_v = dict(sorted(d.items(),
- key=lambda item: item[1],
- reverse=True)[:x])
- return top_v
-
-
-def extract_recency(seq_nodes):
- """
- Gets the current Persona object and a list of nodes that are in a
- chronological order, and outputs a dictionary that has the recency score
- calculated.
-
- Parameters:
- nodes: A list of Node object in a chronological order.
- Returns:
- recency_out: A dictionary whose keys are the node.node_id and whose values
- are the float that represents the recency score.
- """
- # 检查seq_nodes是否为None或为空
- if seq_nodes is None:
- print("警告: extract_recency接收到None节点列表")
- return {}
-
- if not seq_nodes:
- print("警告: extract_recency接收到空节点列表")
- return {}
-
- try:
- # 确保所有的last_retrieved都是整数类型
- normalized_timestamps = []
- for node in seq_nodes:
- if node is None:
- print("警告: 节点为None,跳过")
- continue
-
- if not hasattr(node, 'last_retrieved'):
- print(f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0")
- normalized_timestamps.append(0)
- continue
-
- if isinstance(node.last_retrieved, str):
- try:
- normalized_timestamps.append(int(node.last_retrieved))
- except ValueError:
- # 如果无法转换为整数,使用0作为默认值
- normalized_timestamps.append(0)
- else:
- normalized_timestamps.append(node.last_retrieved)
-
- if not normalized_timestamps:
- return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
-
- max_timestep = max(normalized_timestamps)
-
- recency_decay = 0.99
- recency_out = dict()
- for count, node in enumerate(seq_nodes):
- if node is None or not hasattr(node, 'node_id') or not hasattr(node, 'last_retrieved'):
- continue
-
- # 获取标准化后的时间戳
- try:
- last_retrieved = normalized_timestamps[count]
- recency_out[node.node_id] = (recency_decay
- ** (max_timestep - last_retrieved))
- except Exception as e:
- print(f"计算节点 {node.node_id} 的recency时出错: {str(e)}")
- # 使用默认值
- recency_out[node.node_id] = 1.0
-
- return recency_out
- except Exception as e:
- print(f"extract_recency处理节点列表时出错: {str(e)}")
- # 返回一个默认字典
- return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
-
-
+import math
+import sys
+import datetime
+import random
+import string
+import re
+
+from numpy import dot
+from numpy.linalg import norm
+
+from simulation_engine.settings import *
+from simulation_engine.global_methods import *
+from simulation_engine.gpt_structure import *
+from simulation_engine.llm_json_parser import *
+
+
+def run_gpt_generate_importance(
+ records,
+ prompt_version="1",
+ gpt_version="GPT4o",
+ verbose=False):
+
+ def create_prompt_input(records):
+ records_str = ""
+ for count, r in enumerate(records):
+ records_str += f"Item {str(count+1)}:\n"
+ records_str += f"{r}\n"
+ return [records_str]
+
+ def _func_clean_up(gpt_response, prompt=""):
+ gpt_response = extract_first_json_dict(gpt_response)
+ # 处理gpt_response为None的情况
+ if gpt_response is None:
+ print("警告: extract_first_json_dict返回None,使用默认值")
+ return [50] # 返回默认重要性分数
+ return list(gpt_response.values())
+
+ def _get_fail_safe():
+ return 25
+
+ if len(records) > 1:
+ prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/batch_v1.txt"
+ else:
+ prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/singular_v1.txt"
+
+ prompt_input = create_prompt_input(records)
+ fail_safe = _get_fail_safe()
+
+ output, prompt, prompt_input, fail_safe = chat_safe_generate(
+ prompt_input, prompt_lib_file, gpt_version, 1, fail_safe,
+ _func_clean_up, verbose)
+
+ return output, [output, prompt, prompt_input, fail_safe]
+
+
+def generate_importance_score(records):
+ return run_gpt_generate_importance(records, "1", LLM_VERS)[0]
+
+
+def run_gpt_generate_reflection(
+ records,
+ anchor,
+ reflection_count,
+ prompt_version="1",
+ gpt_version="GPT4o",
+ verbose=False):
+
+ def create_prompt_input(records, anchor, reflection_count):
+ records_str = ""
+ for count, r in enumerate(records):
+ records_str += f"Item {str(count+1)}:\n"
+ records_str += f"{r}\n"
+ return [records_str, reflection_count, anchor]
+
+ def _func_clean_up(gpt_response, prompt=""):
+ return extract_first_json_dict(gpt_response)["reflection"]
+
+ def _get_fail_safe():
+ return []
+
+ if reflection_count > 1:
+ prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/batch_v1.txt"
+ else:
+ prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/singular_v1.txt"
+
+ prompt_input = create_prompt_input(records, anchor, reflection_count)
+ fail_safe = _get_fail_safe()
+
+ output, prompt, prompt_input, fail_safe = chat_safe_generate(
+ prompt_input, prompt_lib_file, gpt_version, 1, fail_safe,
+ _func_clean_up, verbose)
+
+ return output, [output, prompt, prompt_input, fail_safe]
+
+
+def generate_reflection(records, anchor, reflection_count):
+ records = [i.content for i in records]
+ return run_gpt_generate_reflection(records, anchor, reflection_count, "1",
+ LLM_VERS)[0]
+
+
+# ##############################################################################
+# ### HELPER FUNCTIONS FOR GENERATIVE AGENTS ###
+# ##############################################################################
+
+def get_random_str(length):
+ """
+ Generates a random string of alphanumeric characters with the specified
+ length. This function creates a random string by selecting characters from
+ the set of uppercase letters, lowercase letters, and digits. The length of
+ the random string is determined by the 'length' parameter.
+
+ Parameters:
+ length (int): The desired length of the random string.
+ Returns:
+ random_string: A randomly generated string of the specified length.
+
+ Example:
+ >>> get_random_str(8)
+ 'aB3R7tQ2'
+ """
+ characters = string.ascii_letters + string.digits
+ random_string = ''.join(random.choice(characters) for _ in range(length))
+ return random_string
+
+
+def cos_sim(a, b):
+ """
+ This function calculates the cosine similarity between two input vectors
+ 'a' and 'b'. Cosine similarity is a measure of similarity between two
+ non-zero vectors of an inner product space that measures the cosine
+ of the angle between them.
+
+ Parameters:
+ a: 1-D array object
+ b: 1-D array object
+ Returns:
+ A scalar value representing the cosine similarity between the input
+ vectors 'a' and 'b'.
+
+ Example:
+ >>> a = [0.3, 0.2, 0.5]
+ >>> b = [0.2, 0.2, 0.5]
+ >>> cos_sim(a, b)
+ """
+ return dot(a, b)/(norm(a)*norm(b))
+
+
+def normalize_dict_floats(d, target_min, target_max):
+ """
+ This function normalizes the float values of a given dictionary 'd' between
+ a target minimum and maximum value. The normalization is done by scaling the
+ values to the target range while maintaining the same relative proportions
+ between the original values.
+
+ Parameters:
+ d: Dictionary. The input dictionary whose float values need to be
+ normalized.
+ target_min: Integer or float. The minimum value to which the original
+ values should be scaled.
+ target_max: Integer or float. The maximum value to which the original
+ values should be scaled.
+ Returns:
+ d: A new dictionary with the same keys as the input but with the float
+ values normalized between the target_min and target_max.
+
+ Example:
+ >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8}
+ >>> target_min = -5
+ >>> target_max = 5
+ >>> normalize_dict_floats(d, target_min, target_max)
+ """
+ # 检查字典是否为None或为空
+ if d is None:
+ print("警告: normalize_dict_floats接收到None字典")
+ return {}
+
+ if not d:
+ print("警告: normalize_dict_floats接收到空字典")
+ return {}
+
+ try:
+ min_val = min(val for val in d.values())
+ max_val = max(val for val in d.values())
+ range_val = max_val - min_val
+
+ if range_val == 0:
+ for key, val in d.items():
+ d[key] = (target_max - target_min)/2
+ else:
+ for key, val in d.items():
+ d[key] = ((val - min_val) * (target_max - target_min)
+ / range_val + target_min)
+ return d
+ except Exception as e:
+ print(f"normalize_dict_floats处理字典时出错: {str(e)}")
+ # 返回原始字典,避免处理失败
+ return d
+
+
+def top_highest_x_values(d, x):
+ """
+ This function takes a dictionary 'd' and an integer 'x' as input, and
+ returns a new dictionary containing the top 'x' key-value pairs from the
+ input dictionary 'd' with the highest values.
+
+ Parameters:
+ d: Dictionary. The input dictionary from which the top 'x' key-value pairs
+ with the highest values are to be extracted.
+ x: Integer. The number of top key-value pairs with the highest values to
+ be extracted from the input dictionary.
+ Returns:
+ A new dictionary containing the top 'x' key-value pairs from the input
+ dictionary 'd' with the highest values.
+
+ Example:
+ >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8}
+ >>> x = 3
+ >>> top_highest_x_values(d, x)
+ """
+ top_v = dict(sorted(d.items(),
+ key=lambda item: item[1],
+ reverse=True)[:x])
+ return top_v
+
+
+def extract_recency(seq_nodes):
+ """
+ Gets the current Persona object and a list of nodes that are in a
+ chronological order, and outputs a dictionary that has the recency score
+ calculated.
+
+ Parameters:
+ nodes: A list of Node object in a chronological order.
+ Returns:
+ recency_out: A dictionary whose keys are the node.node_id and whose values
+ are the float that represents the recency score.
+ """
+ # 检查seq_nodes是否为None或为空
+ if seq_nodes is None:
+ print("警告: extract_recency接收到None节点列表")
+ return {}
+
+ if not seq_nodes:
+ print("警告: extract_recency接收到空节点列表")
+ return {}
+
+ try:
+ # 确保所有的last_retrieved都是整数类型
+ normalized_timestamps = []
+ for node in seq_nodes:
+ if node is None:
+ print("警告: 节点为None,跳过")
+ continue
+
+ if not hasattr(node, 'last_retrieved'):
+ print(f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0")
+ normalized_timestamps.append(0)
+ continue
+
+ if isinstance(node.last_retrieved, str):
+ try:
+ normalized_timestamps.append(int(node.last_retrieved))
+ except ValueError:
+ # 如果无法转换为整数,使用0作为默认值
+ normalized_timestamps.append(0)
+ else:
+ normalized_timestamps.append(node.last_retrieved)
+
+ if not normalized_timestamps:
+ return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
+
+ max_timestep = max(normalized_timestamps)
+
+ recency_decay = 0.99
+ recency_out = dict()
+ for count, node in enumerate(seq_nodes):
+ if node is None or not hasattr(node, 'node_id') or not hasattr(node, 'last_retrieved'):
+ continue
+
+ # 获取标准化后的时间戳
+ try:
+ last_retrieved = normalized_timestamps[count]
+ recency_out[node.node_id] = (recency_decay
+ ** (max_timestep - last_retrieved))
+ except Exception as e:
+ print(f"计算节点 {node.node_id} 的recency时出错: {str(e)}")
+ # 使用默认值
+ recency_out[node.node_id] = 1.0
+
+ return recency_out
+ except Exception as e:
+ print(f"extract_recency处理节点列表时出错: {str(e)}")
+ # 返回一个默认字典
+ return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
+
+
def extract_importance(seq_nodes):
- """
- Gets the current Persona object and a list of nodes that are in a
- chronological order, and outputs a dictionary that has the importance score
- calculated.
-
- Parameters:
- seq_nodes: A list of Node object in a chronological order.
- Returns:
- importance_out: A dictionary whose keys are the node.node_id and whose
- values are the float that represents the importance score.
- """
- # 检查seq_nodes是否为None或为空
- if seq_nodes is None:
- print("警告: extract_importance接收到None节点列表")
- return {}
-
- if not seq_nodes:
- print("警告: extract_importance接收到空节点列表")
- return {}
-
- try:
- importance_out = dict()
- for count, node in enumerate(seq_nodes):
- if node is None:
- print("警告: 节点为None,跳过")
- continue
-
- if not hasattr(node, 'node_id') or not hasattr(node, 'importance'):
- print(f"警告: 节点缺少必要属性,跳过")
- continue
-
- # 确保importance是数值类型
- if isinstance(node.importance, str):
- try:
- importance_out[node.node_id] = float(node.importance)
- except ValueError:
- # 如果无法转换为数值,使用默认值
- print(f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值")
- importance_out[node.node_id] = 50.0
- else:
- importance_out[node.node_id] = node.importance
-
- return importance_out
+ """
+ Gets the current Persona object and a list of nodes that are in a
+ chronological order, and outputs a dictionary that has the importance score
+ calculated.
+
+ Parameters:
+ seq_nodes: A list of Node object in a chronological order.
+ Returns:
+ importance_out: A dictionary whose keys are the node.node_id and whose
+ values are the float that represents the importance score.
+ """
+ # 检查seq_nodes是否为None或为空
+ if seq_nodes is None:
+ print("警告: extract_importance接收到None节点列表")
+ return {}
+
+ if not seq_nodes:
+ print("警告: extract_importance接收到空节点列表")
+ return {}
+
+ try:
+ importance_out = dict()
+ for count, node in enumerate(seq_nodes):
+ if node is None:
+ print("警告: 节点为None,跳过")
+ continue
+
+ if not hasattr(node, 'node_id') or not hasattr(node, 'importance'):
+ print(f"警告: 节点缺少必要属性,跳过")
+ continue
+
+ # 确保importance是数值类型
+ if isinstance(node.importance, str):
+ try:
+ importance_out[node.node_id] = float(node.importance)
+ except ValueError:
+ # 如果无法转换为数值,使用默认值
+ print(f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值")
+ importance_out[node.node_id] = 50.0
+ else:
+ importance_out[node.node_id] = node.importance
+
+ return importance_out
except Exception as e:
print(f"extract_importance处理节点列表时出错: {str(e)}")
# 返回一个默认字典
return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
-def extract_relevance(seq_nodes, embeddings, focal_pt):
- """
- Gets the current Persona object, a list of seq_nodes that are in a
- chronological order, and the focal_pt string and outputs a dictionary
- that has the relevance score calculated.
+def _is_valid_embedding(vec, expected_dim):
+ if vec is None:
+ return False
+ if not isinstance(vec, (list, tuple)):
+ return False
+ if expected_dim is not None and len(vec) != expected_dim:
+ return False
+ for val in vec:
+ if not isinstance(val, (int, float)):
+ return False
+ return True
- Parameters:
- seq_nodes: A list of Node object in a chronological order.
- focal_pt: A string describing the current thought of revent of focus.
- Returns:
- relevance_out: A dictionary whose keys are the node.node_id and whose
- values are the float that represents the relevance score.
- """
- # 确保embeddings不为None
- if embeddings is None:
- print("警告: embeddings为None,使用空字典代替")
- embeddings = {}
-
+
+def extract_relevance(seq_nodes, embeddings, focal_pt):
+ """
+ Gets the current Persona object, a list of seq_nodes that are in a
+ chronological order, and the focal_pt string and outputs a dictionary
+ that has the relevance score calculated.
+
+ Parameters:
+ seq_nodes: A list of Node object in a chronological order.
+ focal_pt: A string describing the current thought of revent of focus.
+ Returns:
+ relevance_out: A dictionary whose keys are the node.node_id and whose
+ values are the float that represents the relevance score.
+ """
+ # 确保embeddings不为None
+ if embeddings is None:
+ print("警告: embeddings为None,使用空字典代替")
+ embeddings = {}
+
try:
focal_embedding = get_text_embedding(focal_pt)
except Exception as e:
@@ -370,237 +383,260 @@ def extract_relevance(seq_nodes, embeddings, focal_pt):
# 如果无法获取嵌入向量,返回默认值
return {node.node_id: 0.5 for node in seq_nodes}
+ expected_dim = len(focal_embedding) if isinstance(focal_embedding, (list, tuple)) else None
relevance_out = dict()
for count, node in enumerate(seq_nodes):
try:
# 检查节点内容是否在embeddings中
if node.content in embeddings:
node_embedding = embeddings[node.content]
+ if not _is_valid_embedding(node_embedding, expected_dim):
+ try:
+ regenerated = get_text_embedding(node.content)
+ if _is_valid_embedding(regenerated, expected_dim):
+ embeddings[node.content] = regenerated
+ node_embedding = regenerated
+ else:
+ print("Warning: regenerated embedding has unexpected dimension, using default relevance")
+ node_embedding = None
+ except Exception as e:
+ print(f"Regenerate embedding failed: {str(e)}")
+ node_embedding = None
# 计算余弦相似度
- relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding)
+ if node_embedding is None:
+ relevance_out[node.node_id] = 0.5
+ else:
+ relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding)
else:
# 如果没有对应的嵌入向量,使用默认值
relevance_out[node.node_id] = 0.5
- except Exception as e:
- print(f"计算节点 {node.node_id} 的相关性时出错: {str(e)}")
- # 如果计算过程中出错,使用默认值
- relevance_out[node.node_id] = 0.5
-
- return relevance_out
-
-
-# ##############################################################################
-# ### CONCEPT NODE ###
-# ##############################################################################
-
-class ConceptNode:
- def __init__(self, node_dict):
- # Loading the content of a memory node in the memory stream.
+ except Exception as e:
+ print(f"计算节点 {node.node_id} 的相关性时出错: {str(e)}")
+ # 如果计算过程中出错,使用默认值
+ relevance_out[node.node_id] = 0.5
+
+ return relevance_out
+
+
+# ##############################################################################
+# ### CONCEPT NODE ###
+# ##############################################################################
+
+class ConceptNode:
+ def __init__(self, node_dict):
+ # Loading the content of a memory node in the memory stream.
self.node_id = node_dict["node_id"]
self.node_type = node_dict["node_type"]
self.content = node_dict["content"]
self.importance = node_dict["importance"]
+ self.datetime = node_dict.get("datetime", "")
# 确保created是整数类型
- self.created = int(node_dict["created"]) if node_dict["created"] is not None else 0
- # 确保last_retrieved是整数类型
- self.last_retrieved = int(node_dict["last_retrieved"]) if node_dict["last_retrieved"] is not None else 0
- self.pointer_id = node_dict["pointer_id"]
-
-
- def package(self):
- """
- Packaging the ConceptNode
-
- Parameters:
- None
- Returns:
- packaged dictionary
- """
- curr_package = {}
- curr_package["node_id"] = self.node_id
- curr_package["node_type"] = self.node_type
+ self.created = int(node_dict["created"]) if node_dict["created"] is not None else 0
+ # 确保last_retrieved是整数类型
+ self.last_retrieved = int(node_dict["last_retrieved"]) if node_dict["last_retrieved"] is not None else 0
+ self.pointer_id = node_dict["pointer_id"]
+
+
+ def package(self):
+ """
+ Packaging the ConceptNode
+
+ Parameters:
+ None
+ Returns:
+ packaged dictionary
+ """
+ curr_package = {}
+ curr_package["node_id"] = self.node_id
+ curr_package["node_type"] = self.node_type
curr_package["content"] = self.content
curr_package["importance"] = self.importance
+ curr_package["datetime"] = self.datetime
curr_package["created"] = self.created
- curr_package["last_retrieved"] = self.last_retrieved
- curr_package["pointer_id"] = self.pointer_id
-
- return curr_package
-
-
-# ##############################################################################
-# ### MEMORY STREAM ###
-# ##############################################################################
-
-class MemoryStream:
- def __init__(self, nodes, embeddings):
- # Loading the memory stream for the agent.
- self.seq_nodes = []
- self.id_to_node = dict()
- for node in nodes:
- new_node = ConceptNode(node)
- self.seq_nodes += [new_node]
- self.id_to_node[new_node.node_id] = new_node
-
- self.embeddings = embeddings
-
-
- def count_observations(self):
- """
- Counting the number of observations (basically, the number of all nodes in
- memory stream except for the reflections)
-
- Parameters:
- None
- Returns:
- Count
- """
- count = 0
- for i in self.seq_nodes:
- if i.node_type == "observation":
- count += 1
- return count
-
-
- def retrieve(self, focal_points, time_step, n_count=120, curr_filter="all",
- hp=[0, 1, 0.5], stateless=False, verbose=False):
- """
- Retrieve elements from the memory stream.
-
- Parameters:
- focal_points: This is the query sentence. It is in a list form where
- the elemnts of the list are the query sentences.
- time_step: Current time_step
- n_count: The number of nodes that we want to retrieve.
+ curr_package["last_retrieved"] = self.last_retrieved
+ curr_package["pointer_id"] = self.pointer_id
+
+ return curr_package
+
+
+# ##############################################################################
+# ### MEMORY STREAM ###
+# ##############################################################################
+
+class MemoryStream:
+ def __init__(self, nodes, embeddings):
+ # Loading the memory stream for the agent.
+ self.seq_nodes = []
+ self.id_to_node = dict()
+ for node in nodes:
+ new_node = ConceptNode(node)
+ self.seq_nodes += [new_node]
+ self.id_to_node[new_node.node_id] = new_node
+
+ self.embeddings = embeddings
+
+
+ def count_observations(self):
+ """
+ Counting the number of observations (basically, the number of all nodes in
+ memory stream except for the reflections)
+
+ Parameters:
+ None
+ Returns:
+ Count
+ """
+ count = 0
+ for i in self.seq_nodes:
+ if i.node_type == "observation":
+ count += 1
+ return count
+
+
+ def retrieve(self, focal_points, time_step, n_count=120, curr_filter="all",
+ hp=[0, 1, 0.5], stateless=False, verbose=False):
+ """
+ Retrieve elements from the memory stream.
+
+ Parameters:
+ focal_points: This is the query sentence. It is in a list form where
+ the elemnts of the list are the query sentences.
+ time_step: Current time_step
+ n_count: The number of nodes that we want to retrieve.
curr_filter: Filtering the node.type that we want to retrieve.
- Acceptable values are 'all', 'reflection', 'observation'
- hp: Hyperparameter for [recency_w, relevance_w, importance_w]
- verbose: verbose
- Returns:
- retrieved: A dictionary whose keys are a focal_pt query str, and whose
- values are a list of nodes that are retrieved for that query str.
- """
- curr_nodes = []
-
- # If the memory stream is empty, we return an empty dictionary.
- if len(self.seq_nodes) == 0:
- return dict()
-
- # Filtering for the desired node type. curr_filter can be one of the three
- # elements: 'all', 'reflection', 'observation'
- if curr_filter == "all":
- curr_nodes = self.seq_nodes
- else:
- for curr_node in self.seq_nodes:
- if curr_node.node_type == curr_filter:
- curr_nodes += [curr_node]
-
- # 确保embeddings不为None
- if self.embeddings is None:
- print("警告: 在retrieve方法中,embeddings为None,初始化为空字典")
- self.embeddings = {}
-
- # is the main dictionary that we are returning
- retrieved = dict()
- for focal_pt in focal_points:
- # Calculating the component dictionaries and normalizing them.
- x = extract_recency(curr_nodes)
- recency_out = normalize_dict_floats(x, 0, 1)
- x = extract_importance(curr_nodes)
- importance_out = normalize_dict_floats(x, 0, 1)
- x = extract_relevance(curr_nodes, self.embeddings, focal_pt)
- relevance_out = normalize_dict_floats(x, 0, 1)
-
- # Computing the final scores that combines the component values.
- master_out = dict()
- for key in recency_out.keys():
- recency_w = hp[0]
- relevance_w = hp[1]
- importance_w = hp[2]
- master_out[key] = (recency_w * recency_out[key]
- + relevance_w * relevance_out[key]
- + importance_w * importance_out[key])
-
- if verbose:
- master_out = top_highest_x_values(master_out, len(master_out.keys()))
- for key, val in master_out.items():
- print (self.id_to_node[key].content, val)
- print (recency_w*recency_out[key]*1,
- relevance_w*relevance_out[key]*1,
- importance_w*importance_out[key]*1)
-
- # Extracting the highest x values.
- # has the key of node.id and value of float. Once we get
- # the highest x values, we want to translate the node.id into nodes
- # and return the list of nodes.
- master_out = top_highest_x_values(master_out, n_count)
- master_nodes = [self.id_to_node[key] for key in list(master_out.keys())]
-
- # **Sort the master_nodes list by last_retrieved in descending order**
- master_nodes = sorted(master_nodes, key=lambda node: node.created, reverse=False)
-
- # We do not want to update the last retrieved time_step for these nodes
- # if we are in a stateless mode.
- if not stateless:
- for n in master_nodes:
- n.last_retrieved = time_step
-
- retrieved[focal_pt] = master_nodes
-
- return retrieved
-
-
- def _add_node(self, time_step, node_type, content, importance, pointer_id):
- """
- Adding a new node to the memory stream.
-
- Parameters:
- time_step: Current time_step
- node_type: type of node -- it's either reflection, observation
- content: the str content of the memory record
- importance: int score of the importance score
- pointer_id: the str of the parent node
- Returns:
- retrieved: A dictionary whose keys are a focal_pt query str, and whose
- values are a list of nodes that are retrieved for that query str.
- """
- node_dict = dict()
- node_dict["node_id"] = len(self.seq_nodes)
+ Acceptable values are 'all', 'reflection', 'observation', 'conversation'
+ hp: Hyperparameter for [recency_w, relevance_w, importance_w]
+ verbose: verbose
+ Returns:
+ retrieved: A dictionary whose keys are a focal_pt query str, and whose
+ values are a list of nodes that are retrieved for that query str.
+ """
+ curr_nodes = []
+
+ # If the memory stream is empty, we return an empty dictionary.
+ if len(self.seq_nodes) == 0:
+ return dict()
+
+ # Filtering for the desired node type. curr_filter can be one of the
+ # elements: 'all', 'reflection', 'observation', 'conversation'
+ if curr_filter == "all":
+ curr_nodes = self.seq_nodes
+ else:
+ for curr_node in self.seq_nodes:
+ if curr_node.node_type == curr_filter:
+ curr_nodes += [curr_node]
+
+ # 确保embeddings不为None
+ if self.embeddings is None:
+ print("警告: 在retrieve方法中,embeddings为None,初始化为空字典")
+ self.embeddings = {}
+
+ # is the main dictionary that we are returning
+ retrieved = dict()
+ for focal_pt in focal_points:
+ # Calculating the component dictionaries and normalizing them.
+ x = extract_recency(curr_nodes)
+ recency_out = normalize_dict_floats(x, 0, 1)
+ x = extract_importance(curr_nodes)
+ importance_out = normalize_dict_floats(x, 0, 1)
+ x = extract_relevance(curr_nodes, self.embeddings, focal_pt)
+ relevance_out = normalize_dict_floats(x, 0, 1)
+
+ # Computing the final scores that combines the component values.
+ master_out = dict()
+ for key in recency_out.keys():
+ recency_w = hp[0]
+ relevance_w = hp[1]
+ importance_w = hp[2]
+ master_out[key] = (recency_w * recency_out[key]
+ + relevance_w * relevance_out[key]
+ + importance_w * importance_out[key])
+
+ if verbose:
+ master_out = top_highest_x_values(master_out, len(master_out.keys()))
+ for key, val in master_out.items():
+ print (self.id_to_node[key].content, val)
+ print (recency_w*recency_out[key]*1,
+ relevance_w*relevance_out[key]*1,
+ importance_w*importance_out[key]*1)
+
+ # Extracting the highest x values.
+ # has the key of node.id and value of float. Once we get
+ # the highest x values, we want to translate the node.id into nodes
+ # and return the list of nodes.
+ master_out = top_highest_x_values(master_out, n_count)
+ master_nodes = [self.id_to_node[key] for key in list(master_out.keys())]
+
+ # **Sort the master_nodes list by last_retrieved in descending order**
+ master_nodes = sorted(master_nodes, key=lambda node: node.created, reverse=False)
+
+ # We do not want to update the last retrieved time_step for these nodes
+ # if we are in a stateless mode.
+ if not stateless:
+ for n in master_nodes:
+ n.last_retrieved = time_step
+
+ retrieved[focal_pt] = master_nodes
+
+ return retrieved
+
+
+ def _add_node(self, time_step, node_type, content, importance, pointer_id):
+ """
+ Adding a new node to the memory stream.
+
+ Parameters:
+ time_step: Current time_step
+ node_type: type of node -- it's either reflection, observation, conversation
+ content: the str content of the memory record
+ importance: int score of the importance score
+ pointer_id: the str of the parent node
+ Returns:
+ retrieved: A dictionary whose keys are a focal_pt query str, and whose
+ values are a list of nodes that are retrieved for that query str.
+ """
+ node_dict = dict()
+ node_dict["node_id"] = len(self.seq_nodes)
node_dict["node_type"] = node_type
node_dict["content"] = content
node_dict["importance"] = importance
+ node_dict["datetime"] = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")
node_dict["created"] = time_step
- node_dict["last_retrieved"] = time_step
- node_dict["pointer_id"] = pointer_id
- new_node = ConceptNode(node_dict)
-
- self.seq_nodes += [new_node]
- self.id_to_node[new_node.node_id] = new_node
-
- # 确保embeddings不为None
- if self.embeddings is None:
- self.embeddings = {}
-
- try:
- self.embeddings[content] = get_text_embedding(content)
- except Exception as e:
- print(f"获取文本嵌入时出错: {str(e)}")
- # 如果获取嵌入失败,使用空列表代替
- self.embeddings[content] = []
-
-
+ node_dict["last_retrieved"] = time_step
+ node_dict["pointer_id"] = pointer_id
+ new_node = ConceptNode(node_dict)
+
+ self.seq_nodes += [new_node]
+ self.id_to_node[new_node.node_id] = new_node
+
+ # 确保embeddings不为None
+ if self.embeddings is None:
+ self.embeddings = {}
+
+ try:
+ self.embeddings[content] = get_text_embedding(content)
+ except Exception as e:
+ print(f"获取文本嵌入时出错: {str(e)}")
+ # 如果获取嵌入失败,使用空列表代替
+ self.embeddings[content] = []
+
+
def remember(self, content, time_step=0):
score = generate_importance_score([content])[0]
self._add_node(time_step, "observation", content, score, None)
+ def remember_conversation(self, content, time_step=0):
+ score = generate_importance_score([content])[0]
+ self._add_node(time_step, "conversation", content, score, None)
- def reflect(self, anchor, reflection_count=5,
- retrieval_count=120, time_step=0):
- records = self.retrieve([anchor], time_step, retrieval_count)[anchor]
- record_ids = [i.node_id for i in records]
- reflections = generate_reflection(records, anchor, reflection_count)
- scores = generate_importance_score(reflections)
-
- for count, reflection in enumerate(reflections):
- self._add_node(time_step, "reflection", reflections[count],
- scores[count], record_ids)
+
+ def reflect(self, anchor, reflection_count=5,
+ retrieval_count=120, time_step=0):
+ records = self.retrieve([anchor], time_step, retrieval_count)[anchor]
+ record_ids = [i.node_id for i in records]
+ reflections = generate_reflection(records, anchor, reflection_count)
+ scores = generate_importance_score(reflections)
+
+ for count, reflection in enumerate(reflections):
+ self._add_node(time_step, "reflection", reflections[count],
+ scores[count], record_ids)
diff --git a/gui/flask_server.py b/gui/flask_server.py
index d22e0e9..c64873a 100644
--- a/gui/flask_server.py
+++ b/gui/flask_server.py
@@ -74,9 +74,9 @@ def __get_template():
except Exception as e:
return f"Error rendering template: {e}", 500
-def __get_device_list():
- try:
- if config_util.start_mode == 'common':
+def __get_device_list():
+ try:
+ if config_util.start_mode == 'common':
audio = pyaudio.PyAudio()
device_list = []
for i in range(audio.get_device_count()):
@@ -86,33 +86,33 @@ def __get_device_list():
return list(set(device_list))
else:
return []
- except Exception as e:
- print(f"Error getting device list: {e}")
- return []
-
-def _as_bool(value):
- if isinstance(value, bool):
- return value
- if value is None:
- return False
- if isinstance(value, (int, float)):
- return value != 0
- if isinstance(value, str):
- return value.strip().lower() in ("1", "true", "yes", "y", "on")
- return False
-
-def _build_llm_url(base_url: str) -> str:
- if not base_url:
- return ""
- url = base_url.rstrip("/")
- if url.endswith("/chat/completions"):
- return url
- if url.endswith("/v1"):
- return url + "/chat/completions"
- return url + "/v1/chat/completions"
-
-@__app.route('/api/submit', methods=['post'])
-def api_submit():
+ except Exception as e:
+ print(f"Error getting device list: {e}")
+ return []
+
+def _as_bool(value):
+ if isinstance(value, bool):
+ return value
+ if value is None:
+ return False
+ if isinstance(value, (int, float)):
+ return value != 0
+ if isinstance(value, str):
+ return value.strip().lower() in ("1", "true", "yes", "y", "on")
+ return False
+
+def _build_llm_url(base_url: str) -> str:
+ if not base_url:
+ return ""
+ url = base_url.rstrip("/")
+ if url.endswith("/chat/completions"):
+ return url
+ if url.endswith("/v1"):
+ return url + "/chat/completions"
+ return url + "/v1/chat/completions"
+
+@__app.route('/api/submit', methods=['post'])
+def api_submit():
data = request.values.get('data')
if not data:
return jsonify({'result': 'error', 'message': '未提供数据'})
@@ -309,27 +309,27 @@ def api_send():
# 获取指定用户的消息记录(支持分页)
@__app.route('/api/get-msg', methods=['post'])
-def api_get_Msg():
- try:
- data = request.form.get('data')
- if data is None:
- data = request.get_json(silent=True) or {}
- else:
- data = json.loads(data)
- if not isinstance(data, dict):
- data = {}
- username = data.get("username")
- limit = data.get("limit", 30) # 默认每页30条
- offset = data.get("offset", 0) # 默认从0开始
- contentdb = content_db.new_instance()
- uid = 0
- if username:
- uid = member_db.new_instance().find_user(username)
- if uid == 0:
- return json.dumps({'list': [], 'total': 0, 'hasMore': False})
- # 获取总数用于判断是否还有更多
- total = contentdb.get_message_count(uid)
- list = contentdb.get_list('all', 'desc', limit, uid, offset)
+def api_get_Msg():
+ try:
+ data = request.form.get('data')
+ if data is None:
+ data = request.get_json(silent=True) or {}
+ else:
+ data = json.loads(data)
+ if not isinstance(data, dict):
+ data = {}
+ username = data.get("username")
+ limit = data.get("limit", 30) # 默认每页30条
+ offset = data.get("offset", 0) # 默认从0开始
+ contentdb = content_db.new_instance()
+ uid = 0
+ if username:
+ uid = member_db.new_instance().find_user(username)
+ if uid == 0:
+ return json.dumps({'list': [], 'total': 0, 'hasMore': False})
+ # 获取总数用于判断是否还有更多
+ total = contentdb.get_message_count(uid)
+ list = contentdb.get_list('all', 'desc', limit, uid, offset)
relist = []
i = len(list) - 1
while i >= 0:
@@ -354,134 +354,143 @@ def api_send_v1_chat_completions():
data = request.get_json()
if not data:
return jsonify({'error': '未提供数据'})
- try:
- model = data.get('model', 'fay')
- if model == 'llm':
- try:
- config_util.load_config()
- llm_url = _build_llm_url(config_util.gpt_base_url)
- api_key = config_util.key_gpt_api_key
- model_engine = config_util.gpt_model_engine
- except Exception as exc:
- return jsonify({'error': f'LLM config load failed: {exc}'}), 500
-
- if not llm_url:
- return jsonify({'error': 'LLM base_url is not configured'}), 500
-
- payload = dict(data)
- if payload.get('model') == 'llm' and model_engine:
- payload['model'] = model_engine
-
- stream_requested = _as_bool(payload.get('stream', False))
- headers = {'Content-Type': 'application/json'}
- if api_key:
- headers['Authorization'] = f'Bearer {api_key}'
-
- try:
- if stream_requested:
- resp = requests.post(llm_url, headers=headers, json=payload, stream=True)
-
- def generate():
- try:
- for chunk in resp.iter_content(chunk_size=8192):
- if not chunk:
- continue
- yield chunk
- finally:
- resp.close()
-
- content_type = resp.headers.get("Content-Type", "text/event-stream")
- if "charset=" not in content_type.lower():
- content_type = f"{content_type}; charset=utf-8"
- return Response(
- generate(),
- status=resp.status_code,
- content_type=content_type,
- )
-
- resp = requests.post(llm_url, headers=headers, json=payload, timeout=60)
- content_type = resp.headers.get("Content-Type", "application/json")
- if "charset=" not in content_type.lower():
- content_type = f"{content_type}; charset=utf-8"
- return Response(
- resp.content,
- status=resp.status_code,
- content_type=content_type,
- )
- except Exception as exc:
- return jsonify({'error': f'LLM request failed: {exc}'}), 500
-
- last_content = ""
- if 'messages' in data and data['messages']:
- last_message = data['messages'][-1]
- username = last_message.get('role', 'User')
- if username == 'user':
- username = 'User'
- last_content = last_message.get('content', 'No content provided')
- else:
- last_content = 'No messages found'
- username = 'User'
+ try:
+ model = data.get('model', 'fay')
+ if model == 'llm':
+ try:
+ config_util.load_config()
+ llm_url = _build_llm_url(config_util.gpt_base_url)
+ api_key = config_util.key_gpt_api_key
+ model_engine = config_util.gpt_model_engine
+ except Exception as exc:
+ return jsonify({'error': f'LLM config load failed: {exc}'}), 500
- observation = data.get('observation', '')
+ if not llm_url:
+ return jsonify({'error': 'LLM base_url is not configured'}), 500
+
+ payload = dict(data)
+ if payload.get('model') == 'llm' and model_engine:
+ payload['model'] = model_engine
+
+ stream_requested = _as_bool(payload.get('stream', False))
+ headers = {'Content-Type': 'application/json'}
+ if api_key:
+ headers['Authorization'] = f'Bearer {api_key}'
+
+ try:
+ if stream_requested:
+ resp = requests.post(llm_url, headers=headers, json=payload, stream=True)
+
+ def generate():
+ try:
+ for chunk in resp.iter_content(chunk_size=8192):
+ if not chunk:
+ continue
+ yield chunk
+ finally:
+ resp.close()
+
+ content_type = resp.headers.get("Content-Type", "text/event-stream")
+ if "charset=" not in content_type.lower():
+ content_type = f"{content_type}; charset=utf-8"
+ return Response(
+ generate(),
+ status=resp.status_code,
+ content_type=content_type,
+ )
+
+ resp = requests.post(llm_url, headers=headers, json=payload, timeout=60)
+ content_type = resp.headers.get("Content-Type", "application/json")
+ if "charset=" not in content_type.lower():
+ content_type = f"{content_type}; charset=utf-8"
+ return Response(
+ resp.content,
+ status=resp.status_code,
+ content_type=content_type,
+ )
+ except Exception as exc:
+ return jsonify({'error': f'LLM request failed: {exc}'}), 500
+
+ last_content = ""
+ username = "User"
+ messages = data.get("messages")
+ if isinstance(messages, list) and messages:
+ last_message = messages[-1] or {}
+ username = last_message.get("role", "User") or "User"
+ if username == "user":
+ username = "User"
+ last_content = last_message.get("content") or ""
+ elif isinstance(messages, str):
+ last_content = messages
+
+ observation = data.get('observation', '')
# 检查请求中是否指定了流式传输
- stream_requested = data.get('stream', False)
+ stream_requested = data.get('stream', False)
no_reply = _as_bool(data.get('no_reply', data.get('noReply', False)))
- if no_reply:
- interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream': bool(stream_requested), 'no_reply': True})
- util.printInfo(1, username, '[text chat no_reply]{}'.format(interact.data["msg"]), time.time())
- fay_booter.feiFei.on_interact(interact)
- if stream_requested or model == 'fay-streaming':
- def generate():
- message = {
- "id": "faystreaming-" + str(uuid.uuid4()),
- "object": "chat.completion.chunk",
- "created": int(time.time()),
- "model": model,
- "choices": [
- {
- "delta": {
- "content": ""
- },
- "index": 0,
- "finish_reason": "stop"
- }
- ],
- "usage": {
- "prompt_tokens": len(last_content),
- "completion_tokens": 0,
- "total_tokens": len(last_content)
- },
- "system_fingerprint": "",
- "no_reply": True
- }
- yield f"data: {json.dumps(message)}\n\n"
- yield 'data: [DONE]\n\n'
- return Response(generate(), mimetype='text/event-stream')
- return jsonify({
- "id": "fay-" + str(uuid.uuid4()),
- "object": "chat.completion",
- "created": int(time.time()),
- "model": model,
- "choices": [
- {
- "index": 0,
- "message": {
- "role": "assistant",
- "content": ""
- },
- "logprobs": "",
- "finish_reason": "stop"
- }
- ],
- "usage": {
- "prompt_tokens": len(last_content),
- "completion_tokens": 0,
- "total_tokens": len(last_content)
- },
- "system_fingerprint": "",
- "no_reply": True
- })
- if stream_requested or model == 'fay-streaming':
+ obs_text = ""
+ if observation is not None:
+ obs_text = observation.strip() if isinstance(observation, str) else str(observation).strip()
+ message_text = last_content.strip() if isinstance(last_content, str) else str(last_content).strip()
+ if not message_text and not obs_text:
+ return jsonify({'error': 'messages and observation are both empty'}), 400
+ if not message_text and obs_text:
+ no_reply = True
+ if no_reply:
+ interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream': bool(stream_requested), 'no_reply': True})
+ util.printInfo(1, username, '[text chat no_reply]{}'.format(interact.data["msg"]), time.time())
+ fay_booter.feiFei.on_interact(interact)
+ if stream_requested or model == 'fay-streaming':
+ def generate():
+ message = {
+ "id": "faystreaming-" + str(uuid.uuid4()),
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": model,
+ "choices": [
+ {
+ "delta": {
+ "content": ""
+ },
+ "index": 0,
+ "finish_reason": "stop"
+ }
+ ],
+ "usage": {
+ "prompt_tokens": len(last_content),
+ "completion_tokens": 0,
+ "total_tokens": len(last_content)
+ },
+ "system_fingerprint": "",
+ "no_reply": True
+ }
+ yield f"data: {json.dumps(message)}\n\n"
+ yield 'data: [DONE]\n\n'
+ return Response(generate(), mimetype='text/event-stream')
+ return jsonify({
+ "id": "fay-" + str(uuid.uuid4()),
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": model,
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": ""
+ },
+ "logprobs": "",
+ "finish_reason": "stop"
+ }
+ ],
+ "usage": {
+ "prompt_tokens": len(last_content),
+ "completion_tokens": 0,
+ "total_tokens": len(last_content)
+ },
+ "system_fingerprint": "",
+ "no_reply": True
+ })
+ if stream_requested or model == 'fay-streaming':
interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream':True})
util.printInfo(1, username, '[文字沟通接口(流式)]{}'.format(interact.data["msg"]), time.time())
fay_booter.feiFei.on_interact(interact)
@@ -588,7 +597,6 @@ def api_delete_user():
# 清除缓存的 agent 对象
try:
- from llm import nlp_cognitive_stream
if hasattr(nlp_cognitive_stream, 'agents') and username in nlp_cognitive_stream.agents:
del nlp_cognitive_stream.agents[username]
except Exception:
diff --git a/llm/nlp_cognitive_stream.py b/llm/nlp_cognitive_stream.py
index 776c0cc..c41aade 100644
--- a/llm/nlp_cognitive_stream.py
+++ b/llm/nlp_cognitive_stream.py
@@ -65,6 +65,11 @@ memory_cleared = False # 添加记忆清除标记
# 新增: 当前会话用户名及按用户获取memory目录的辅助函数
current_username = None # 当前会话用户名
+def _log_prompt(messages: List[SystemMessage | HumanMessage], tag: str = ""):
+ """No-op placeholder for prompt logging (disabled)."""
+ return
+
+
llm = ChatOpenAI(
model=cfg.gpt_model_engine,
base_url=cfg.gpt_base_url,
@@ -314,14 +319,26 @@ def _remove_prestart_from_text(text: str, keep_marked: bool = True) -> str:
return text.strip()
-def _remove_think_from_text(text: str) -> str:
- """从文本中移除 think 标签及其内容"""
- if not text:
- return text
- import re
- cleaned = re.sub(r'[\s\S]*?', '', text, flags=re.IGNORECASE)
- cleaned = re.sub(r'?think>', '', cleaned, flags=re.IGNORECASE)
- return cleaned.strip()
+def _remove_think_from_text(text: str) -> str:
+ """从文本中移除 think 标签及其内容"""
+ if not text:
+ return text
+ import re
+ cleaned = re.sub(r'[\s\S]*?', '', text, flags=re.IGNORECASE)
+ cleaned = re.sub(r'?think>', '', cleaned, flags=re.IGNORECASE)
+ return cleaned.strip()
+
+
+def _strip_json_code_fence(text: str) -> str:
+ """Strip ```json ... ``` wrappers if present."""
+ if not text:
+ return text
+ import re
+ trimmed = text.strip()
+ match = re.match(r"^```(?:json)?\\s*(.*?)\\s*```$", trimmed, flags=re.IGNORECASE | re.DOTALL)
+ if match:
+ return match.group(1).strip()
+ return text
def _format_conversation_block(conversation: List[Dict], username: str = "User") -> str:
@@ -440,35 +457,35 @@ def _run_prestart_tools(user_question: str) -> List[Dict[str, Any]]:
return results
-def _truncate_history(
- history: List[ToolResult],
- limit: Optional[int] = None,
- output_limit: Optional[int] = None,
-) -> str:
- if not history:
- return "(暂无)"
- lines: List[str] = []
- selected = history if limit is None else history[-limit:]
- for item in selected:
- call = item.get("call", {})
- name = call.get("name", "未知工具")
- attempt = item.get("attempt", 0)
- success = item.get("success", False)
- status = "成功" if success else "失败"
- lines.append(f"- {name} 第 {attempt} 次 → {status}")
- output = item.get("output")
- if output is not None:
- output_text = str(output)
- if output_limit is not None:
- output_text = _truncate_text(output_text, output_limit)
- lines.append(" 输出:" + output_text)
- error = item.get("error")
- if error is not None:
- error_text = str(error)
- if output_limit is not None:
- error_text = _truncate_text(error_text, output_limit)
- lines.append(" 错误:" + error_text)
- return "\n".join(lines)
+def _truncate_history(
+ history: List[ToolResult],
+ limit: Optional[int] = None,
+ output_limit: Optional[int] = None,
+) -> str:
+ if not history:
+ return "(暂无)"
+ lines: List[str] = []
+ selected = history if limit is None else history[-limit:]
+ for item in selected:
+ call = item.get("call", {})
+ name = call.get("name", "未知工具")
+ attempt = item.get("attempt", 0)
+ success = item.get("success", False)
+ status = "成功" if success else "失败"
+ lines.append(f"- {name} 第 {attempt} 次 → {status}")
+ output = item.get("output")
+ if output is not None:
+ output_text = str(output)
+ if output_limit is not None:
+ output_text = _truncate_text(output_text, output_limit)
+ lines.append(" 输出:" + output_text)
+ error = item.get("error")
+ if error is not None:
+ error_text = str(error)
+ if output_limit is not None:
+ error_text = _truncate_text(error_text, output_limit)
+ lines.append(" 错误:" + error_text)
+ return "\n".join(lines)
def _format_schema_parameters(schema: Dict[str, Any]) -> List[str]:
@@ -598,9 +615,9 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess
# 生成对话文本,使用代码块包裹每条消息
convo_text = _format_conversation_block(conversation, username)
- history_text = _truncate_history(history)
- tools_text = _format_tools_for_prompt(tool_specs)
- preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else ""
+ history_text = _truncate_history(history)
+ tools_text = _format_tools_for_prompt(tool_specs)
+ preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else ""
# 只有当有预启动工具结果时才显示,工具名+参数在外,结果在代码块内
if prestart_context and prestart_context.strip():
@@ -638,7 +655,6 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess
{observation or '(无补充)'}
---
-**关联记忆**
{memory_context or '(无相关记忆)'}
---
{prestart_section}
@@ -682,9 +698,9 @@ def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessag
display_username = "主人" if username == "User" else username
# 生成对话文本,使用代码块包裹每条消息
- conversation_block = _format_conversation_block(conversation, username)
- history_text = _truncate_history(state.get("tool_results", []))
- preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else ""
+ conversation_block = _format_conversation_block(conversation, username)
+ history_text = _truncate_history(state.get("tool_results", []))
+ preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else ""
# 只有当有预启动工具结果时才显示,工具名+参数在外,结果在代码块内
if prestart_context and prestart_context.strip():
@@ -755,6 +771,7 @@ def _call_planner_llm(
解析后的决策字典
"""
messages = _build_planner_messages(state)
+ _log_prompt(messages, tag="planner")
# 如果有流式回调,使用流式模式检测 finish+message
if stream_callback is not None:
@@ -826,6 +843,7 @@ def _call_planner_llm(
# 处理完整响应
trimmed = _remove_think_from_text(accumulated.strip())
+ trimmed = _strip_json_code_fence(trimmed)
if in_message_mode:
# 提取完整 message 内容,去掉结尾的 "}
@@ -869,6 +887,7 @@ def _call_planner_llm(
raise RuntimeError("规划器返回内容异常,未获得字符串。")
# 先移除 think 标签,兼容带思考标签的模型(如 DeepSeek R1)
trimmed = _remove_think_from_text(content.strip())
+ trimmed = _strip_json_code_fence(trimmed)
try:
decision = json.loads(trimmed)
except json.JSONDecodeError as exc:
@@ -929,6 +948,7 @@ def _plan_next_action(state: AgentState) -> AgentState:
"规划器:检测到工具重复调用,使用最新结果产出最终回复。"
)
final_messages = _build_final_messages(state)
+ _log_prompt(final_messages, tag="final")
preview = last_entry.get("output")
return {
"status": "completed",
@@ -1669,27 +1689,57 @@ def load_agent_memory(agent, username=None):
# 记忆对话内容的线程函数
def remember_conversation_thread(username, content, response_text):
- """
- 在单独线程中记录对话内容到代理记忆
-
- 参数:
- username: 用户名
- content: 用户问题内容
- response_text: 代理回答内容
- """
- global agents
+ """Background task to store a conversation memory node."""
try:
+ ag = create_agent(username)
+ if ag is None:
+ return
+ questioner = username
+ if isinstance(questioner, str) and questioner.lower() == "user":
+ questioner = "主人"
+ answerer = ag.scratch.get("first_name", "Fay")
+ question_text = content if content is not None else ""
+ answer_text = response_text if response_text is not None else ""
+ memory_content = f"{questioner}:{question_text}\n{answerer}:{answer_text}"
with agent_lock:
- ag = agents.get(username)
- if ag is None:
- return
time_step = get_current_time_step(username)
- name = "主人" if username == "User" else username
- # 记录对话内容
- memory_content = f"在对话中,我回答了{name}的问题:{content}\n,我的回答是:{response_text}"
- ag.remember(memory_content, time_step)
+ if ag.memory_stream and hasattr(ag.memory_stream, "remember_conversation"):
+ ag.memory_stream.remember_conversation(memory_content, time_step)
+ else:
+ ag.remember(memory_content, time_step)
except Exception as e:
- util.log(1, f"记忆对话内容出错: {str(e)}")
+ util.log(1, f"记录对话记忆失败: {str(e)}")
+
+def remember_observation_thread(username, observation_text):
+ """Background task to store an observation memory node."""
+ try:
+ ag = create_agent(username)
+ if ag is None:
+ return
+ text = observation_text if observation_text is not None else ""
+ memory_content = text
+ with agent_lock:
+ time_step = get_current_time_step(username)
+ if ag.memory_stream and hasattr(ag.memory_stream, "remember"):
+ ag.memory_stream.remember(memory_content, time_step)
+ else:
+ ag.remember(memory_content, time_step)
+ except Exception as e:
+ util.log(1, f"记录观察记忆失败: {str(e)}")
+
+def record_observation(username, observation_text):
+ """Persist an observation memory node asynchronously."""
+ if observation_text is None:
+ return False, "observation text is required"
+ text = observation_text.strip() if isinstance(observation_text, str) else str(observation_text).strip()
+ if not text:
+ return False, "observation text is required"
+ try:
+ MyThread(target=remember_observation_thread, args=(username, text)).start()
+ return True, "observation recorded"
+ except Exception as exc:
+ util.log(1, f"记录观察记忆失败: {exc}")
+ return False, f"observation record failed: {exc}"
def question(content, username, observation=None):
"""处理用户提问并返回回复。"""
@@ -1725,24 +1775,49 @@ def question(content, username, observation=None):
),
}
+ memory_sections = [
+ ("观察记忆", "observation"),
+ ("对话记忆", "conversation"),
+ ("反思记忆", "reflection"),
+ ]
memory_context = ""
- if agent.memory_stream and len(agent.memory_stream.seq_nodes) > 0:
+ if agent.memory_stream and len(agent.memory_stream.seq_nodes) > 0 and content:
current_time_step = get_current_time_step(username)
+ query = content.strip() if isinstance(content, str) else str(content)
+ max_per_type = 10
+ section_texts = []
try:
- query = f"{'主人' if username == 'User' else username}提出了问题:{content}"
- related_memories = agent.memory_stream.retrieve(
+ combined = agent.memory_stream.retrieve(
[query],
current_time_step,
- n_count=30,
+ n_count=max_per_type * len(memory_sections),
curr_filter="all",
hp=[0.8, 0.5, 0.5],
stateless=False,
)
- if related_memories and query in related_memories:
- memory_nodes = related_memories[query]
- memory_context = "\n".join(f"- {node.content}" for node in memory_nodes)
+ all_nodes = combined.get(query, []) if combined else []
except Exception as exc:
- util.log(1, f"获取相关记忆时出错: {exc}")
+ util.log(1, f"获取关联记忆时出错: {exc}")
+ all_nodes = []
+
+ for label, node_type in memory_sections:
+ section_lines = "(无)"
+ try:
+ memory_nodes = [n for n in all_nodes if getattr(n, "node_type", "") == node_type][:max_per_type]
+ if memory_nodes:
+ formatted = []
+ for node in memory_nodes:
+ ts = (getattr(node, "datetime", "") or "").strip()
+ prefix = f"[{ts}] " if ts else ""
+ formatted.append(f"- {prefix}{node.content}")
+ section_lines = "\n".join(formatted)
+ except Exception as exc:
+ util.log(1, f"获取{label}时出错: {exc}")
+ section_lines = "(获取失败)"
+ section_texts.append(f"**{label}**\n{section_lines}")
+ memory_context = "\n".join(section_texts)
+ else:
+ memory_context = "\n".join(f"**{label}**\n(无)" for label, _ in memory_sections)
prestart_context = ""
prestart_stream_text = ""
@@ -1866,22 +1941,22 @@ def question(content, username, observation=None):
or messages_buffer[-1]['content'] != content
):
messages_buffer.append({"role": "user", "content": content})
- else:
- # 不隔离:按独立消息存储,保留用户名信息
- def append_to_buffer_multi(role: str, text_value: str, msg_username: str = "") -> None:
- if not text_value:
- return
- messages_buffer.append({"role": role, "content": text_value, "username": msg_username})
- if len(messages_buffer) > 60:
- del messages_buffer[:-60]
-
- def append_to_buffer(role: str, text_value: str) -> None:
- append_to_buffer_multi(role, text_value, "")
-
- for record in history_records:
- msg_type, msg_text, msg_username = record
- if not msg_text:
- continue
+ else:
+ # 不隔离:按独立消息存储,保留用户名信息
+ def append_to_buffer_multi(role: str, text_value: str, msg_username: str = "") -> None:
+ if not text_value:
+ return
+ messages_buffer.append({"role": role, "content": text_value, "username": msg_username})
+ if len(messages_buffer) > 60:
+ del messages_buffer[:-60]
+
+ def append_to_buffer(role: str, text_value: str) -> None:
+ append_to_buffer_multi(role, text_value, "")
+
+ for record in history_records:
+ msg_type, msg_text, msg_username = record
+ if not msg_text:
+ continue
if msg_type and msg_type.lower() in ('member', 'user'):
append_to_buffer_multi("user", msg_text, msg_username)
else:
@@ -2020,22 +2095,22 @@ def question(content, username, observation=None):
# 创建规划器流式回调,用于实时输出 finish+message 响应
planner_stream_buffer = {"text": "", "first_chunk": True}
- def planner_stream_callback(chunk_text: str) -> None:
- """规划器流式回调,将 message 内容实时输出"""
- nonlocal accumulated_text, full_response_text, is_first_sentence, is_agent_think_start
- if not chunk_text:
- return
- planner_stream_buffer["text"] += chunk_text
- if planner_stream_buffer["first_chunk"]:
- planner_stream_buffer["first_chunk"] = False
- if is_agent_think_start:
- closing = ""
- accumulated_text += closing
- full_response_text += closing
- is_agent_think_start = False
- # 使用 stream_response_chunks 的逻辑进行分句流式输出
- accumulated_text += chunk_text
- full_response_text += chunk_text
+ def planner_stream_callback(chunk_text: str) -> None:
+ """规划器流式回调,将 message 内容实时输出"""
+ nonlocal accumulated_text, full_response_text, is_first_sentence, is_agent_think_start
+ if not chunk_text:
+ return
+ planner_stream_buffer["text"] += chunk_text
+ if planner_stream_buffer["first_chunk"]:
+ planner_stream_buffer["first_chunk"] = False
+ if is_agent_think_start:
+ closing = ""
+ accumulated_text += closing
+ full_response_text += closing
+ is_agent_think_start = False
+ # 使用 stream_response_chunks 的逻辑进行分句流式输出
+ accumulated_text += chunk_text
+ full_response_text += chunk_text
# 检查是否有完整句子可以输出
if len(accumulated_text) >= 20:
while True:
@@ -2507,46 +2582,46 @@ def save_agent_memory():
agent.scratch = {}
# 保存记忆前进行完整性检查
- try:
- # 检查seq_nodes中的每个节点是否有效
- valid_nodes = []
- for node in agent.memory_stream.seq_nodes:
- if node is None:
- util.log(1, "发现无效节点(None),跳过")
- continue
-
- if not hasattr(node, 'node_id') or not hasattr(node, 'content'):
- util.log(1, f"发现无效节点(缺少必要属性),跳过")
- continue
- raw_content = node.content if isinstance(node.content, str) else str(node.content)
- cleaned_content = _remove_think_from_text(raw_content)
- if cleaned_content != raw_content:
- old_content = raw_content
- node.content = cleaned_content
- if (
- agent.memory_stream.embeddings is not None
- and old_content in agent.memory_stream.embeddings
- and cleaned_content not in agent.memory_stream.embeddings
- ):
- agent.memory_stream.embeddings[cleaned_content] = agent.memory_stream.embeddings[old_content]
- else:
- node.content = raw_content
- valid_nodes.append(node)
-
- # 更新seq_nodes为有效节点列表
- agent.memory_stream.seq_nodes = valid_nodes
-
- # 重建id_to_node字典
- agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')}
- if agent.memory_stream.embeddings is not None:
- kept_contents = {node.content for node in valid_nodes if hasattr(node, 'content')}
- agent.memory_stream.embeddings = {
- key: value
- for key, value in agent.memory_stream.embeddings.items()
- if key in kept_contents
- }
- except Exception as e:
- util.log(1, f"检查记忆完整性时出错: {str(e)}")
+ try:
+ # 检查seq_nodes中的每个节点是否有效
+ valid_nodes = []
+ for node in agent.memory_stream.seq_nodes:
+ if node is None:
+ util.log(1, "发现无效节点(None),跳过")
+ continue
+
+ if not hasattr(node, 'node_id') or not hasattr(node, 'content'):
+ util.log(1, f"发现无效节点(缺少必要属性),跳过")
+ continue
+ raw_content = node.content if isinstance(node.content, str) else str(node.content)
+ cleaned_content = _remove_think_from_text(raw_content)
+ if cleaned_content != raw_content:
+ old_content = raw_content
+ node.content = cleaned_content
+ if (
+ agent.memory_stream.embeddings is not None
+ and old_content in agent.memory_stream.embeddings
+ and cleaned_content not in agent.memory_stream.embeddings
+ ):
+ agent.memory_stream.embeddings[cleaned_content] = agent.memory_stream.embeddings[old_content]
+ else:
+ node.content = raw_content
+ valid_nodes.append(node)
+
+ # 更新seq_nodes为有效节点列表
+ agent.memory_stream.seq_nodes = valid_nodes
+
+ # 重建id_to_node字典
+ agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')}
+ if agent.memory_stream.embeddings is not None:
+ kept_contents = {node.content for node in valid_nodes if hasattr(node, 'content')}
+ agent.memory_stream.embeddings = {
+ key: value
+ for key, value in agent.memory_stream.embeddings.items()
+ if key in kept_contents
+ }
+ except Exception as e:
+ util.log(1, f"检查记忆完整性时出错: {str(e)}")
# 保存记忆
try:
diff --git a/test/test_fay_gpt_stream.py b/test/test_fay_gpt_stream.py
index 2be854d..6574361 100644
--- a/test/test_fay_gpt_stream.py
+++ b/test/test_fay_gpt_stream.py
@@ -91,7 +91,7 @@ if __name__ == "__main__":
print("=" * 60)
print("示例1:张三的对话(带观察数据)")
print("=" * 60)
- test_gpt("你好,今天天气不错啊", username="user", observation=OBSERVATION_SAMPLES["张三"])
+ test_gpt(prompt="", username="user", observation=OBSERVATION_SAMPLES["张三"], no_reply=False)
print("\n")