diff --git a/core/fay_core.py b/core/fay_core.py index 0a0100b..1165f80 100644 --- a/core/fay_core.py +++ b/core/fay_core.py @@ -1,1026 +1,2706 @@ -# -*- coding: utf-8 -*- -#作用是处理交互逻辑,文字输入,语音、文字及情绪的发送、播放及展示输出 -import math -from operator import index -import os -import time -import socket -import requests -from pydub import AudioSegment -from queue import Queue -import re # 添加正则表达式模块用于过滤表情符号 -import uuid - -# 适应模型使用 -import numpy as np -from ai_module import baidu_emotion -from core import wsa_server -from core.interact import Interact -from tts.tts_voice import EnumVoice -from scheduler.thread_manager import MyThread -from tts import tts_voice -from utils import util, config_util -from core import qa_service -from utils import config_util as cfg -from core import content_db -from ai_module import nlp_cemotion -from core import stream_manager - -from core import member_db -import threading - -#加载配置 -cfg.load_config() -if cfg.tts_module =='ali': - from tts.ali_tss import Speech -elif cfg.tts_module == 'gptsovits': - from tts.gptsovits import Speech -elif cfg.tts_module == 'gptsovits_v3': - from tts.gptsovits_v3 import Speech -elif cfg.tts_module == 'volcano': - from tts.volcano_tts import Speech -else: - from tts.ms_tts_sdk import Speech - -#windows运行推送唇形数据 -import platform -if platform.system() == "Windows": - import sys - sys.path.append("test/ovr_lipsync") - from test_olipsync import LipSyncGenerator - - -#可以使用自动播报的标记 -can_auto_play = True -auto_play_lock = threading.RLock() - -class FeiFei: - def __init__(self): - self.lock = threading.Lock() - self.nlp_streams = {} # 存储用户ID到句子缓存的映射 - self.nlp_stream_lock = threading.Lock() # 保护nlp_streams字典的锁 - self.mood = 0.0 # 情绪值 - self.old_mood = 0.0 - self.item_index = 0 - self.X = np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, -1) # 适应模型变量矩阵 - # self.W = np.array([0.01577594,1.16119452,0.75828,0.207746,1.25017864,0.1044121,0.4294899,0.2770932]).reshape(-1,1) #适应模型变量矩阵 - self.W = np.array([0.0, 0.6, 0.1, 0.7, 0.3, 0.0, 0.0, 0.0]).reshape(-1, 1) # 适应模型变量矩阵 - - self.wsParam = None - self.wss = None - self.sp = Speech() - self.speaking = False #声音是否在播放 - self.__running = True - self.sp.connect() #TODO 预连接 - - self.timer = None - self.sound_query = Queue() - self.think_mode_users = {} # 使用字典存储每个用户的think模式状态 - self.think_time_users = {} #使用字典存储每个用户的think开始时间 - self.think_display_state = {} - self.think_display_limit = 400 - self.user_conv_map = {} #存储用户对话id及句子流序号,key为(username, conversation_id) - self.pending_isfirst = {} # 存储因prestart被过滤而延迟的isfirst标记,key为username - - def __remove_emojis(self, text): - """ - 改进的表情包过滤,避免误删除正常Unicode字符 - """ - # 更精确的emoji范围,避免误删除正常字符 - emoji_pattern = re.compile( - "[" - "\U0001F600-\U0001F64F" # 表情符号 (Emoticons) - "\U0001F300-\U0001F5FF" # 杂项符号和象形文字 (Miscellaneous Symbols and Pictographs) - "\U0001F680-\U0001F6FF" # 交通和地图符号 (Transport and Map Symbols) - "\U0001F1E0-\U0001F1FF" # 区域指示符号 (Regional Indicator Symbols) - "\U0001F900-\U0001F9FF" # 补充符号和象形文字 (Supplemental Symbols and Pictographs) - "\U0001FA70-\U0001FAFF" # 扩展A符号和象形文字 (Symbols and Pictographs Extended-A) - "\U00002600-\U000026FF" # 杂项符号 (Miscellaneous Symbols) - "\U00002700-\U000027BF" # 装饰符号 (Dingbats) - "\U0000FE00-\U0000FE0F" # 变体选择器 (Variation Selectors) - "\U0001F000-\U0001F02F" # 麻将牌 (Mahjong Tiles) - "\U0001F0A0-\U0001F0FF" # 扑克牌 (Playing Cards) - "]+", - flags=re.UNICODE, - ) - - # 保护常用的中文标点符号和特殊字符 - protected_chars = ["。", ",", "!", "?", ":", ";", "、", """, """, "'", "'", "(", ")", "【", "】", "《", "》"] - - # 先保存保护字符的位置 - protected_positions = {} - for i, char in enumerate(text): - if char in protected_chars: - protected_positions[i] = char - - # 执行emoji过滤 - filtered_text = emoji_pattern.sub('', text) - - # 如果过滤后文本长度变化太大,可能误删了正常字符,返回原文本 - if len(filtered_text) < len(text) * 0.5: # 如果删除了超过50%的内容 - return text - - return filtered_text - - def __process_stream_output(self, text, username, session_type="type2_stream", is_qa=False): - """ - 按流式方式分割和发送 type=2 的文本 - 使用安全的流式文本处理器和状态管理器 - """ - if not text or text.strip() == "": - return - - # 使用安全的流式文本处理器 - from utils.stream_text_processor import get_processor - from utils.stream_state_manager import get_state_manager - - processor = get_processor() - state_manager = get_state_manager() - - # 处理流式文本,is_qa=False表示普通模式 - success = processor.process_stream_text(text, username, is_qa=is_qa, session_type=session_type) - - if success: - # 普通模式结束会话 - state_manager.end_session(username, conversation_id=stream_manager.new_instance().get_conversation_id(username)) - else: - util.log(1, f"type=2流式处理失败,文本长度: {len(text)}") - # 失败时也要确保结束会话 - state_manager.force_reset_user_state(username) - - #语音消息处理检查是否命中q&a - def __get_answer(self, interleaver, text): - answer = None - # 全局问答 - answer, type = qa_service.QAService().question('qa',text) - if answer is not None: - return answer, type - else: - return None, None - - - #消息处理 - def __process_interact(self, interact: Interact): - if self.__running: - try: - index = interact.interact_type - username = interact.data.get("user", "User") - uid = member_db.new_instance().find_user(username) no_reply = interact.data.get("no_reply", False) if isinstance(no_reply, str): no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on") else: no_reply = bool(no_reply) - - if index == 1: #语音、文字交互 - - #记录用户问题,方便obs等调用 - self.write_to_file("./logs", "asr_result.txt", interact.data["msg"]) - - #同步用户问题到数字人 - if wsa_server.get_instance().is_connected(username): - content = {'Topic': 'human', 'Data': {'Key': 'question', 'Value': interact.data["msg"]}, 'Username' : interact.data.get("user")} - wsa_server.get_instance().add_cmd(content) - - #记录用户问题 - content_id = content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid) - if wsa_server.get_web_instance().is_connected(username): - wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid, "id":content_id}, "Username" : username}) - - if no_reply: return "" #确定是否命中q&a - answer, type = self.__get_answer(interact.interleaver, interact.data["msg"]) - - #大语言模型回复 - text = '' - if answer is None or type != "qa": - if wsa_server.get_web_instance().is_connected(username): - wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}) - if wsa_server.get_instance().is_connected(username): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'} - wsa_server.get_instance().add_cmd(content) - - # 根据配置动态调用不同的NLP模块 - if cfg.config["memory"].get("use_bionic_memory", False): - from llm import nlp_bionicmemory_stream - text = nlp_bionicmemory_stream.question(interact.data["msg"], username, interact.data.get("observation", None)) - else: - from llm import nlp_cognitive_stream - text = nlp_cognitive_stream.question(interact.data["msg"], username, interact.data.get("observation", None)) - - else: - text = answer - # 使用流式分割处理Q&A答案 - self.__process_stream_output(text, username, session_type="qa", is_qa=True) - - - return text - - elif (index == 2):#透传模式:有音频则仅播音频;仅文本则流式+TTS - audio_url = interact.data.get("audio") - text = interact.data.get("text") - - # 1) 存在音频:忽略文本,仅播放音频 - if audio_url and str(audio_url).strip(): - try: - audio_interact = Interact( - "stream", 1, - {"user": username, "msg": "", "isfirst": True, "isend": True, "audio": audio_url} - ) - self.say(audio_interact, "") - except Exception: - pass - return 'success' - - # 2) 只有文本:执行流式切分并TTS - if text and str(text).strip(): - # 进行流式处理(用于TTS,流式处理中会记录到数据库) - self.__process_stream_output(text, username, f"type2_{interact.interleaver}", is_qa=False) - - # 不再需要额外记录,因为流式处理已经记录了 - # self.__process_text_output(text, username, uid) - - return 'success' - - # 没有有效内容 - return 'success' - - except BaseException as e: - print(e) - return e - else: - return "还没有开始运行" - - #记录问答到log - def write_to_file(self, path, filename, content): - if not os.path.exists(path): - os.makedirs(path) - full_path = os.path.join(path, filename) - with open(full_path, 'w', encoding='utf-8') as file: - file.write(content) - file.flush() - os.fsync(file.fileno()) - - #触发交互 - def on_interact(self, interact: Interact): - #创建用户 - username = interact.data.get("user", "User") - if member_db.new_instance().is_username_exist(username) == "notexists": - member_db.new_instance().add_user(username) - no_reply = interact.data.get("no_reply", False) - if isinstance(no_reply, str): - no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on") - else: - no_reply = bool(no_reply) - - if not no_reply: - try: - from utils.stream_state_manager import get_state_manager - import uuid - if get_state_manager().is_session_active(username): - stream_manager.new_instance().clear_Stream_with_audio(username) - conv_id = "conv_" + str(uuid.uuid4()) - stream_manager.new_instance().set_current_conversation(username, conv_id) - # 将当前会话ID附加到交互数据 - interact.data["conversation_id"] = conv_id - # 允许新的生成 - stream_manager.new_instance().set_stop_generation(username, stop=False) - except Exception: - util.log(3, "开启新会话失败") - - if interact.interact_type == 1: - MyThread(target=self.__process_interact, args=[interact]).start() - else: - return self.__process_interact(interact) - - #获取不同情绪声音 - def __get_mood_voice(self): - voice = tts_voice.get_voice_of(config_util.config["attribute"]["voice"]) - if voice is None: - voice = EnumVoice.XIAO_XIAO - styleList = voice.value["styleList"] - sayType = styleList["calm"] - return sayType - - # 合成声音 - def say(self, interact, text, type = ""): - try: - uid = member_db.new_instance().find_user(interact.data.get("user")) - is_end = interact.data.get("isend", False) - is_first = interact.data.get("isfirst", False) - username = interact.data.get("user", "User") - - # 提前进行会话有效性与中断检查,避免产生多余面板/数字人输出 - try: - user_for_stop = interact.data.get("user", "User") - conv_id_for_stop = interact.data.get("conversation_id") - if not is_end and stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): - return None - except Exception: - pass - - #无效流式文本提前结束 - if not is_first and not is_end and (text is None or text.strip() == ""): - return None - - # 检查是否是 prestart 内容(不应该影响 thinking 状态) - is_prestart_content = self.__has_prestart(text) - - # 流式文本拼接存库 - content_id = 0 - # 使用 (username, conversation_id) 作为 key,避免并发会话覆盖 - conv = interact.data.get("conversation_id") or "" - conv_map_key = (username, conv) - - if is_first == True: - # reset any leftover think-mode at the start of a new reply - # 但如果是 prestart 内容,不重置 thinking 状态 - try: - if uid is not None and not is_prestart_content: - self.think_mode_users[uid] = False - if uid in self.think_time_users: - del self.think_time_users[uid] if uid in self.think_display_state: - del self.think_display_state[uid] - - except Exception: - pass - # 如果没有 conversation_id,生成一个新的 - if not conv: - conv = "conv_" + str(uuid.uuid4()) - conv_map_key = (username, conv) - conv_no = 0 - # 创建第一条数据库记录,获得content_id - if text and text.strip(): - content_id = content_db.new_instance().add_content('fay', 'speak', text, username, uid) - else: - content_id = content_db.new_instance().add_content('fay', 'speak', '', username, uid) - - # 保存content_id到会话映射中,使用 (username, conversation_id) 作为 key - self.user_conv_map[conv_map_key] = { - "conversation_id": conv, - "conversation_msg_no": conv_no, - "content_id": content_id - } - util.log(1, f"流式会话开始: key={conv_map_key}, content_id={content_id}") - else: - # 获取之前保存的content_id - conv_info = self.user_conv_map.get(conv_map_key, {}) - content_id = conv_info.get("content_id", 0) - - # 如果 conv_map_key 不存在,尝试使用 username 作为备用查找 - if not conv_info and text and text.strip(): - # 查找所有匹配用户名的会话 - for (u, c), info in list(self.user_conv_map.items()): - if u == username and info.get("content_id", 0) > 0: - content_id = info.get("content_id", 0) - conv_info = info - util.log(1, f"警告:使用备用会话 ({u}, {c}) 的 content_id={content_id},原 key=({username}, {conv})") - break - - if conv_info: - conv_info["conversation_msg_no"] = conv_info.get("conversation_msg_no", 0) + 1 - - # 如果有新内容,更新数据库 - if content_id > 0 and text and text.strip(): - # 获取当前已有内容 - existing_content = content_db.new_instance().get_content_by_id(content_id) - if existing_content: - # 累积内容 - accumulated_text = existing_content[3] + text - content_db.new_instance().update_content(content_id, accumulated_text) - elif content_id == 0 and text and text.strip(): - # content_id 为 0 表示可能会话 key 不匹配,记录警告 - util.log(1, f"警告:content_id=0,无法更新数据库。user={username}, conv={conv}, text片段={text[:50] if len(text) > 50 else text}") - - # 会话结束时清理 user_conv_map 中的对应条目,避免内存泄漏 - if is_end and conv_map_key in self.user_conv_map: - del self.user_conv_map[conv_map_key] - - # 推送给前端和数字人 - try: - user_for_stop = interact.data.get("user", "User") - conv_id_for_stop = interact.data.get("conversation_id") - if is_end or not stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): - self.__process_text_output(text, interact.data.get('user'), uid, content_id, type, is_first, is_end) - except Exception: - self.__process_text_output(text, interact.data.get('user'), uid, content_id, type, is_first, is_end) - - # 处理think标签 - # 第一步:处理结束标记 - if "" in text: - # 设置用户退出思考模式 - self.think_mode_users[uid] = False - - # 分割文本,提取后面的内容 - # 如果有多个,我们只关心最后一个后面的内容 - parts = text.split("") - text = parts[-1].strip() - - # 如果提取出的文本为空,则不需要继续处理 - if text == "": - return None - # 第二步:处理开始标记 - # 注意:这里要检查经过上面处理后的text - if "" in text: - self.think_mode_users[uid] = True - self.think_time_users[uid] = time.time() - - #”思考中“的输出 - if self.think_mode_users.get(uid, False): - try: - user_for_stop = interact.data.get("user", "User") - conv_id_for_stop = interact.data.get("conversation_id") - should_block = stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop) - except Exception: - should_block = False - if not should_block: - if wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}) - if wsa_server.get_instance().is_connected(interact.data.get("user")): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'} - wsa_server.get_instance().add_cmd(content) - - #”请稍等“的音频输出(不影响文本输出) - if self.think_mode_users.get(uid, False) == True and time.time() - self.think_time_users[uid] >= 5: - self.think_time_users[uid] = time.time() - text = "请稍等..." - elif self.think_mode_users.get(uid, False) == True and "" not in text: - return None - - result = None - audio_url = interact.data.get('audio', None)#透传的音频 - - # 移除 prestart 标签内容,不进行TTS - tts_text = self.__remove_prestart_tags(text) if text else text - - if audio_url is not None:#透传音频下载 - file_name = 'sample-' + str(int(time.time() * 1000)) + audio_url[-4:] - result = self.download_wav(audio_url, './samples/', file_name) - elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().get_client_output(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts - if tts_text != None and tts_text.replace("*", "").strip() != "": - # 检查是否需要停止TTS处理(按会话) - if stream_manager.new_instance().should_stop_generation( - interact.data.get("user", "User"), - conversation_id=interact.data.get("conversation_id") - ): - util.printInfo(1, interact.data.get('user'), 'TTS处理被打断,跳过音频合成') - return None - - # 先过滤表情符号,然后再合成语音 - filtered_text = self.__remove_emojis(tts_text.replace("*", "")) - if filtered_text is not None and filtered_text.strip() != "": - util.printInfo(1, interact.data.get('user'), '合成音频...') - tm = time.time() - result = self.sp.to_sample(filtered_text, self.__get_mood_voice()) - # 合成完成后再次检查会话是否仍有效,避免继续输出旧会话结果 - try: - user_for_stop = interact.data.get("user", "User") - conv_id_for_stop = interact.data.get("conversation_id") - if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): - return None - except Exception: - pass - util.printInfo(1, interact.data.get("user"), "合成音频完成. 耗时: {} ms 文件:{}".format(math.floor((time.time() - tm) * 1000), result)) - else: - # prestart 内容不应该触发机器人表情重置 - if is_end and not is_prestart_content and wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) - - if result is not None or is_first or is_end: - # prestart 内容不需要进入音频处理流程 - if is_prestart_content: - return result - if is_end:#TODO 临时方案:如果结束标记,则延迟1秒处理,免得is end比前面的音频tts要快 - time.sleep(1) - MyThread(target=self.__process_output_audio, args=[result, interact, text]).start() - return result - - except BaseException as e: - print(e) - return None - - #下载wav - def download_wav(self, url, save_directory, filename): - try: - # 发送HTTP GET请求以获取WAV文件内容 - response = requests.get(url, stream=True) - response.raise_for_status() # 检查请求是否成功 - - # 确保保存目录存在 - if not os.path.exists(save_directory): - os.makedirs(save_directory) - - # 构建保存文件的路径 - save_path = os.path.join(save_directory, filename) - - # 将WAV文件内容保存到指定文件 - with open(save_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - - return save_path - except requests.exceptions.RequestException as e: - print(f"[Error] Failed to download file: {e}") - return None - - - #面板播放声音 - def __play_sound(self): - try: - import pygame - pygame.mixer.init() # 初始化pygame.mixer,只需要在此处初始化一次, 如果初始化失败,则不播放音频 - except Exception as e: - util.printInfo(1, "System", "音频播放初始化失败,本机无法播放音频") - return - - while self.__running: - time.sleep(0.01) - if not self.sound_query.empty(): # 如果队列不为空则播放音频 - file_url, audio_length, interact = self.sound_query.get() - - is_first = interact.data.get('isfirst') is True - is_end = interact.data.get('isend') is True - - - - if file_url is not None: - util.printInfo(1, interact.data.get('user'), '播放音频...') - - if is_first: - self.speaking = True - elif not is_end: - self.speaking = True - - #自动播报关闭 - global auto_play_lock - global can_auto_play - with auto_play_lock: - if self.timer is not None: - self.timer.cancel() - self.timer = None - can_auto_play = False - - if wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "播放中 ...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}) - - if file_url is not None: - pygame.mixer.music.load(file_url) - pygame.mixer.music.play() - - # 播放过程中计时,直到音频播放完毕 - length = 0 - while length < audio_length: - try: - user_for_stop = interact.data.get("user", "User") - conv_id_for_stop = interact.data.get("conversation_id") - if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): - try: - pygame.mixer.music.stop() - except Exception: - pass - break - except Exception: - pass - length += 0.01 - time.sleep(0.01) - - if is_end: - self.play_end(interact) - - if wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) - # 播放完毕后通知 - if wsa_server.get_web_instance().is_connected(interact.data.get("user")): - wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username': interact.data.get('user')}) - - #推送远程音频 - def __send_remote_device_audio(self, file_url, interact): - if file_url is None: - return - delkey = None - for key, value in fay_booter.DeviceInputListenerDict.items(): - if value.username == interact.data.get("user") and value.isOutput: #按username选择推送,booter.devicelistenerdice按用户名记录 - try: - value.deviceConnector.send(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08") # 发送音频开始标志,同时也检查设备是否在线 - wavfile = open(os.path.abspath(file_url), "rb") - data = wavfile.read(102400) - total = 0 - while data: - total += len(data) - value.deviceConnector.send(data) - data = wavfile.read(102400) - time.sleep(0.0001) - value.deviceConnector.send(b'\x08\x07\x06\x05\x04\x03\x02\x01\x00')# 发送音频结束标志 - util.printInfo(1, value.username, "远程音频发送完成:{}".format(total)) - except socket.error as serr: - util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key)) - value.stop() - delkey = key - if delkey: - value = fay_booter.DeviceInputListenerDict.pop(delkey) - if wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"remote_audio_connect": False, "Username" : interact.data.get('user')}) - - def __is_send_remote_device_audio(self, interact): - for key, value in fay_booter.DeviceInputListenerDict.items(): - if value.username == interact.data.get("user") and value.isOutput: - return True - return False - - #输出音频处理 - def __process_output_audio(self, file_url, interact, text): - try: - # 会话有效性与中断检查(最早返回,避免向面板/数字人发送任何旧会话输出) - try: - user_for_stop = interact.data.get("user", "User") - conv_id_for_stop = interact.data.get("conversation_id") - if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): - return - except Exception: - pass - try: - if file_url is None: - audio_length = 0 - elif file_url.endswith('.wav'): - audio = AudioSegment.from_wav(file_url) - audio_length = len(audio) / 1000.0 # 时长以秒为单位 - elif file_url.endswith('.mp3'): - audio = AudioSegment.from_mp3(file_url) - audio_length = len(audio) / 1000.0 # 时长以秒为单位 - except Exception as e: - audio_length = 3 - - #推送远程音频 - if file_url is not None: - MyThread(target=self.__send_remote_device_audio, args=[file_url, interact]).start() - - #发送音频给数字人接口 - if file_url is not None and wsa_server.get_instance().get_client_output(interact.data.get("user")): - # 使用 (username, conversation_id) 作为 key 获取会话信息 - audio_username = interact.data.get("user", "User") - audio_conv_id = interact.data.get("conversation_id") or "" - audio_conv_info = self.user_conv_map.get((audio_username, audio_conv_id), {}) - content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0, 'CONV_ID' : audio_conv_info.get("conversation_id", ""), 'CONV_MSG_NO' : audio_conv_info.get("conversation_msg_no", 0) }, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'} - #计算lips - if platform.system() == "Windows": - try: - lip_sync_generator = LipSyncGenerator() - viseme_list = lip_sync_generator.generate_visemes(os.path.abspath(file_url)) - consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list) - content["Data"]["Lips"] = consolidated_visemes - except Exception as e: - print(e) - util.printInfo(1, interact.data.get("user"), "唇型数据生成失败") - wsa_server.get_instance().add_cmd(content) - util.printInfo(1, interact.data.get("user"), "数字人接口发送音频数据成功") - - #面板播放 - config_util.load_config() - # 检查是否是 prestart 内容 - is_prestart = self.__has_prestart(text) - if config_util.config["interact"]["playSound"]: - # prestart 内容不应该进入播放队列,避免触发 Normal 状态 - if not is_prestart: - self.sound_query.put((file_url, audio_length, interact)) - else: - # prestart 内容不应该重置机器人表情 - if not is_prestart and wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) - - except Exception as e: - print(e) - - def play_end(self, interact): - self.speaking = False - global can_auto_play - global auto_play_lock - with auto_play_lock: - if self.timer: - self.timer.cancel() - self.timer = None - if interact.interleaver != 'auto_play': #交互后暂停自动播报30秒 - self.timer = threading.Timer(30, self.set_auto_play) - self.timer.start() - else: - can_auto_play = True - - #恢复自动播报(如果有) - def set_auto_play(self): - global auto_play_lock - global can_auto_play - with auto_play_lock: - can_auto_play = True - self.timer = None - - #启动核心服务 - def start(self): - MyThread(target=self.__play_sound).start() - - #停止核心服务 - def stop(self): - self.__running = False - self.speaking = False - self.sp.close() - wsa_server.get_web_instance().add_cmd({"panelMsg": ""}) - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': ""}} - wsa_server.get_instance().add_cmd(content) - - def __record_response(self, text, username, uid): - """ - 记录AI的回复内容 - :param text: 回复文本 - :param username: 用户名 - :param uid: 用户ID - :return: content_id - """ - self.write_to_file("./logs", "answer_result.txt", text) - return content_db.new_instance().add_content('fay', 'speak', text, username, uid) - - def __remove_prestart_tags(self, text): - """ - 移除文本中的 prestart 标签及其内容 - :param text: 原始文本 - :return: 移除 prestart 标签后的文本 - """ - if not text: - return text - import re - # 移除 ... 标签及其内容(支持属性) - cleaned = re.sub(r']*>[\s\S]*?', '', text, flags=re.IGNORECASE) - return cleaned.strip() - - def __has_prestart(self, text): - """ - 判断文本中是否包含 prestart 标签(支持属性) - """ - if not text: - return False - return re.search(r']*>[\s\S]*?', text, flags=re.IGNORECASE) is not None - - - def __truncate_think_for_panel(self, text, uid, username): - - if not text or not isinstance(text, str): - - return text - - key = uid if uid is not None else username - - state = self.think_display_state.get(key) - - if state is None: - - state = {"in_think": False, "in_tool_output": False, "tool_count": 0, "tool_truncated": False} - - self.think_display_state[key] = state - - if not state["in_think"] and "" not in text and "" not in text: - - return text - - tool_output_regex = re.compile(r"\[TOOL\]\s*(?:Output|\u8f93\u51fa)[:\uff1a]", re.IGNORECASE) - - section_regex = re.compile(r"(?i)(^|[\r\n])(\[(?:TOOL|PLAN)\])") - - out = [] - - i = 0 - - while i < len(text): - - if not state["in_think"]: - - idx = text.find("", i) - - if idx == -1: - - out.append(text[i:]) - - break - - out.append(text[i:idx + len("")]) - - state["in_think"] = True - - i = idx + len("") - - continue - - if not state["in_tool_output"]: - - think_end = text.find("", i) - - tool_match = tool_output_regex.search(text, i) - - next_pos = None - - next_kind = None - - if tool_match: - - next_pos = tool_match.start() - - next_kind = "tool" - - if think_end != -1 and (next_pos is None or think_end < next_pos): - - next_pos = think_end - - next_kind = "think_end" - - if next_pos is None: - - out.append(text[i:]) - - break - - if next_pos > i: - - out.append(text[i:next_pos]) - - if next_kind == "think_end": - - out.append("") - - state["in_think"] = False - - state["in_tool_output"] = False - - state["tool_count"] = 0 - - state["tool_truncated"] = False - - i = next_pos + len("") - - else: - - marker_end = tool_match.end() - - out.append(text[next_pos:marker_end]) - - state["in_tool_output"] = True - - state["tool_count"] = 0 - - state["tool_truncated"] = False - - i = marker_end - - continue - - think_end = text.find("", i) - - section_match = section_regex.search(text, i) - - end_pos = None - - if section_match: - - end_pos = section_match.start(2) - - if think_end != -1 and (end_pos is None or think_end < end_pos): - - end_pos = think_end - - segment = text[i:] if end_pos is None else text[i:end_pos] - - if segment: - - if state["tool_truncated"]: - - pass - - else: - - remaining = self.think_display_limit - state["tool_count"] - - if remaining <= 0: - - out.append("...") - - state["tool_truncated"] = True - - elif len(segment) <= remaining: - - out.append(segment) - - state["tool_count"] += len(segment) - - else: - - out.append(segment[:remaining] + "...") - - state["tool_count"] += remaining - - state["tool_truncated"] = True - - if end_pos is None: - - break - - state["in_tool_output"] = False - - state["tool_count"] = 0 - - state["tool_truncated"] = False - - i = end_pos - - return "".join(out) - - def __send_panel_message(self, text, username, uid, content_id=None, type=None): - """ - 发送消息到Web面板 - :param text: 消息文本 - :param username: 用户名 - :param uid: 用户ID - :param content_id: 内容ID - :param type: 消息类型 - """ - if not wsa_server.get_web_instance().is_connected(username): - return - - # 检查是否是 prestart 内容,prestart 内容不应该更新日志区消息 - # 因为这会覆盖掉"思考中..."的状态显示 - is_prestart = self.__has_prestart(text) display_text = self.__truncate_think_for_panel(text, uid, username) +# -*- coding: utf-8 -*- + + +#作用是处理交互逻辑,文字输入,语音、文字及情绪的发送、播放及展示输出 + + +import math + + +from operator import index + + +import os + + +import time + + +import socket + + +import requests + + +from pydub import AudioSegment + + +from queue import Queue + + +import re # 添加正则表达式模块用于过滤表情符号 + + +import uuid + + + + + +# 适应模型使用 + + +import numpy as np + + +from ai_module import baidu_emotion + + +from core import wsa_server + + +from core.interact import Interact + + +from tts.tts_voice import EnumVoice + + +from scheduler.thread_manager import MyThread + + +from tts import tts_voice + + +from utils import util, config_util + + +from core import qa_service + + +from utils import config_util as cfg + + +from core import content_db + + +from ai_module import nlp_cemotion + + +from core import stream_manager + + + + + +from core import member_db + + +import threading + + + + + +#加载配置 + + +cfg.load_config() + + +if cfg.tts_module =='ali': + + + from tts.ali_tss import Speech + + +elif cfg.tts_module == 'gptsovits': + + + from tts.gptsovits import Speech + + +elif cfg.tts_module == 'gptsovits_v3': + + + from tts.gptsovits_v3 import Speech + + +elif cfg.tts_module == 'volcano': + + + from tts.volcano_tts import Speech + + +else: + + + from tts.ms_tts_sdk import Speech + + + + + +#windows运行推送唇形数据 + + +import platform + + +if platform.system() == "Windows": + + + import sys + + + sys.path.append("test/ovr_lipsync") + + + from test_olipsync import LipSyncGenerator + + + + + + + + +#可以使用自动播报的标记 + + +can_auto_play = True + + +auto_play_lock = threading.RLock() + + + + + +class FeiFei: + + + def __init__(self): + + + self.lock = threading.Lock() + + + self.nlp_streams = {} # 存储用户ID到句子缓存的映射 + + + self.nlp_stream_lock = threading.Lock() # 保护nlp_streams字典的锁 + + + self.mood = 0.0 # 情绪值 + + + self.old_mood = 0.0 + + + self.item_index = 0 + + + self.X = np.array([1, 0, 0, 0, 0, 0, 0, 0]).reshape(1, -1) # 适应模型变量矩阵 + + + # self.W = np.array([0.01577594,1.16119452,0.75828,0.207746,1.25017864,0.1044121,0.4294899,0.2770932]).reshape(-1,1) #适应模型变量矩阵 + + + self.W = np.array([0.0, 0.6, 0.1, 0.7, 0.3, 0.0, 0.0, 0.0]).reshape(-1, 1) # 适应模型变量矩阵 + + + + + + self.wsParam = None + + + self.wss = None + + + self.sp = Speech() + + + self.speaking = False #声音是否在播放 + + + self.__running = True + + + self.sp.connect() #TODO 预连接 + + + + + + self.timer = None + + + self.sound_query = Queue() + + + self.think_mode_users = {} # 使用字典存储每个用户的think模式状态 + + + self.think_time_users = {} #使用字典存储每个用户的think开始时间 + self.think_display_state = {} + self.think_display_limit = 400 + self.user_conv_map = {} #存储用户对话id及句子流序号,key为(username, conversation_id) + + self.pending_isfirst = {} # 存储因prestart被过滤而延迟的isfirst标记,key为username + + + + + def __remove_emojis(self, text): + + + """ + + + 改进的表情包过滤,避免误删除正常Unicode字符 + + + """ + + + # 更精确的emoji范围,避免误删除正常字符 + + + emoji_pattern = re.compile( + + + "[" + + + "\U0001F600-\U0001F64F" # 表情符号 (Emoticons) + + + "\U0001F300-\U0001F5FF" # 杂项符号和象形文字 (Miscellaneous Symbols and Pictographs) + + + "\U0001F680-\U0001F6FF" # 交通和地图符号 (Transport and Map Symbols) + + + "\U0001F1E0-\U0001F1FF" # 区域指示符号 (Regional Indicator Symbols) + + + "\U0001F900-\U0001F9FF" # 补充符号和象形文字 (Supplemental Symbols and Pictographs) + + + "\U0001FA70-\U0001FAFF" # 扩展A符号和象形文字 (Symbols and Pictographs Extended-A) + + + "\U00002600-\U000026FF" # 杂项符号 (Miscellaneous Symbols) + + + "\U00002700-\U000027BF" # 装饰符号 (Dingbats) + + + "\U0000FE00-\U0000FE0F" # 变体选择器 (Variation Selectors) + + + "\U0001F000-\U0001F02F" # 麻将牌 (Mahjong Tiles) + + + "\U0001F0A0-\U0001F0FF" # 扑克牌 (Playing Cards) + + + "]+", + + + flags=re.UNICODE, + + + ) + + + + + + # 保护常用的中文标点符号和特殊字符 + + + protected_chars = ["。", ",", "!", "?", ":", ";", "、", """, """, "'", "'", "(", ")", "【", "】", "《", "》"] + + + + + + # 先保存保护字符的位置 + + + protected_positions = {} + + + for i, char in enumerate(text): + + + if char in protected_chars: + + + protected_positions[i] = char + + + + + + # 执行emoji过滤 + + + filtered_text = emoji_pattern.sub('', text) + + + + + + # 如果过滤后文本长度变化太大,可能误删了正常字符,返回原文本 + + + if len(filtered_text) < len(text) * 0.5: # 如果删除了超过50%的内容 + + + return text + + + + + + return filtered_text + + + + + + def __process_stream_output(self, text, username, session_type="type2_stream", is_qa=False): + + + """ + + + 按流式方式分割和发送 type=2 的文本 + + + 使用安全的流式文本处理器和状态管理器 + + + """ + + + if not text or text.strip() == "": + + + return + + + + + + # 使用安全的流式文本处理器 + + + from utils.stream_text_processor import get_processor + + + from utils.stream_state_manager import get_state_manager + + + + + + processor = get_processor() + + + state_manager = get_state_manager() + + + + + + # 处理流式文本,is_qa=False表示普通模式 + + + success = processor.process_stream_text(text, username, is_qa=is_qa, session_type=session_type) + + + + + + if success: + + + # 普通模式结束会话 + + + state_manager.end_session(username, conversation_id=stream_manager.new_instance().get_conversation_id(username)) + + + else: + + + util.log(1, f"type=2流式处理失败,文本长度: {len(text)}") + + + # 失败时也要确保结束会话 + + + state_manager.force_reset_user_state(username) + + + + + + #语音消息处理检查是否命中q&a + + + def __get_answer(self, interleaver, text): + + + answer = None + + + # 全局问答 + + + answer, type = qa_service.QAService().question('qa',text) + + + if answer is not None: + + + return answer, type + + + else: + + + return None, None + + + + + + + + + #消息处理 + + + def __process_interact(self, interact: Interact): + + + if self.__running: + + + try: + + + index = interact.interact_type + + + username = interact.data.get("user", "User") + + + uid = member_db.new_instance().find_user(username) + no_reply = interact.data.get("no_reply", False) + if isinstance(no_reply, str): + no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on") + else: + no_reply = bool(no_reply) + + + + + + if index == 1: #语音、文字交互 + + + + + + #记录用户问题,方便obs等调用 + + + self.write_to_file("./logs", "asr_result.txt", interact.data["msg"]) + + + + + + #同步用户问题到数字人 + + + if wsa_server.get_instance().is_connected(username): + + + content = {'Topic': 'human', 'Data': {'Key': 'question', 'Value': interact.data["msg"]}, 'Username' : interact.data.get("user")} + + + wsa_server.get_instance().add_cmd(content) + + + + + + #记录用户问题 + + + if not no_reply: + content_id = content_db.new_instance().add_content('member','speak',interact.data["msg"], username, uid) + if wsa_server.get_web_instance().is_connected(username): + wsa_server.get_web_instance().add_cmd({"panelReply": {"type":"member","content":interact.data["msg"], "username":username, "uid":uid, "id":content_id}, "Username" : username}) + + + + + + observation = interact.data.get("observation", None) + obs_text = "" + if observation is not None: + obs_text = observation.strip() if isinstance(observation, str) else str(observation).strip() + if not obs_text and no_reply: + msg_text = interact.data.get("msg", "") + obs_text = msg_text.strip() if isinstance(msg_text, str) else str(msg_text).strip() + if obs_text: + from llm import nlp_cognitive_stream + nlp_cognitive_stream.record_observation(username, obs_text) + if no_reply: + return "" + + #确定是否命中q&a + + + answer, type = self.__get_answer(interact.interleaver, interact.data["msg"]) + + + + + + #大语言模型回复 + + + text = '' + + + if answer is None or type != "qa": + + + if wsa_server.get_web_instance().is_connected(username): + + + wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}) + + + if wsa_server.get_instance().is_connected(username): + + + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'} + + + wsa_server.get_instance().add_cmd(content) + + + + + + # 根据配置动态调用不同的NLP模块 + + + if cfg.config["memory"].get("use_bionic_memory", False): + + + from llm import nlp_bionicmemory_stream + + + text = nlp_bionicmemory_stream.question(interact.data["msg"], username, interact.data.get("observation", None)) + + + else: + + + from llm import nlp_cognitive_stream + + + text = nlp_cognitive_stream.question(interact.data["msg"], username, interact.data.get("observation", None)) + + + + + + else: + + + text = answer + + + # 使用流式分割处理Q&A答案 + + + self.__process_stream_output(text, username, session_type="qa", is_qa=True) + + + + + + + + + return text + + + + + + elif (index == 2):#透传模式:有音频则仅播音频;仅文本则流式+TTS + + + audio_url = interact.data.get("audio") + + + text = interact.data.get("text") + + + + + + # 1) 存在音频:忽略文本,仅播放音频 + + + if audio_url and str(audio_url).strip(): + + + try: + + + audio_interact = Interact( + + + "stream", 1, + + + {"user": username, "msg": "", "isfirst": True, "isend": True, "audio": audio_url} + + + ) + + + self.say(audio_interact, "") + + + except Exception: + + + pass + + + return 'success' + + + + + + # 2) 只有文本:执行流式切分并TTS + + + if text and str(text).strip(): + + + # 进行流式处理(用于TTS,流式处理中会记录到数据库) + + + self.__process_stream_output(text, username, f"type2_{interact.interleaver}", is_qa=False) + + + + + + # 不再需要额外记录,因为流式处理已经记录了 + + + # self.__process_text_output(text, username, uid) + + + + + + return 'success' + + + + + + # 没有有效内容 + + + return 'success' + + + + + + except BaseException as e: + + + print(e) + + + return e + + + else: + + + return "还没有开始运行" + + + + + + #记录问答到log + + + def write_to_file(self, path, filename, content): + + + if not os.path.exists(path): + + + os.makedirs(path) + + + full_path = os.path.join(path, filename) + + + with open(full_path, 'w', encoding='utf-8') as file: + + + file.write(content) + + + file.flush() + + + os.fsync(file.fileno()) + + + + + + #触发交互 + + + def on_interact(self, interact: Interact): + + + #创建用户 + + + username = interact.data.get("user", "User") + + + if member_db.new_instance().is_username_exist(username) == "notexists": + + + member_db.new_instance().add_user(username) + + + no_reply = interact.data.get("no_reply", False) + + if isinstance(no_reply, str): + + no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on") + + else: + + no_reply = bool(no_reply) + + + + if not no_reply: + + try: + + + from utils.stream_state_manager import get_state_manager + + + import uuid + + + if get_state_manager().is_session_active(username): + + + stream_manager.new_instance().clear_Stream_with_audio(username) + + + conv_id = "conv_" + str(uuid.uuid4()) + + + stream_manager.new_instance().set_current_conversation(username, conv_id) + + + # 将当前会话ID附加到交互数据 + + + interact.data["conversation_id"] = conv_id + + + # 允许新的生成 + + + stream_manager.new_instance().set_stop_generation(username, stop=False) + + + except Exception: + + + util.log(3, "开启新会话失败") + + + + + + if interact.interact_type == 1: + + + MyThread(target=self.__process_interact, args=[interact]).start() + + + else: + + + return self.__process_interact(interact) + + + + + + #获取不同情绪声音 + + + def __get_mood_voice(self): + + + voice = tts_voice.get_voice_of(config_util.config["attribute"]["voice"]) + + + if voice is None: + + + voice = EnumVoice.XIAO_XIAO + + + styleList = voice.value["styleList"] + + + sayType = styleList["calm"] + + + return sayType + + + + + + # 合成声音 + + + def say(self, interact, text, type = ""): + + + try: + + + uid = member_db.new_instance().find_user(interact.data.get("user")) + + + is_end = interact.data.get("isend", False) + + + is_first = interact.data.get("isfirst", False) + + + username = interact.data.get("user", "User") + + + + + + # 提前进行会话有效性与中断检查,避免产生多余面板/数字人输出 + + + try: + + + user_for_stop = interact.data.get("user", "User") + + + conv_id_for_stop = interact.data.get("conversation_id") + + + if not is_end and stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): + + + return None + + + except Exception: + + + pass + + + + + + #无效流式文本提前结束 + + + if not is_first and not is_end and (text is None or text.strip() == ""): + + + return None + + + + + + # 检查是否是 prestart 内容(不应该影响 thinking 状态) + + + is_prestart_content = self.__has_prestart(text) + + + + + # 流式文本拼接存库 + + + content_id = 0 + + + # 使用 (username, conversation_id) 作为 key,避免并发会话覆盖 + + + conv = interact.data.get("conversation_id") or "" + + + conv_map_key = (username, conv) + + + + + + if is_first == True: + + + # reset any leftover think-mode at the start of a new reply + + + # 但如果是 prestart 内容,不重置 thinking 状态 + + + try: + + + if uid is not None and not is_prestart_content: + + + self.think_mode_users[uid] = False + + + if uid in self.think_time_users: + + + del self.think_time_users[uid] + if uid in self.think_display_state: + del self.think_display_state[uid] + + + except Exception: + + + pass + + + # 如果没有 conversation_id,生成一个新的 + + + if not conv: + + + conv = "conv_" + str(uuid.uuid4()) + + + conv_map_key = (username, conv) + + + conv_no = 0 + + + # 创建第一条数据库记录,获得content_id + + + if text and text.strip(): + + + content_id = content_db.new_instance().add_content('fay', 'speak', text, username, uid) + + + else: + + + content_id = content_db.new_instance().add_content('fay', 'speak', '', username, uid) + + + + + + # 保存content_id到会话映射中,使用 (username, conversation_id) 作为 key + + + self.user_conv_map[conv_map_key] = { + + + "conversation_id": conv, + + + "conversation_msg_no": conv_no, + + + "content_id": content_id + + + } + + + util.log(1, f"流式会话开始: key={conv_map_key}, content_id={content_id}") + + + else: + + + # 获取之前保存的content_id + + + conv_info = self.user_conv_map.get(conv_map_key, {}) + + + content_id = conv_info.get("content_id", 0) + + + + + + # 如果 conv_map_key 不存在,尝试使用 username 作为备用查找 + + + if not conv_info and text and text.strip(): + + + # 查找所有匹配用户名的会话 + + + for (u, c), info in list(self.user_conv_map.items()): + + + if u == username and info.get("content_id", 0) > 0: + + + content_id = info.get("content_id", 0) + + + conv_info = info + + + util.log(1, f"警告:使用备用会话 ({u}, {c}) 的 content_id={content_id},原 key=({username}, {conv})") + + + break + + + + + + if conv_info: + + + conv_info["conversation_msg_no"] = conv_info.get("conversation_msg_no", 0) + 1 + + + + + + # 如果有新内容,更新数据库 + + + if content_id > 0 and text and text.strip(): + + + # 获取当前已有内容 + + + existing_content = content_db.new_instance().get_content_by_id(content_id) + + + if existing_content: + + + # 累积内容 + + + accumulated_text = existing_content[3] + text + + + content_db.new_instance().update_content(content_id, accumulated_text) + + + elif content_id == 0 and text and text.strip(): + + + # content_id 为 0 表示可能会话 key 不匹配,记录警告 + + + util.log(1, f"警告:content_id=0,无法更新数据库。user={username}, conv={conv}, text片段={text[:50] if len(text) > 50 else text}") + + + + + + # 会话结束时清理 user_conv_map 中的对应条目,避免内存泄漏 + + + if is_end and conv_map_key in self.user_conv_map: + + + del self.user_conv_map[conv_map_key] + + + + + + # 推送给前端和数字人 + + + try: + + + user_for_stop = interact.data.get("user", "User") + + + conv_id_for_stop = interact.data.get("conversation_id") + + + if is_end or not stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): + + + self.__process_text_output(text, interact.data.get('user'), uid, content_id, type, is_first, is_end) + + + except Exception: + + + self.__process_text_output(text, interact.data.get('user'), uid, content_id, type, is_first, is_end) + + + + + + # 处理think标签 + + + # 第一步:处理结束标记 + + + if "" in text: + + + # 设置用户退出思考模式 + + + self.think_mode_users[uid] = False + + + + + + # 分割文本,提取后面的内容 + + + # 如果有多个,我们只关心最后一个后面的内容 + + + parts = text.split("") + + + text = parts[-1].strip() + + + + + + # 如果提取出的文本为空,则不需要继续处理 + + + if text == "": + + + return None + + + # 第二步:处理开始标记 + + + # 注意:这里要检查经过上面处理后的text + + + if "" in text: + + + self.think_mode_users[uid] = True + + + self.think_time_users[uid] = time.time() + + + + + + #”思考中“的输出 + + + if self.think_mode_users.get(uid, False): + + + try: + + + user_for_stop = interact.data.get("user", "User") + + + conv_id_for_stop = interact.data.get("conversation_id") + + + should_block = stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop) + + + except Exception: + + + should_block = False + + + if not should_block: + + + if wsa_server.get_web_instance().is_connected(interact.data.get('user')): + + + wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}) + + + if wsa_server.get_instance().is_connected(interact.data.get("user")): + + + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'} + + + wsa_server.get_instance().add_cmd(content) + + + + + + #”请稍等“的音频输出(不影响文本输出) + + + if self.think_mode_users.get(uid, False) == True and time.time() - self.think_time_users[uid] >= 5: + + + self.think_time_users[uid] = time.time() + + + text = "请稍等..." + + + elif self.think_mode_users.get(uid, False) == True and "" not in text: + + + return None + + + + + + result = None + + + audio_url = interact.data.get('audio', None)#透传的音频 + + + + + + # 移除 prestart 标签内容,不进行TTS + + + tts_text = self.__remove_prestart_tags(text) if text else text + + + + + + if audio_url is not None:#透传音频下载 + + + file_name = 'sample-' + str(int(time.time() * 1000)) + audio_url[-4:] + + + result = self.download_wav(audio_url, './samples/', file_name) + + + elif config_util.config["interact"]["playSound"] or wsa_server.get_instance().get_client_output(interact.data.get("user")) or self.__is_send_remote_device_audio(interact):#tts + + + if tts_text != None and tts_text.replace("*", "").strip() != "": + + + # 检查是否需要停止TTS处理(按会话) + + + if stream_manager.new_instance().should_stop_generation( + + + interact.data.get("user", "User"), + + + conversation_id=interact.data.get("conversation_id") + + + ): + + + util.printInfo(1, interact.data.get('user'), 'TTS处理被打断,跳过音频合成') + + + return None + + + + + + # 先过滤表情符号,然后再合成语音 + + + filtered_text = self.__remove_emojis(tts_text.replace("*", "")) + + + if filtered_text is not None and filtered_text.strip() != "": + + + util.printInfo(1, interact.data.get('user'), '合成音频...') + + + tm = time.time() + + + result = self.sp.to_sample(filtered_text, self.__get_mood_voice()) + + + # 合成完成后再次检查会话是否仍有效,避免继续输出旧会话结果 + + + try: + + + user_for_stop = interact.data.get("user", "User") + + + conv_id_for_stop = interact.data.get("conversation_id") + + + if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): + + + return None + + + except Exception: + + + pass + + + util.printInfo(1, interact.data.get("user"), "合成音频完成. 耗时: {} ms 文件:{}".format(math.floor((time.time() - tm) * 1000), result)) + + + else: + + + # prestart 内容不应该触发机器人表情重置 + + + if is_end and not is_prestart_content and wsa_server.get_web_instance().is_connected(interact.data.get('user')): + + + wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) + + + + + + if result is not None or is_first or is_end: + + + # prestart 内容不需要进入音频处理流程 + + + if is_prestart_content: + + + return result + + + if is_end:#TODO 临时方案:如果结束标记,则延迟1秒处理,免得is end比前面的音频tts要快 + + + time.sleep(1) + + + MyThread(target=self.__process_output_audio, args=[result, interact, text]).start() + + + return result + + + + + + except BaseException as e: + + + print(e) + + + return None + + + + + + #下载wav + + + def download_wav(self, url, save_directory, filename): + + + try: + + + # 发送HTTP GET请求以获取WAV文件内容 + + + response = requests.get(url, stream=True) + + + response.raise_for_status() # 检查请求是否成功 + + + + + + # 确保保存目录存在 + + + if not os.path.exists(save_directory): + + + os.makedirs(save_directory) + + + + + + # 构建保存文件的路径 + + + save_path = os.path.join(save_directory, filename) + + + + + + # 将WAV文件内容保存到指定文件 + + + with open(save_path, 'wb') as f: + + + for chunk in response.iter_content(chunk_size=1024): + + + if chunk: + + + f.write(chunk) + + + + + + return save_path + + + except requests.exceptions.RequestException as e: + + + print(f"[Error] Failed to download file: {e}") + + + return None + + + + + + + + + #面板播放声音 + + + def __play_sound(self): + + + try: + + + import pygame + + + pygame.mixer.init() # 初始化pygame.mixer,只需要在此处初始化一次, 如果初始化失败,则不播放音频 + + + except Exception as e: + + + util.printInfo(1, "System", "音频播放初始化失败,本机无法播放音频") + + + return + + + + + + while self.__running: + + + time.sleep(0.01) + + + if not self.sound_query.empty(): # 如果队列不为空则播放音频 + + + file_url, audio_length, interact = self.sound_query.get() + + + + + + is_first = interact.data.get('isfirst') is True + + + is_end = interact.data.get('isend') is True + + + + + + + + + + + + if file_url is not None: + + + util.printInfo(1, interact.data.get('user'), '播放音频...') + + + + + + if is_first: + + + self.speaking = True + + + elif not is_end: + + + self.speaking = True + + + + + + #自动播报关闭 + + + global auto_play_lock + + + global can_auto_play + + + with auto_play_lock: + + + if self.timer is not None: + + + self.timer.cancel() + + + self.timer = None + + + can_auto_play = False + + + + + + if wsa_server.get_web_instance().is_connected(interact.data.get('user')): + + + wsa_server.get_web_instance().add_cmd({"panelMsg": "播放中 ...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}) + + + + + + if file_url is not None: + + + pygame.mixer.music.load(file_url) + + + pygame.mixer.music.play() + + + + + + # 播放过程中计时,直到音频播放完毕 + + + length = 0 + + + while length < audio_length: + + + try: + + + user_for_stop = interact.data.get("user", "User") + + + conv_id_for_stop = interact.data.get("conversation_id") + + + if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): + + + try: + + + pygame.mixer.music.stop() + + + except Exception: + + + pass + + + break + + + except Exception: + + + pass + + + length += 0.01 + + + time.sleep(0.01) + + + + + + if is_end: + + + self.play_end(interact) + + + + + + if wsa_server.get_web_instance().is_connected(interact.data.get('user')): + + + wsa_server.get_web_instance().add_cmd({"panelMsg": "", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) + + + # 播放完毕后通知 + + + if wsa_server.get_web_instance().is_connected(interact.data.get("user")): + + + wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username': interact.data.get('user')}) + + + + + + #推送远程音频 + + + def __send_remote_device_audio(self, file_url, interact): + + + if file_url is None: + + + return + + + delkey = None + + + for key, value in fay_booter.DeviceInputListenerDict.items(): + + + if value.username == interact.data.get("user") and value.isOutput: #按username选择推送,booter.devicelistenerdice按用户名记录 + + + try: + + + value.deviceConnector.send(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08") # 发送音频开始标志,同时也检查设备是否在线 + + + wavfile = open(os.path.abspath(file_url), "rb") + + + data = wavfile.read(102400) + + + total = 0 + + + while data: + + + total += len(data) + + + value.deviceConnector.send(data) + + + data = wavfile.read(102400) + + + time.sleep(0.0001) + + + value.deviceConnector.send(b'\x08\x07\x06\x05\x04\x03\x02\x01\x00')# 发送音频结束标志 + + + util.printInfo(1, value.username, "远程音频发送完成:{}".format(total)) + + + except socket.error as serr: + + + util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key)) + + + value.stop() + + + delkey = key + + + if delkey: + + + value = fay_booter.DeviceInputListenerDict.pop(delkey) + + + if wsa_server.get_web_instance().is_connected(interact.data.get('user')): + + + wsa_server.get_web_instance().add_cmd({"remote_audio_connect": False, "Username" : interact.data.get('user')}) + + + + + + def __is_send_remote_device_audio(self, interact): + + + for key, value in fay_booter.DeviceInputListenerDict.items(): + + + if value.username == interact.data.get("user") and value.isOutput: + + + return True + + + return False + + + + + + #输出音频处理 + + + def __process_output_audio(self, file_url, interact, text): + + + try: + + + # 会话有效性与中断检查(最早返回,避免向面板/数字人发送任何旧会话输出) + + + try: + + + user_for_stop = interact.data.get("user", "User") + + + conv_id_for_stop = interact.data.get("conversation_id") + + + if stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop): + + + return + + + except Exception: + + + pass + + + try: + + + if file_url is None: + + + audio_length = 0 + + + elif file_url.endswith('.wav'): + + + audio = AudioSegment.from_wav(file_url) + + + audio_length = len(audio) / 1000.0 # 时长以秒为单位 + + + elif file_url.endswith('.mp3'): + + + audio = AudioSegment.from_mp3(file_url) + + + audio_length = len(audio) / 1000.0 # 时长以秒为单位 + + + except Exception as e: + + + audio_length = 3 + + + + + + #推送远程音频 + + + if file_url is not None: + + + MyThread(target=self.__send_remote_device_audio, args=[file_url, interact]).start() + + + + + + #发送音频给数字人接口 + + + if file_url is not None and wsa_server.get_instance().get_client_output(interact.data.get("user")): + + + # 使用 (username, conversation_id) 作为 key 获取会话信息 + + + audio_username = interact.data.get("user", "User") + + + audio_conv_id = interact.data.get("conversation_id") or "" + + + audio_conv_info = self.user_conv_map.get((audio_username, audio_conv_id), {}) + + + content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0, 'CONV_ID' : audio_conv_info.get("conversation_id", ""), 'CONV_MSG_NO' : audio_conv_info.get("conversation_msg_no", 0) }, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'} + + + #计算lips + + + if platform.system() == "Windows": + + + try: + + + lip_sync_generator = LipSyncGenerator() + + + viseme_list = lip_sync_generator.generate_visemes(os.path.abspath(file_url)) + + + consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list) + + + content["Data"]["Lips"] = consolidated_visemes + + + except Exception as e: + + + print(e) + + + util.printInfo(1, interact.data.get("user"), "唇型数据生成失败") + + + wsa_server.get_instance().add_cmd(content) + + + util.printInfo(1, interact.data.get("user"), "数字人接口发送音频数据成功") + + + + + + #面板播放 + + + config_util.load_config() + + + # 检查是否是 prestart 内容 + + + is_prestart = self.__has_prestart(text) + + if config_util.config["interact"]["playSound"]: + + + # prestart 内容不应该进入播放队列,避免触发 Normal 状态 + + + if not is_prestart: + + + self.sound_query.put((file_url, audio_length, interact)) + + + else: + + + # prestart 内容不应该重置机器人表情 + + + if not is_prestart and wsa_server.get_web_instance().is_connected(interact.data.get('user')): + + + wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) + + + + + + except Exception as e: + + + print(e) + + + + + + def play_end(self, interact): + + + self.speaking = False + + + global can_auto_play + + + global auto_play_lock + + + with auto_play_lock: + + + if self.timer: + + + self.timer.cancel() + + + self.timer = None + + + if interact.interleaver != 'auto_play': #交互后暂停自动播报30秒 + + + self.timer = threading.Timer(30, self.set_auto_play) + + + self.timer.start() + + + else: + + + can_auto_play = True + + + + + + #恢复自动播报(如果有) + + + def set_auto_play(self): + + + global auto_play_lock + + + global can_auto_play + + + with auto_play_lock: + + + can_auto_play = True + + + self.timer = None + + + + + + #启动核心服务 + + + def start(self): + + + MyThread(target=self.__play_sound).start() + + + + + + #停止核心服务 + + + def stop(self): + + + self.__running = False + + + self.speaking = False + + + self.sp.close() + + + wsa_server.get_web_instance().add_cmd({"panelMsg": ""}) + + + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': ""}} + + + wsa_server.get_instance().add_cmd(content) + + + + + + def __record_response(self, text, username, uid): + + + """ + + + 记录AI的回复内容 + + + :param text: 回复文本 + + + :param username: 用户名 + + + :param uid: 用户ID + + + :return: content_id + + + """ + + + self.write_to_file("./logs", "answer_result.txt", text) + + + return content_db.new_instance().add_content('fay', 'speak', text, username, uid) + + + + + + def __remove_prestart_tags(self, text): + + + """ + + + 移除文本中的 prestart 标签及其内容 + + + :param text: 原始文本 + + + :return: 移除 prestart 标签后的文本 + + + """ + + + if not text: + + + return text + + + import re + + + # 移除 ... 标签及其内容(支持属性) + + cleaned = re.sub(r']*>[\s\S]*?', '', text, flags=re.IGNORECASE) + + return cleaned.strip() + + + + def __has_prestart(self, text): + + """ + + 判断文本中是否包含 prestart 标签(支持属性) + + """ + + if not text: + + return False + + return re.search(r']*>[\s\S]*?', text, flags=re.IGNORECASE) is not None + + + + + + def __truncate_think_for_panel(self, text, uid, username): + + if not text or not isinstance(text, str): + + return text + + key = uid if uid is not None else username + + state = self.think_display_state.get(key) + + if state is None: + + state = {"in_think": False, "in_tool_output": False, "tool_count": 0, "tool_truncated": False} + + self.think_display_state[key] = state + + if not state["in_think"] and "" not in text and "" not in text: + + return text + + tool_output_regex = re.compile(r"\[TOOL\]\s*(?:Output|\u8f93\u51fa)[:\uff1a]", re.IGNORECASE) + + section_regex = re.compile(r"(?i)(^|[\r\n])(\[(?:TOOL|PLAN)\])") + + out = [] + + i = 0 + + while i < len(text): + + if not state["in_think"]: + + idx = text.find("", i) + + if idx == -1: + + out.append(text[i:]) + + break + + out.append(text[i:idx + len("")]) + + state["in_think"] = True + + i = idx + len("") + + continue + + if not state["in_tool_output"]: + + think_end = text.find("", i) + + tool_match = tool_output_regex.search(text, i) + + next_pos = None + + next_kind = None + + if tool_match: + + next_pos = tool_match.start() + + next_kind = "tool" + + if think_end != -1 and (next_pos is None or think_end < next_pos): + + next_pos = think_end + + next_kind = "think_end" + + if next_pos is None: + + out.append(text[i:]) + + break + + if next_pos > i: + + out.append(text[i:next_pos]) + + if next_kind == "think_end": + + out.append("") + + state["in_think"] = False + + state["in_tool_output"] = False + + state["tool_count"] = 0 + + state["tool_truncated"] = False + + i = next_pos + len("") + + else: + + marker_end = tool_match.end() + + out.append(text[next_pos:marker_end]) + + state["in_tool_output"] = True + + state["tool_count"] = 0 + + state["tool_truncated"] = False + + i = marker_end + + continue + + think_end = text.find("", i) + + section_match = section_regex.search(text, i) + + end_pos = None + + if section_match: + + end_pos = section_match.start(2) + + if think_end != -1 and (end_pos is None or think_end < end_pos): + + end_pos = think_end + + segment = text[i:] if end_pos is None else text[i:end_pos] + + if segment: + + if state["tool_truncated"]: + + pass + + else: + + remaining = self.think_display_limit - state["tool_count"] + + if remaining <= 0: + + out.append("...") + + state["tool_truncated"] = True + + elif len(segment) <= remaining: + + out.append(segment) + + state["tool_count"] += len(segment) + + else: + + out.append(segment[:remaining] + "...") + + state["tool_count"] += remaining + + state["tool_truncated"] = True + + if end_pos is None: + + break + + state["in_tool_output"] = False + + state["tool_count"] = 0 + + state["tool_truncated"] = False + + i = end_pos + + return "".join(out) + + def __send_panel_message(self, text, username, uid, content_id=None, type=None): + + + """ + + + 发送消息到Web面板 + + + :param text: 消息文本 + + + :param username: 用户名 + + + :param uid: 用户ID + + + :param content_id: 内容ID + + + :param type: 消息类型 + + + """ + + + if not wsa_server.get_web_instance().is_connected(username): + + + return + + + + + + # 检查是否是 prestart 内容,prestart 内容不应该更新日志区消息 + + + # 因为这会覆盖掉"思考中..."的状态显示 + + + is_prestart = self.__has_prestart(text) + display_text = self.__truncate_think_for_panel(text, uid, username) + + + + + # gui日志区消息(prestart 内容跳过,保持"思考中..."状态) + + + if not is_prestart: + + + wsa_server.get_web_instance().add_cmd({ + + + "panelMsg": display_text, + + + "Username": username + + + }) + + + + + + # 聊天窗消息 + + + if content_id is not None: + + + wsa_server.get_web_instance().add_cmd({ + + + "panelReply": { + + + "type": "fay", + + + "content": display_text, + + + "username": username, + + + "uid": uid, + + + "id": content_id, + + + "is_adopted": type == 'qa' + + + }, + + + "Username": username + + + }) + + + + + + def __send_digital_human_message(self, text, username, is_first=False, is_end=False): + + + """ + + + 发送消息到数字人(语音应该在say方法驱动数字人输出) + + + :param text: 消息文本 + + + :param username: 用户名 + + + :param is_first: 是否是第一段文本 + + + :param is_end: 是否是最后一段文本 + + + """ + + + # 移除 prestart 标签内容,不发送给数字人 + + + cleaned_text = self.__remove_prestart_tags(text) if text else "" + + + full_text = self.__remove_emojis(cleaned_text.replace("*", "")) if cleaned_text else "" + + + + + + # 如果文本为空且不是结束标记,则不发送,但需保留 is_first + + if not full_text and not is_end: + + if is_first: + + self.pending_isfirst[username] = True + + return + + + + # 检查是否有延迟的 is_first 需要应用 + + if self.pending_isfirst.get(username, False): + + is_first = True + + self.pending_isfirst[username] = False + + + + + if wsa_server.get_instance().is_connected(username): + + + content = { + + + 'Topic': 'human', + + + 'Data': { + + + 'Key': 'text', + + + 'Value': full_text, + + + 'IsFirst': 1 if is_first else 0, + + + 'IsEnd': 1 if is_end else 0 + + + }, + + + 'Username': username + + + } + + + wsa_server.get_instance().add_cmd(content) + + + + + + def __process_text_output(self, text, username, uid, content_id, type, is_first=False, is_end=False): + + + """ + + + 完整文本输出到各个终端 + + + :param text: 主要回复文本 + + + :param textlist: 额外回复列表 + + + :param username: 用户名 + + + :param uid: 用户ID + + + :param type: 消息类型 + + + :param is_first: 是否是第一段文本 + + + :param is_end: 是否是最后一段文本 + + + """ + + + if text: + + + text = text.strip() + + + + + + # 记录主回复 + + + # content_id = self.__record_response(text, username, uid) + + + + + + # 发送主回复到面板和数字人 + + + self.__send_panel_message(text, username, uid, content_id, type) + + + self.__send_digital_human_message(text, username, is_first, is_end) + + + + + + # 打印日志 + + + util.printInfo(1, username, '({}) {}'.format("llm", text)) + + + + + +import importlib + + +fay_booter = importlib.import_module('fay_booter') + + + + - - # gui日志区消息(prestart 内容跳过,保持"思考中..."状态) - if not is_prestart: - wsa_server.get_web_instance().add_cmd({ - "panelMsg": display_text, - "Username": username - }) - - # 聊天窗消息 - if content_id is not None: - wsa_server.get_web_instance().add_cmd({ - "panelReply": { - "type": "fay", - "content": display_text, - "username": username, - "uid": uid, - "id": content_id, - "is_adopted": type == 'qa' - }, - "Username": username - }) - - def __send_digital_human_message(self, text, username, is_first=False, is_end=False): - """ - 发送消息到数字人(语音应该在say方法驱动数字人输出) - :param text: 消息文本 - :param username: 用户名 - :param is_first: 是否是第一段文本 - :param is_end: 是否是最后一段文本 - """ - # 移除 prestart 标签内容,不发送给数字人 - cleaned_text = self.__remove_prestart_tags(text) if text else "" - full_text = self.__remove_emojis(cleaned_text.replace("*", "")) if cleaned_text else "" - - # 如果文本为空且不是结束标记,则不发送,但需保留 is_first - if not full_text and not is_end: - if is_first: - self.pending_isfirst[username] = True - return - - # 检查是否有延迟的 is_first 需要应用 - if self.pending_isfirst.get(username, False): - is_first = True - self.pending_isfirst[username] = False - - if wsa_server.get_instance().is_connected(username): - content = { - 'Topic': 'human', - 'Data': { - 'Key': 'text', - 'Value': full_text, - 'IsFirst': 1 if is_first else 0, - 'IsEnd': 1 if is_end else 0 - }, - 'Username': username - } - wsa_server.get_instance().add_cmd(content) - - def __process_text_output(self, text, username, uid, content_id, type, is_first=False, is_end=False): - """ - 完整文本输出到各个终端 - :param text: 主要回复文本 - :param textlist: 额外回复列表 - :param username: 用户名 - :param uid: 用户ID - :param type: 消息类型 - :param is_first: 是否是第一段文本 - :param is_end: 是否是最后一段文本 - """ - if text: - text = text.strip() - - # 记录主回复 - # content_id = self.__record_response(text, username, uid) - - # 发送主回复到面板和数字人 - self.__send_panel_message(text, username, uid, content_id, type) - self.__send_digital_human_message(text, username, is_first, is_end) - - # 打印日志 - util.printInfo(1, username, '({}) {}'.format("llm", text)) - -import importlib -fay_booter = importlib.import_module('fay_booter') - diff --git a/genagents/modules/memory_stream.py b/genagents/modules/memory_stream.py index 360c0b6..27ca8bd 100644 --- a/genagents/modules/memory_stream.py +++ b/genagents/modules/memory_stream.py @@ -1,368 +1,381 @@ -import math -import sys -import datetime -import random -import string -import re - -from numpy import dot -from numpy.linalg import norm - -from simulation_engine.settings import * -from simulation_engine.global_methods import * -from simulation_engine.gpt_structure import * -from simulation_engine.llm_json_parser import * - - -def run_gpt_generate_importance( - records, - prompt_version="1", - gpt_version="GPT4o", - verbose=False): - - def create_prompt_input(records): - records_str = "" - for count, r in enumerate(records): - records_str += f"Item {str(count+1)}:\n" - records_str += f"{r}\n" - return [records_str] - - def _func_clean_up(gpt_response, prompt=""): - gpt_response = extract_first_json_dict(gpt_response) - # 处理gpt_response为None的情况 - if gpt_response is None: - print("警告: extract_first_json_dict返回None,使用默认值") - return [50] # 返回默认重要性分数 - return list(gpt_response.values()) - - def _get_fail_safe(): - return 25 - - if len(records) > 1: - prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/batch_v1.txt" - else: - prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/singular_v1.txt" - - prompt_input = create_prompt_input(records) - fail_safe = _get_fail_safe() - - output, prompt, prompt_input, fail_safe = chat_safe_generate( - prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, - _func_clean_up, verbose) - - return output, [output, prompt, prompt_input, fail_safe] - - -def generate_importance_score(records): - return run_gpt_generate_importance(records, "1", LLM_VERS)[0] - - -def run_gpt_generate_reflection( - records, - anchor, - reflection_count, - prompt_version="1", - gpt_version="GPT4o", - verbose=False): - - def create_prompt_input(records, anchor, reflection_count): - records_str = "" - for count, r in enumerate(records): - records_str += f"Item {str(count+1)}:\n" - records_str += f"{r}\n" - return [records_str, reflection_count, anchor] - - def _func_clean_up(gpt_response, prompt=""): - return extract_first_json_dict(gpt_response)["reflection"] - - def _get_fail_safe(): - return [] - - if reflection_count > 1: - prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/batch_v1.txt" - else: - prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/singular_v1.txt" - - prompt_input = create_prompt_input(records, anchor, reflection_count) - fail_safe = _get_fail_safe() - - output, prompt, prompt_input, fail_safe = chat_safe_generate( - prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, - _func_clean_up, verbose) - - return output, [output, prompt, prompt_input, fail_safe] - - -def generate_reflection(records, anchor, reflection_count): - records = [i.content for i in records] - return run_gpt_generate_reflection(records, anchor, reflection_count, "1", - LLM_VERS)[0] - - -# ############################################################################## -# ### HELPER FUNCTIONS FOR GENERATIVE AGENTS ### -# ############################################################################## - -def get_random_str(length): - """ - Generates a random string of alphanumeric characters with the specified - length. This function creates a random string by selecting characters from - the set of uppercase letters, lowercase letters, and digits. The length of - the random string is determined by the 'length' parameter. - - Parameters: - length (int): The desired length of the random string. - Returns: - random_string: A randomly generated string of the specified length. - - Example: - >>> get_random_str(8) - 'aB3R7tQ2' - """ - characters = string.ascii_letters + string.digits - random_string = ''.join(random.choice(characters) for _ in range(length)) - return random_string - - -def cos_sim(a, b): - """ - This function calculates the cosine similarity between two input vectors - 'a' and 'b'. Cosine similarity is a measure of similarity between two - non-zero vectors of an inner product space that measures the cosine - of the angle between them. - - Parameters: - a: 1-D array object - b: 1-D array object - Returns: - A scalar value representing the cosine similarity between the input - vectors 'a' and 'b'. - - Example: - >>> a = [0.3, 0.2, 0.5] - >>> b = [0.2, 0.2, 0.5] - >>> cos_sim(a, b) - """ - return dot(a, b)/(norm(a)*norm(b)) - - -def normalize_dict_floats(d, target_min, target_max): - """ - This function normalizes the float values of a given dictionary 'd' between - a target minimum and maximum value. The normalization is done by scaling the - values to the target range while maintaining the same relative proportions - between the original values. - - Parameters: - d: Dictionary. The input dictionary whose float values need to be - normalized. - target_min: Integer or float. The minimum value to which the original - values should be scaled. - target_max: Integer or float. The maximum value to which the original - values should be scaled. - Returns: - d: A new dictionary with the same keys as the input but with the float - values normalized between the target_min and target_max. - - Example: - >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8} - >>> target_min = -5 - >>> target_max = 5 - >>> normalize_dict_floats(d, target_min, target_max) - """ - # 检查字典是否为None或为空 - if d is None: - print("警告: normalize_dict_floats接收到None字典") - return {} - - if not d: - print("警告: normalize_dict_floats接收到空字典") - return {} - - try: - min_val = min(val for val in d.values()) - max_val = max(val for val in d.values()) - range_val = max_val - min_val - - if range_val == 0: - for key, val in d.items(): - d[key] = (target_max - target_min)/2 - else: - for key, val in d.items(): - d[key] = ((val - min_val) * (target_max - target_min) - / range_val + target_min) - return d - except Exception as e: - print(f"normalize_dict_floats处理字典时出错: {str(e)}") - # 返回原始字典,避免处理失败 - return d - - -def top_highest_x_values(d, x): - """ - This function takes a dictionary 'd' and an integer 'x' as input, and - returns a new dictionary containing the top 'x' key-value pairs from the - input dictionary 'd' with the highest values. - - Parameters: - d: Dictionary. The input dictionary from which the top 'x' key-value pairs - with the highest values are to be extracted. - x: Integer. The number of top key-value pairs with the highest values to - be extracted from the input dictionary. - Returns: - A new dictionary containing the top 'x' key-value pairs from the input - dictionary 'd' with the highest values. - - Example: - >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8} - >>> x = 3 - >>> top_highest_x_values(d, x) - """ - top_v = dict(sorted(d.items(), - key=lambda item: item[1], - reverse=True)[:x]) - return top_v - - -def extract_recency(seq_nodes): - """ - Gets the current Persona object and a list of nodes that are in a - chronological order, and outputs a dictionary that has the recency score - calculated. - - Parameters: - nodes: A list of Node object in a chronological order. - Returns: - recency_out: A dictionary whose keys are the node.node_id and whose values - are the float that represents the recency score. - """ - # 检查seq_nodes是否为None或为空 - if seq_nodes is None: - print("警告: extract_recency接收到None节点列表") - return {} - - if not seq_nodes: - print("警告: extract_recency接收到空节点列表") - return {} - - try: - # 确保所有的last_retrieved都是整数类型 - normalized_timestamps = [] - for node in seq_nodes: - if node is None: - print("警告: 节点为None,跳过") - continue - - if not hasattr(node, 'last_retrieved'): - print(f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0") - normalized_timestamps.append(0) - continue - - if isinstance(node.last_retrieved, str): - try: - normalized_timestamps.append(int(node.last_retrieved)) - except ValueError: - # 如果无法转换为整数,使用0作为默认值 - normalized_timestamps.append(0) - else: - normalized_timestamps.append(node.last_retrieved) - - if not normalized_timestamps: - return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} - - max_timestep = max(normalized_timestamps) - - recency_decay = 0.99 - recency_out = dict() - for count, node in enumerate(seq_nodes): - if node is None or not hasattr(node, 'node_id') or not hasattr(node, 'last_retrieved'): - continue - - # 获取标准化后的时间戳 - try: - last_retrieved = normalized_timestamps[count] - recency_out[node.node_id] = (recency_decay - ** (max_timestep - last_retrieved)) - except Exception as e: - print(f"计算节点 {node.node_id} 的recency时出错: {str(e)}") - # 使用默认值 - recency_out[node.node_id] = 1.0 - - return recency_out - except Exception as e: - print(f"extract_recency处理节点列表时出错: {str(e)}") - # 返回一个默认字典 - return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} - - +import math +import sys +import datetime +import random +import string +import re + +from numpy import dot +from numpy.linalg import norm + +from simulation_engine.settings import * +from simulation_engine.global_methods import * +from simulation_engine.gpt_structure import * +from simulation_engine.llm_json_parser import * + + +def run_gpt_generate_importance( + records, + prompt_version="1", + gpt_version="GPT4o", + verbose=False): + + def create_prompt_input(records): + records_str = "" + for count, r in enumerate(records): + records_str += f"Item {str(count+1)}:\n" + records_str += f"{r}\n" + return [records_str] + + def _func_clean_up(gpt_response, prompt=""): + gpt_response = extract_first_json_dict(gpt_response) + # 处理gpt_response为None的情况 + if gpt_response is None: + print("警告: extract_first_json_dict返回None,使用默认值") + return [50] # 返回默认重要性分数 + return list(gpt_response.values()) + + def _get_fail_safe(): + return 25 + + if len(records) > 1: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/batch_v1.txt" + else: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/singular_v1.txt" + + prompt_input = create_prompt_input(records) + fail_safe = _get_fail_safe() + + output, prompt, prompt_input, fail_safe = chat_safe_generate( + prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, + _func_clean_up, verbose) + + return output, [output, prompt, prompt_input, fail_safe] + + +def generate_importance_score(records): + return run_gpt_generate_importance(records, "1", LLM_VERS)[0] + + +def run_gpt_generate_reflection( + records, + anchor, + reflection_count, + prompt_version="1", + gpt_version="GPT4o", + verbose=False): + + def create_prompt_input(records, anchor, reflection_count): + records_str = "" + for count, r in enumerate(records): + records_str += f"Item {str(count+1)}:\n" + records_str += f"{r}\n" + return [records_str, reflection_count, anchor] + + def _func_clean_up(gpt_response, prompt=""): + return extract_first_json_dict(gpt_response)["reflection"] + + def _get_fail_safe(): + return [] + + if reflection_count > 1: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/batch_v1.txt" + else: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/singular_v1.txt" + + prompt_input = create_prompt_input(records, anchor, reflection_count) + fail_safe = _get_fail_safe() + + output, prompt, prompt_input, fail_safe = chat_safe_generate( + prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, + _func_clean_up, verbose) + + return output, [output, prompt, prompt_input, fail_safe] + + +def generate_reflection(records, anchor, reflection_count): + records = [i.content for i in records] + return run_gpt_generate_reflection(records, anchor, reflection_count, "1", + LLM_VERS)[0] + + +# ############################################################################## +# ### HELPER FUNCTIONS FOR GENERATIVE AGENTS ### +# ############################################################################## + +def get_random_str(length): + """ + Generates a random string of alphanumeric characters with the specified + length. This function creates a random string by selecting characters from + the set of uppercase letters, lowercase letters, and digits. The length of + the random string is determined by the 'length' parameter. + + Parameters: + length (int): The desired length of the random string. + Returns: + random_string: A randomly generated string of the specified length. + + Example: + >>> get_random_str(8) + 'aB3R7tQ2' + """ + characters = string.ascii_letters + string.digits + random_string = ''.join(random.choice(characters) for _ in range(length)) + return random_string + + +def cos_sim(a, b): + """ + This function calculates the cosine similarity between two input vectors + 'a' and 'b'. Cosine similarity is a measure of similarity between two + non-zero vectors of an inner product space that measures the cosine + of the angle between them. + + Parameters: + a: 1-D array object + b: 1-D array object + Returns: + A scalar value representing the cosine similarity between the input + vectors 'a' and 'b'. + + Example: + >>> a = [0.3, 0.2, 0.5] + >>> b = [0.2, 0.2, 0.5] + >>> cos_sim(a, b) + """ + return dot(a, b)/(norm(a)*norm(b)) + + +def normalize_dict_floats(d, target_min, target_max): + """ + This function normalizes the float values of a given dictionary 'd' between + a target minimum and maximum value. The normalization is done by scaling the + values to the target range while maintaining the same relative proportions + between the original values. + + Parameters: + d: Dictionary. The input dictionary whose float values need to be + normalized. + target_min: Integer or float. The minimum value to which the original + values should be scaled. + target_max: Integer or float. The maximum value to which the original + values should be scaled. + Returns: + d: A new dictionary with the same keys as the input but with the float + values normalized between the target_min and target_max. + + Example: + >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8} + >>> target_min = -5 + >>> target_max = 5 + >>> normalize_dict_floats(d, target_min, target_max) + """ + # 检查字典是否为None或为空 + if d is None: + print("警告: normalize_dict_floats接收到None字典") + return {} + + if not d: + print("警告: normalize_dict_floats接收到空字典") + return {} + + try: + min_val = min(val for val in d.values()) + max_val = max(val for val in d.values()) + range_val = max_val - min_val + + if range_val == 0: + for key, val in d.items(): + d[key] = (target_max - target_min)/2 + else: + for key, val in d.items(): + d[key] = ((val - min_val) * (target_max - target_min) + / range_val + target_min) + return d + except Exception as e: + print(f"normalize_dict_floats处理字典时出错: {str(e)}") + # 返回原始字典,避免处理失败 + return d + + +def top_highest_x_values(d, x): + """ + This function takes a dictionary 'd' and an integer 'x' as input, and + returns a new dictionary containing the top 'x' key-value pairs from the + input dictionary 'd' with the highest values. + + Parameters: + d: Dictionary. The input dictionary from which the top 'x' key-value pairs + with the highest values are to be extracted. + x: Integer. The number of top key-value pairs with the highest values to + be extracted from the input dictionary. + Returns: + A new dictionary containing the top 'x' key-value pairs from the input + dictionary 'd' with the highest values. + + Example: + >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8} + >>> x = 3 + >>> top_highest_x_values(d, x) + """ + top_v = dict(sorted(d.items(), + key=lambda item: item[1], + reverse=True)[:x]) + return top_v + + +def extract_recency(seq_nodes): + """ + Gets the current Persona object and a list of nodes that are in a + chronological order, and outputs a dictionary that has the recency score + calculated. + + Parameters: + nodes: A list of Node object in a chronological order. + Returns: + recency_out: A dictionary whose keys are the node.node_id and whose values + are the float that represents the recency score. + """ + # 检查seq_nodes是否为None或为空 + if seq_nodes is None: + print("警告: extract_recency接收到None节点列表") + return {} + + if not seq_nodes: + print("警告: extract_recency接收到空节点列表") + return {} + + try: + # 确保所有的last_retrieved都是整数类型 + normalized_timestamps = [] + for node in seq_nodes: + if node is None: + print("警告: 节点为None,跳过") + continue + + if not hasattr(node, 'last_retrieved'): + print(f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0") + normalized_timestamps.append(0) + continue + + if isinstance(node.last_retrieved, str): + try: + normalized_timestamps.append(int(node.last_retrieved)) + except ValueError: + # 如果无法转换为整数,使用0作为默认值 + normalized_timestamps.append(0) + else: + normalized_timestamps.append(node.last_retrieved) + + if not normalized_timestamps: + return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} + + max_timestep = max(normalized_timestamps) + + recency_decay = 0.99 + recency_out = dict() + for count, node in enumerate(seq_nodes): + if node is None or not hasattr(node, 'node_id') or not hasattr(node, 'last_retrieved'): + continue + + # 获取标准化后的时间戳 + try: + last_retrieved = normalized_timestamps[count] + recency_out[node.node_id] = (recency_decay + ** (max_timestep - last_retrieved)) + except Exception as e: + print(f"计算节点 {node.node_id} 的recency时出错: {str(e)}") + # 使用默认值 + recency_out[node.node_id] = 1.0 + + return recency_out + except Exception as e: + print(f"extract_recency处理节点列表时出错: {str(e)}") + # 返回一个默认字典 + return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} + + def extract_importance(seq_nodes): - """ - Gets the current Persona object and a list of nodes that are in a - chronological order, and outputs a dictionary that has the importance score - calculated. - - Parameters: - seq_nodes: A list of Node object in a chronological order. - Returns: - importance_out: A dictionary whose keys are the node.node_id and whose - values are the float that represents the importance score. - """ - # 检查seq_nodes是否为None或为空 - if seq_nodes is None: - print("警告: extract_importance接收到None节点列表") - return {} - - if not seq_nodes: - print("警告: extract_importance接收到空节点列表") - return {} - - try: - importance_out = dict() - for count, node in enumerate(seq_nodes): - if node is None: - print("警告: 节点为None,跳过") - continue - - if not hasattr(node, 'node_id') or not hasattr(node, 'importance'): - print(f"警告: 节点缺少必要属性,跳过") - continue - - # 确保importance是数值类型 - if isinstance(node.importance, str): - try: - importance_out[node.node_id] = float(node.importance) - except ValueError: - # 如果无法转换为数值,使用默认值 - print(f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值") - importance_out[node.node_id] = 50.0 - else: - importance_out[node.node_id] = node.importance - - return importance_out + """ + Gets the current Persona object and a list of nodes that are in a + chronological order, and outputs a dictionary that has the importance score + calculated. + + Parameters: + seq_nodes: A list of Node object in a chronological order. + Returns: + importance_out: A dictionary whose keys are the node.node_id and whose + values are the float that represents the importance score. + """ + # 检查seq_nodes是否为None或为空 + if seq_nodes is None: + print("警告: extract_importance接收到None节点列表") + return {} + + if not seq_nodes: + print("警告: extract_importance接收到空节点列表") + return {} + + try: + importance_out = dict() + for count, node in enumerate(seq_nodes): + if node is None: + print("警告: 节点为None,跳过") + continue + + if not hasattr(node, 'node_id') or not hasattr(node, 'importance'): + print(f"警告: 节点缺少必要属性,跳过") + continue + + # 确保importance是数值类型 + if isinstance(node.importance, str): + try: + importance_out[node.node_id] = float(node.importance) + except ValueError: + # 如果无法转换为数值,使用默认值 + print(f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值") + importance_out[node.node_id] = 50.0 + else: + importance_out[node.node_id] = node.importance + + return importance_out except Exception as e: print(f"extract_importance处理节点列表时出错: {str(e)}") # 返回一个默认字典 return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} -def extract_relevance(seq_nodes, embeddings, focal_pt): - """ - Gets the current Persona object, a list of seq_nodes that are in a - chronological order, and the focal_pt string and outputs a dictionary - that has the relevance score calculated. +def _is_valid_embedding(vec, expected_dim): + if vec is None: + return False + if not isinstance(vec, (list, tuple)): + return False + if expected_dim is not None and len(vec) != expected_dim: + return False + for val in vec: + if not isinstance(val, (int, float)): + return False + return True - Parameters: - seq_nodes: A list of Node object in a chronological order. - focal_pt: A string describing the current thought of revent of focus. - Returns: - relevance_out: A dictionary whose keys are the node.node_id and whose - values are the float that represents the relevance score. - """ - # 确保embeddings不为None - if embeddings is None: - print("警告: embeddings为None,使用空字典代替") - embeddings = {} - + +def extract_relevance(seq_nodes, embeddings, focal_pt): + """ + Gets the current Persona object, a list of seq_nodes that are in a + chronological order, and the focal_pt string and outputs a dictionary + that has the relevance score calculated. + + Parameters: + seq_nodes: A list of Node object in a chronological order. + focal_pt: A string describing the current thought of revent of focus. + Returns: + relevance_out: A dictionary whose keys are the node.node_id and whose + values are the float that represents the relevance score. + """ + # 确保embeddings不为None + if embeddings is None: + print("警告: embeddings为None,使用空字典代替") + embeddings = {} + try: focal_embedding = get_text_embedding(focal_pt) except Exception as e: @@ -370,237 +383,260 @@ def extract_relevance(seq_nodes, embeddings, focal_pt): # 如果无法获取嵌入向量,返回默认值 return {node.node_id: 0.5 for node in seq_nodes} + expected_dim = len(focal_embedding) if isinstance(focal_embedding, (list, tuple)) else None relevance_out = dict() for count, node in enumerate(seq_nodes): try: # 检查节点内容是否在embeddings中 if node.content in embeddings: node_embedding = embeddings[node.content] + if not _is_valid_embedding(node_embedding, expected_dim): + try: + regenerated = get_text_embedding(node.content) + if _is_valid_embedding(regenerated, expected_dim): + embeddings[node.content] = regenerated + node_embedding = regenerated + else: + print("Warning: regenerated embedding has unexpected dimension, using default relevance") + node_embedding = None + except Exception as e: + print(f"Regenerate embedding failed: {str(e)}") + node_embedding = None # 计算余弦相似度 - relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding) + if node_embedding is None: + relevance_out[node.node_id] = 0.5 + else: + relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding) else: # 如果没有对应的嵌入向量,使用默认值 relevance_out[node.node_id] = 0.5 - except Exception as e: - print(f"计算节点 {node.node_id} 的相关性时出错: {str(e)}") - # 如果计算过程中出错,使用默认值 - relevance_out[node.node_id] = 0.5 - - return relevance_out - - -# ############################################################################## -# ### CONCEPT NODE ### -# ############################################################################## - -class ConceptNode: - def __init__(self, node_dict): - # Loading the content of a memory node in the memory stream. + except Exception as e: + print(f"计算节点 {node.node_id} 的相关性时出错: {str(e)}") + # 如果计算过程中出错,使用默认值 + relevance_out[node.node_id] = 0.5 + + return relevance_out + + +# ############################################################################## +# ### CONCEPT NODE ### +# ############################################################################## + +class ConceptNode: + def __init__(self, node_dict): + # Loading the content of a memory node in the memory stream. self.node_id = node_dict["node_id"] self.node_type = node_dict["node_type"] self.content = node_dict["content"] self.importance = node_dict["importance"] + self.datetime = node_dict.get("datetime", "") # 确保created是整数类型 - self.created = int(node_dict["created"]) if node_dict["created"] is not None else 0 - # 确保last_retrieved是整数类型 - self.last_retrieved = int(node_dict["last_retrieved"]) if node_dict["last_retrieved"] is not None else 0 - self.pointer_id = node_dict["pointer_id"] - - - def package(self): - """ - Packaging the ConceptNode - - Parameters: - None - Returns: - packaged dictionary - """ - curr_package = {} - curr_package["node_id"] = self.node_id - curr_package["node_type"] = self.node_type + self.created = int(node_dict["created"]) if node_dict["created"] is not None else 0 + # 确保last_retrieved是整数类型 + self.last_retrieved = int(node_dict["last_retrieved"]) if node_dict["last_retrieved"] is not None else 0 + self.pointer_id = node_dict["pointer_id"] + + + def package(self): + """ + Packaging the ConceptNode + + Parameters: + None + Returns: + packaged dictionary + """ + curr_package = {} + curr_package["node_id"] = self.node_id + curr_package["node_type"] = self.node_type curr_package["content"] = self.content curr_package["importance"] = self.importance + curr_package["datetime"] = self.datetime curr_package["created"] = self.created - curr_package["last_retrieved"] = self.last_retrieved - curr_package["pointer_id"] = self.pointer_id - - return curr_package - - -# ############################################################################## -# ### MEMORY STREAM ### -# ############################################################################## - -class MemoryStream: - def __init__(self, nodes, embeddings): - # Loading the memory stream for the agent. - self.seq_nodes = [] - self.id_to_node = dict() - for node in nodes: - new_node = ConceptNode(node) - self.seq_nodes += [new_node] - self.id_to_node[new_node.node_id] = new_node - - self.embeddings = embeddings - - - def count_observations(self): - """ - Counting the number of observations (basically, the number of all nodes in - memory stream except for the reflections) - - Parameters: - None - Returns: - Count - """ - count = 0 - for i in self.seq_nodes: - if i.node_type == "observation": - count += 1 - return count - - - def retrieve(self, focal_points, time_step, n_count=120, curr_filter="all", - hp=[0, 1, 0.5], stateless=False, verbose=False): - """ - Retrieve elements from the memory stream. - - Parameters: - focal_points: This is the query sentence. It is in a list form where - the elemnts of the list are the query sentences. - time_step: Current time_step - n_count: The number of nodes that we want to retrieve. + curr_package["last_retrieved"] = self.last_retrieved + curr_package["pointer_id"] = self.pointer_id + + return curr_package + + +# ############################################################################## +# ### MEMORY STREAM ### +# ############################################################################## + +class MemoryStream: + def __init__(self, nodes, embeddings): + # Loading the memory stream for the agent. + self.seq_nodes = [] + self.id_to_node = dict() + for node in nodes: + new_node = ConceptNode(node) + self.seq_nodes += [new_node] + self.id_to_node[new_node.node_id] = new_node + + self.embeddings = embeddings + + + def count_observations(self): + """ + Counting the number of observations (basically, the number of all nodes in + memory stream except for the reflections) + + Parameters: + None + Returns: + Count + """ + count = 0 + for i in self.seq_nodes: + if i.node_type == "observation": + count += 1 + return count + + + def retrieve(self, focal_points, time_step, n_count=120, curr_filter="all", + hp=[0, 1, 0.5], stateless=False, verbose=False): + """ + Retrieve elements from the memory stream. + + Parameters: + focal_points: This is the query sentence. It is in a list form where + the elemnts of the list are the query sentences. + time_step: Current time_step + n_count: The number of nodes that we want to retrieve. curr_filter: Filtering the node.type that we want to retrieve. - Acceptable values are 'all', 'reflection', 'observation' - hp: Hyperparameter for [recency_w, relevance_w, importance_w] - verbose: verbose - Returns: - retrieved: A dictionary whose keys are a focal_pt query str, and whose - values are a list of nodes that are retrieved for that query str. - """ - curr_nodes = [] - - # If the memory stream is empty, we return an empty dictionary. - if len(self.seq_nodes) == 0: - return dict() - - # Filtering for the desired node type. curr_filter can be one of the three - # elements: 'all', 'reflection', 'observation' - if curr_filter == "all": - curr_nodes = self.seq_nodes - else: - for curr_node in self.seq_nodes: - if curr_node.node_type == curr_filter: - curr_nodes += [curr_node] - - # 确保embeddings不为None - if self.embeddings is None: - print("警告: 在retrieve方法中,embeddings为None,初始化为空字典") - self.embeddings = {} - - # is the main dictionary that we are returning - retrieved = dict() - for focal_pt in focal_points: - # Calculating the component dictionaries and normalizing them. - x = extract_recency(curr_nodes) - recency_out = normalize_dict_floats(x, 0, 1) - x = extract_importance(curr_nodes) - importance_out = normalize_dict_floats(x, 0, 1) - x = extract_relevance(curr_nodes, self.embeddings, focal_pt) - relevance_out = normalize_dict_floats(x, 0, 1) - - # Computing the final scores that combines the component values. - master_out = dict() - for key in recency_out.keys(): - recency_w = hp[0] - relevance_w = hp[1] - importance_w = hp[2] - master_out[key] = (recency_w * recency_out[key] - + relevance_w * relevance_out[key] - + importance_w * importance_out[key]) - - if verbose: - master_out = top_highest_x_values(master_out, len(master_out.keys())) - for key, val in master_out.items(): - print (self.id_to_node[key].content, val) - print (recency_w*recency_out[key]*1, - relevance_w*relevance_out[key]*1, - importance_w*importance_out[key]*1) - - # Extracting the highest x values. - # has the key of node.id and value of float. Once we get - # the highest x values, we want to translate the node.id into nodes - # and return the list of nodes. - master_out = top_highest_x_values(master_out, n_count) - master_nodes = [self.id_to_node[key] for key in list(master_out.keys())] - - # **Sort the master_nodes list by last_retrieved in descending order** - master_nodes = sorted(master_nodes, key=lambda node: node.created, reverse=False) - - # We do not want to update the last retrieved time_step for these nodes - # if we are in a stateless mode. - if not stateless: - for n in master_nodes: - n.last_retrieved = time_step - - retrieved[focal_pt] = master_nodes - - return retrieved - - - def _add_node(self, time_step, node_type, content, importance, pointer_id): - """ - Adding a new node to the memory stream. - - Parameters: - time_step: Current time_step - node_type: type of node -- it's either reflection, observation - content: the str content of the memory record - importance: int score of the importance score - pointer_id: the str of the parent node - Returns: - retrieved: A dictionary whose keys are a focal_pt query str, and whose - values are a list of nodes that are retrieved for that query str. - """ - node_dict = dict() - node_dict["node_id"] = len(self.seq_nodes) + Acceptable values are 'all', 'reflection', 'observation', 'conversation' + hp: Hyperparameter for [recency_w, relevance_w, importance_w] + verbose: verbose + Returns: + retrieved: A dictionary whose keys are a focal_pt query str, and whose + values are a list of nodes that are retrieved for that query str. + """ + curr_nodes = [] + + # If the memory stream is empty, we return an empty dictionary. + if len(self.seq_nodes) == 0: + return dict() + + # Filtering for the desired node type. curr_filter can be one of the + # elements: 'all', 'reflection', 'observation', 'conversation' + if curr_filter == "all": + curr_nodes = self.seq_nodes + else: + for curr_node in self.seq_nodes: + if curr_node.node_type == curr_filter: + curr_nodes += [curr_node] + + # 确保embeddings不为None + if self.embeddings is None: + print("警告: 在retrieve方法中,embeddings为None,初始化为空字典") + self.embeddings = {} + + # is the main dictionary that we are returning + retrieved = dict() + for focal_pt in focal_points: + # Calculating the component dictionaries and normalizing them. + x = extract_recency(curr_nodes) + recency_out = normalize_dict_floats(x, 0, 1) + x = extract_importance(curr_nodes) + importance_out = normalize_dict_floats(x, 0, 1) + x = extract_relevance(curr_nodes, self.embeddings, focal_pt) + relevance_out = normalize_dict_floats(x, 0, 1) + + # Computing the final scores that combines the component values. + master_out = dict() + for key in recency_out.keys(): + recency_w = hp[0] + relevance_w = hp[1] + importance_w = hp[2] + master_out[key] = (recency_w * recency_out[key] + + relevance_w * relevance_out[key] + + importance_w * importance_out[key]) + + if verbose: + master_out = top_highest_x_values(master_out, len(master_out.keys())) + for key, val in master_out.items(): + print (self.id_to_node[key].content, val) + print (recency_w*recency_out[key]*1, + relevance_w*relevance_out[key]*1, + importance_w*importance_out[key]*1) + + # Extracting the highest x values. + # has the key of node.id and value of float. Once we get + # the highest x values, we want to translate the node.id into nodes + # and return the list of nodes. + master_out = top_highest_x_values(master_out, n_count) + master_nodes = [self.id_to_node[key] for key in list(master_out.keys())] + + # **Sort the master_nodes list by last_retrieved in descending order** + master_nodes = sorted(master_nodes, key=lambda node: node.created, reverse=False) + + # We do not want to update the last retrieved time_step for these nodes + # if we are in a stateless mode. + if not stateless: + for n in master_nodes: + n.last_retrieved = time_step + + retrieved[focal_pt] = master_nodes + + return retrieved + + + def _add_node(self, time_step, node_type, content, importance, pointer_id): + """ + Adding a new node to the memory stream. + + Parameters: + time_step: Current time_step + node_type: type of node -- it's either reflection, observation, conversation + content: the str content of the memory record + importance: int score of the importance score + pointer_id: the str of the parent node + Returns: + retrieved: A dictionary whose keys are a focal_pt query str, and whose + values are a list of nodes that are retrieved for that query str. + """ + node_dict = dict() + node_dict["node_id"] = len(self.seq_nodes) node_dict["node_type"] = node_type node_dict["content"] = content node_dict["importance"] = importance + node_dict["datetime"] = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S") node_dict["created"] = time_step - node_dict["last_retrieved"] = time_step - node_dict["pointer_id"] = pointer_id - new_node = ConceptNode(node_dict) - - self.seq_nodes += [new_node] - self.id_to_node[new_node.node_id] = new_node - - # 确保embeddings不为None - if self.embeddings is None: - self.embeddings = {} - - try: - self.embeddings[content] = get_text_embedding(content) - except Exception as e: - print(f"获取文本嵌入时出错: {str(e)}") - # 如果获取嵌入失败,使用空列表代替 - self.embeddings[content] = [] - - + node_dict["last_retrieved"] = time_step + node_dict["pointer_id"] = pointer_id + new_node = ConceptNode(node_dict) + + self.seq_nodes += [new_node] + self.id_to_node[new_node.node_id] = new_node + + # 确保embeddings不为None + if self.embeddings is None: + self.embeddings = {} + + try: + self.embeddings[content] = get_text_embedding(content) + except Exception as e: + print(f"获取文本嵌入时出错: {str(e)}") + # 如果获取嵌入失败,使用空列表代替 + self.embeddings[content] = [] + + def remember(self, content, time_step=0): score = generate_importance_score([content])[0] self._add_node(time_step, "observation", content, score, None) + def remember_conversation(self, content, time_step=0): + score = generate_importance_score([content])[0] + self._add_node(time_step, "conversation", content, score, None) - def reflect(self, anchor, reflection_count=5, - retrieval_count=120, time_step=0): - records = self.retrieve([anchor], time_step, retrieval_count)[anchor] - record_ids = [i.node_id for i in records] - reflections = generate_reflection(records, anchor, reflection_count) - scores = generate_importance_score(reflections) - - for count, reflection in enumerate(reflections): - self._add_node(time_step, "reflection", reflections[count], - scores[count], record_ids) + + def reflect(self, anchor, reflection_count=5, + retrieval_count=120, time_step=0): + records = self.retrieve([anchor], time_step, retrieval_count)[anchor] + record_ids = [i.node_id for i in records] + reflections = generate_reflection(records, anchor, reflection_count) + scores = generate_importance_score(reflections) + + for count, reflection in enumerate(reflections): + self._add_node(time_step, "reflection", reflections[count], + scores[count], record_ids) diff --git a/gui/flask_server.py b/gui/flask_server.py index d22e0e9..c64873a 100644 --- a/gui/flask_server.py +++ b/gui/flask_server.py @@ -74,9 +74,9 @@ def __get_template(): except Exception as e: return f"Error rendering template: {e}", 500 -def __get_device_list(): - try: - if config_util.start_mode == 'common': +def __get_device_list(): + try: + if config_util.start_mode == 'common': audio = pyaudio.PyAudio() device_list = [] for i in range(audio.get_device_count()): @@ -86,33 +86,33 @@ def __get_device_list(): return list(set(device_list)) else: return [] - except Exception as e: - print(f"Error getting device list: {e}") - return [] - -def _as_bool(value): - if isinstance(value, bool): - return value - if value is None: - return False - if isinstance(value, (int, float)): - return value != 0 - if isinstance(value, str): - return value.strip().lower() in ("1", "true", "yes", "y", "on") - return False - -def _build_llm_url(base_url: str) -> str: - if not base_url: - return "" - url = base_url.rstrip("/") - if url.endswith("/chat/completions"): - return url - if url.endswith("/v1"): - return url + "/chat/completions" - return url + "/v1/chat/completions" - -@__app.route('/api/submit', methods=['post']) -def api_submit(): + except Exception as e: + print(f"Error getting device list: {e}") + return [] + +def _as_bool(value): + if isinstance(value, bool): + return value + if value is None: + return False + if isinstance(value, (int, float)): + return value != 0 + if isinstance(value, str): + return value.strip().lower() in ("1", "true", "yes", "y", "on") + return False + +def _build_llm_url(base_url: str) -> str: + if not base_url: + return "" + url = base_url.rstrip("/") + if url.endswith("/chat/completions"): + return url + if url.endswith("/v1"): + return url + "/chat/completions" + return url + "/v1/chat/completions" + +@__app.route('/api/submit', methods=['post']) +def api_submit(): data = request.values.get('data') if not data: return jsonify({'result': 'error', 'message': '未提供数据'}) @@ -309,27 +309,27 @@ def api_send(): # 获取指定用户的消息记录(支持分页) @__app.route('/api/get-msg', methods=['post']) -def api_get_Msg(): - try: - data = request.form.get('data') - if data is None: - data = request.get_json(silent=True) or {} - else: - data = json.loads(data) - if not isinstance(data, dict): - data = {} - username = data.get("username") - limit = data.get("limit", 30) # 默认每页30条 - offset = data.get("offset", 0) # 默认从0开始 - contentdb = content_db.new_instance() - uid = 0 - if username: - uid = member_db.new_instance().find_user(username) - if uid == 0: - return json.dumps({'list': [], 'total': 0, 'hasMore': False}) - # 获取总数用于判断是否还有更多 - total = contentdb.get_message_count(uid) - list = contentdb.get_list('all', 'desc', limit, uid, offset) +def api_get_Msg(): + try: + data = request.form.get('data') + if data is None: + data = request.get_json(silent=True) or {} + else: + data = json.loads(data) + if not isinstance(data, dict): + data = {} + username = data.get("username") + limit = data.get("limit", 30) # 默认每页30条 + offset = data.get("offset", 0) # 默认从0开始 + contentdb = content_db.new_instance() + uid = 0 + if username: + uid = member_db.new_instance().find_user(username) + if uid == 0: + return json.dumps({'list': [], 'total': 0, 'hasMore': False}) + # 获取总数用于判断是否还有更多 + total = contentdb.get_message_count(uid) + list = contentdb.get_list('all', 'desc', limit, uid, offset) relist = [] i = len(list) - 1 while i >= 0: @@ -354,134 +354,143 @@ def api_send_v1_chat_completions(): data = request.get_json() if not data: return jsonify({'error': '未提供数据'}) - try: - model = data.get('model', 'fay') - if model == 'llm': - try: - config_util.load_config() - llm_url = _build_llm_url(config_util.gpt_base_url) - api_key = config_util.key_gpt_api_key - model_engine = config_util.gpt_model_engine - except Exception as exc: - return jsonify({'error': f'LLM config load failed: {exc}'}), 500 - - if not llm_url: - return jsonify({'error': 'LLM base_url is not configured'}), 500 - - payload = dict(data) - if payload.get('model') == 'llm' and model_engine: - payload['model'] = model_engine - - stream_requested = _as_bool(payload.get('stream', False)) - headers = {'Content-Type': 'application/json'} - if api_key: - headers['Authorization'] = f'Bearer {api_key}' - - try: - if stream_requested: - resp = requests.post(llm_url, headers=headers, json=payload, stream=True) - - def generate(): - try: - for chunk in resp.iter_content(chunk_size=8192): - if not chunk: - continue - yield chunk - finally: - resp.close() - - content_type = resp.headers.get("Content-Type", "text/event-stream") - if "charset=" not in content_type.lower(): - content_type = f"{content_type}; charset=utf-8" - return Response( - generate(), - status=resp.status_code, - content_type=content_type, - ) - - resp = requests.post(llm_url, headers=headers, json=payload, timeout=60) - content_type = resp.headers.get("Content-Type", "application/json") - if "charset=" not in content_type.lower(): - content_type = f"{content_type}; charset=utf-8" - return Response( - resp.content, - status=resp.status_code, - content_type=content_type, - ) - except Exception as exc: - return jsonify({'error': f'LLM request failed: {exc}'}), 500 - - last_content = "" - if 'messages' in data and data['messages']: - last_message = data['messages'][-1] - username = last_message.get('role', 'User') - if username == 'user': - username = 'User' - last_content = last_message.get('content', 'No content provided') - else: - last_content = 'No messages found' - username = 'User' + try: + model = data.get('model', 'fay') + if model == 'llm': + try: + config_util.load_config() + llm_url = _build_llm_url(config_util.gpt_base_url) + api_key = config_util.key_gpt_api_key + model_engine = config_util.gpt_model_engine + except Exception as exc: + return jsonify({'error': f'LLM config load failed: {exc}'}), 500 - observation = data.get('observation', '') + if not llm_url: + return jsonify({'error': 'LLM base_url is not configured'}), 500 + + payload = dict(data) + if payload.get('model') == 'llm' and model_engine: + payload['model'] = model_engine + + stream_requested = _as_bool(payload.get('stream', False)) + headers = {'Content-Type': 'application/json'} + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + + try: + if stream_requested: + resp = requests.post(llm_url, headers=headers, json=payload, stream=True) + + def generate(): + try: + for chunk in resp.iter_content(chunk_size=8192): + if not chunk: + continue + yield chunk + finally: + resp.close() + + content_type = resp.headers.get("Content-Type", "text/event-stream") + if "charset=" not in content_type.lower(): + content_type = f"{content_type}; charset=utf-8" + return Response( + generate(), + status=resp.status_code, + content_type=content_type, + ) + + resp = requests.post(llm_url, headers=headers, json=payload, timeout=60) + content_type = resp.headers.get("Content-Type", "application/json") + if "charset=" not in content_type.lower(): + content_type = f"{content_type}; charset=utf-8" + return Response( + resp.content, + status=resp.status_code, + content_type=content_type, + ) + except Exception as exc: + return jsonify({'error': f'LLM request failed: {exc}'}), 500 + + last_content = "" + username = "User" + messages = data.get("messages") + if isinstance(messages, list) and messages: + last_message = messages[-1] or {} + username = last_message.get("role", "User") or "User" + if username == "user": + username = "User" + last_content = last_message.get("content") or "" + elif isinstance(messages, str): + last_content = messages + + observation = data.get('observation', '') # 检查请求中是否指定了流式传输 - stream_requested = data.get('stream', False) + stream_requested = data.get('stream', False) no_reply = _as_bool(data.get('no_reply', data.get('noReply', False))) - if no_reply: - interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream': bool(stream_requested), 'no_reply': True}) - util.printInfo(1, username, '[text chat no_reply]{}'.format(interact.data["msg"]), time.time()) - fay_booter.feiFei.on_interact(interact) - if stream_requested or model == 'fay-streaming': - def generate(): - message = { - "id": "faystreaming-" + str(uuid.uuid4()), - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "delta": { - "content": "" - }, - "index": 0, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": len(last_content), - "completion_tokens": 0, - "total_tokens": len(last_content) - }, - "system_fingerprint": "", - "no_reply": True - } - yield f"data: {json.dumps(message)}\n\n" - yield 'data: [DONE]\n\n' - return Response(generate(), mimetype='text/event-stream') - return jsonify({ - "id": "fay-" + str(uuid.uuid4()), - "object": "chat.completion", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "" - }, - "logprobs": "", - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": len(last_content), - "completion_tokens": 0, - "total_tokens": len(last_content) - }, - "system_fingerprint": "", - "no_reply": True - }) - if stream_requested or model == 'fay-streaming': + obs_text = "" + if observation is not None: + obs_text = observation.strip() if isinstance(observation, str) else str(observation).strip() + message_text = last_content.strip() if isinstance(last_content, str) else str(last_content).strip() + if not message_text and not obs_text: + return jsonify({'error': 'messages and observation are both empty'}), 400 + if not message_text and obs_text: + no_reply = True + if no_reply: + interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream': bool(stream_requested), 'no_reply': True}) + util.printInfo(1, username, '[text chat no_reply]{}'.format(interact.data["msg"]), time.time()) + fay_booter.feiFei.on_interact(interact) + if stream_requested or model == 'fay-streaming': + def generate(): + message = { + "id": "faystreaming-" + str(uuid.uuid4()), + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "delta": { + "content": "" + }, + "index": 0, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": len(last_content), + "completion_tokens": 0, + "total_tokens": len(last_content) + }, + "system_fingerprint": "", + "no_reply": True + } + yield f"data: {json.dumps(message)}\n\n" + yield 'data: [DONE]\n\n' + return Response(generate(), mimetype='text/event-stream') + return jsonify({ + "id": "fay-" + str(uuid.uuid4()), + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "" + }, + "logprobs": "", + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": len(last_content), + "completion_tokens": 0, + "total_tokens": len(last_content) + }, + "system_fingerprint": "", + "no_reply": True + }) + if stream_requested or model == 'fay-streaming': interact = Interact("text", 1, {'user': username, 'msg': last_content, 'observation': str(observation), 'stream':True}) util.printInfo(1, username, '[文字沟通接口(流式)]{}'.format(interact.data["msg"]), time.time()) fay_booter.feiFei.on_interact(interact) @@ -588,7 +597,6 @@ def api_delete_user(): # 清除缓存的 agent 对象 try: - from llm import nlp_cognitive_stream if hasattr(nlp_cognitive_stream, 'agents') and username in nlp_cognitive_stream.agents: del nlp_cognitive_stream.agents[username] except Exception: diff --git a/llm/nlp_cognitive_stream.py b/llm/nlp_cognitive_stream.py index 776c0cc..c41aade 100644 --- a/llm/nlp_cognitive_stream.py +++ b/llm/nlp_cognitive_stream.py @@ -65,6 +65,11 @@ memory_cleared = False # 添加记忆清除标记 # 新增: 当前会话用户名及按用户获取memory目录的辅助函数 current_username = None # 当前会话用户名 +def _log_prompt(messages: List[SystemMessage | HumanMessage], tag: str = ""): + """No-op placeholder for prompt logging (disabled).""" + return + + llm = ChatOpenAI( model=cfg.gpt_model_engine, base_url=cfg.gpt_base_url, @@ -314,14 +319,26 @@ def _remove_prestart_from_text(text: str, keep_marked: bool = True) -> str: return text.strip() -def _remove_think_from_text(text: str) -> str: - """从文本中移除 think 标签及其内容""" - if not text: - return text - import re - cleaned = re.sub(r'[\s\S]*?', '', text, flags=re.IGNORECASE) - cleaned = re.sub(r'', '', cleaned, flags=re.IGNORECASE) - return cleaned.strip() +def _remove_think_from_text(text: str) -> str: + """从文本中移除 think 标签及其内容""" + if not text: + return text + import re + cleaned = re.sub(r'[\s\S]*?', '', text, flags=re.IGNORECASE) + cleaned = re.sub(r'', '', cleaned, flags=re.IGNORECASE) + return cleaned.strip() + + +def _strip_json_code_fence(text: str) -> str: + """Strip ```json ... ``` wrappers if present.""" + if not text: + return text + import re + trimmed = text.strip() + match = re.match(r"^```(?:json)?\\s*(.*?)\\s*```$", trimmed, flags=re.IGNORECASE | re.DOTALL) + if match: + return match.group(1).strip() + return text def _format_conversation_block(conversation: List[Dict], username: str = "User") -> str: @@ -440,35 +457,35 @@ def _run_prestart_tools(user_question: str) -> List[Dict[str, Any]]: return results -def _truncate_history( - history: List[ToolResult], - limit: Optional[int] = None, - output_limit: Optional[int] = None, -) -> str: - if not history: - return "(暂无)" - lines: List[str] = [] - selected = history if limit is None else history[-limit:] - for item in selected: - call = item.get("call", {}) - name = call.get("name", "未知工具") - attempt = item.get("attempt", 0) - success = item.get("success", False) - status = "成功" if success else "失败" - lines.append(f"- {name} 第 {attempt} 次 → {status}") - output = item.get("output") - if output is not None: - output_text = str(output) - if output_limit is not None: - output_text = _truncate_text(output_text, output_limit) - lines.append(" 输出:" + output_text) - error = item.get("error") - if error is not None: - error_text = str(error) - if output_limit is not None: - error_text = _truncate_text(error_text, output_limit) - lines.append(" 错误:" + error_text) - return "\n".join(lines) +def _truncate_history( + history: List[ToolResult], + limit: Optional[int] = None, + output_limit: Optional[int] = None, +) -> str: + if not history: + return "(暂无)" + lines: List[str] = [] + selected = history if limit is None else history[-limit:] + for item in selected: + call = item.get("call", {}) + name = call.get("name", "未知工具") + attempt = item.get("attempt", 0) + success = item.get("success", False) + status = "成功" if success else "失败" + lines.append(f"- {name} 第 {attempt} 次 → {status}") + output = item.get("output") + if output is not None: + output_text = str(output) + if output_limit is not None: + output_text = _truncate_text(output_text, output_limit) + lines.append(" 输出:" + output_text) + error = item.get("error") + if error is not None: + error_text = str(error) + if output_limit is not None: + error_text = _truncate_text(error_text, output_limit) + lines.append(" 错误:" + error_text) + return "\n".join(lines) def _format_schema_parameters(schema: Dict[str, Any]) -> List[str]: @@ -598,9 +615,9 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess # 生成对话文本,使用代码块包裹每条消息 convo_text = _format_conversation_block(conversation, username) - history_text = _truncate_history(history) - tools_text = _format_tools_for_prompt(tool_specs) - preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else "" + history_text = _truncate_history(history) + tools_text = _format_tools_for_prompt(tool_specs) + preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else "" # 只有当有预启动工具结果时才显示,工具名+参数在外,结果在代码块内 if prestart_context and prestart_context.strip(): @@ -638,7 +655,6 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess {observation or '(无补充)'} --- -**关联记忆** {memory_context or '(无相关记忆)'} --- {prestart_section} @@ -682,9 +698,9 @@ def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessag display_username = "主人" if username == "User" else username # 生成对话文本,使用代码块包裹每条消息 - conversation_block = _format_conversation_block(conversation, username) - history_text = _truncate_history(state.get("tool_results", [])) - preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else "" + conversation_block = _format_conversation_block(conversation, username) + history_text = _truncate_history(state.get("tool_results", [])) + preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else "" # 只有当有预启动工具结果时才显示,工具名+参数在外,结果在代码块内 if prestart_context and prestart_context.strip(): @@ -755,6 +771,7 @@ def _call_planner_llm( 解析后的决策字典 """ messages = _build_planner_messages(state) + _log_prompt(messages, tag="planner") # 如果有流式回调,使用流式模式检测 finish+message if stream_callback is not None: @@ -826,6 +843,7 @@ def _call_planner_llm( # 处理完整响应 trimmed = _remove_think_from_text(accumulated.strip()) + trimmed = _strip_json_code_fence(trimmed) if in_message_mode: # 提取完整 message 内容,去掉结尾的 "} @@ -869,6 +887,7 @@ def _call_planner_llm( raise RuntimeError("规划器返回内容异常,未获得字符串。") # 先移除 think 标签,兼容带思考标签的模型(如 DeepSeek R1) trimmed = _remove_think_from_text(content.strip()) + trimmed = _strip_json_code_fence(trimmed) try: decision = json.loads(trimmed) except json.JSONDecodeError as exc: @@ -929,6 +948,7 @@ def _plan_next_action(state: AgentState) -> AgentState: "规划器:检测到工具重复调用,使用最新结果产出最终回复。" ) final_messages = _build_final_messages(state) + _log_prompt(final_messages, tag="final") preview = last_entry.get("output") return { "status": "completed", @@ -1669,27 +1689,57 @@ def load_agent_memory(agent, username=None): # 记忆对话内容的线程函数 def remember_conversation_thread(username, content, response_text): - """ - 在单独线程中记录对话内容到代理记忆 - - 参数: - username: 用户名 - content: 用户问题内容 - response_text: 代理回答内容 - """ - global agents + """Background task to store a conversation memory node.""" try: + ag = create_agent(username) + if ag is None: + return + questioner = username + if isinstance(questioner, str) and questioner.lower() == "user": + questioner = "主人" + answerer = ag.scratch.get("first_name", "Fay") + question_text = content if content is not None else "" + answer_text = response_text if response_text is not None else "" + memory_content = f"{questioner}:{question_text}\n{answerer}:{answer_text}" with agent_lock: - ag = agents.get(username) - if ag is None: - return time_step = get_current_time_step(username) - name = "主人" if username == "User" else username - # 记录对话内容 - memory_content = f"在对话中,我回答了{name}的问题:{content}\n,我的回答是:{response_text}" - ag.remember(memory_content, time_step) + if ag.memory_stream and hasattr(ag.memory_stream, "remember_conversation"): + ag.memory_stream.remember_conversation(memory_content, time_step) + else: + ag.remember(memory_content, time_step) except Exception as e: - util.log(1, f"记忆对话内容出错: {str(e)}") + util.log(1, f"记录对话记忆失败: {str(e)}") + +def remember_observation_thread(username, observation_text): + """Background task to store an observation memory node.""" + try: + ag = create_agent(username) + if ag is None: + return + text = observation_text if observation_text is not None else "" + memory_content = text + with agent_lock: + time_step = get_current_time_step(username) + if ag.memory_stream and hasattr(ag.memory_stream, "remember"): + ag.memory_stream.remember(memory_content, time_step) + else: + ag.remember(memory_content, time_step) + except Exception as e: + util.log(1, f"记录观察记忆失败: {str(e)}") + +def record_observation(username, observation_text): + """Persist an observation memory node asynchronously.""" + if observation_text is None: + return False, "observation text is required" + text = observation_text.strip() if isinstance(observation_text, str) else str(observation_text).strip() + if not text: + return False, "observation text is required" + try: + MyThread(target=remember_observation_thread, args=(username, text)).start() + return True, "observation recorded" + except Exception as exc: + util.log(1, f"记录观察记忆失败: {exc}") + return False, f"observation record failed: {exc}" def question(content, username, observation=None): """处理用户提问并返回回复。""" @@ -1725,24 +1775,49 @@ def question(content, username, observation=None): ), } + memory_sections = [ + ("观察记忆", "observation"), + ("对话记忆", "conversation"), + ("反思记忆", "reflection"), + ] memory_context = "" - if agent.memory_stream and len(agent.memory_stream.seq_nodes) > 0: + if agent.memory_stream and len(agent.memory_stream.seq_nodes) > 0 and content: current_time_step = get_current_time_step(username) + query = content.strip() if isinstance(content, str) else str(content) + max_per_type = 10 + section_texts = [] try: - query = f"{'主人' if username == 'User' else username}提出了问题:{content}" - related_memories = agent.memory_stream.retrieve( + combined = agent.memory_stream.retrieve( [query], current_time_step, - n_count=30, + n_count=max_per_type * len(memory_sections), curr_filter="all", hp=[0.8, 0.5, 0.5], stateless=False, ) - if related_memories and query in related_memories: - memory_nodes = related_memories[query] - memory_context = "\n".join(f"- {node.content}" for node in memory_nodes) + all_nodes = combined.get(query, []) if combined else [] except Exception as exc: - util.log(1, f"获取相关记忆时出错: {exc}") + util.log(1, f"获取关联记忆时出错: {exc}") + all_nodes = [] + + for label, node_type in memory_sections: + section_lines = "(无)" + try: + memory_nodes = [n for n in all_nodes if getattr(n, "node_type", "") == node_type][:max_per_type] + if memory_nodes: + formatted = [] + for node in memory_nodes: + ts = (getattr(node, "datetime", "") or "").strip() + prefix = f"[{ts}] " if ts else "" + formatted.append(f"- {prefix}{node.content}") + section_lines = "\n".join(formatted) + except Exception as exc: + util.log(1, f"获取{label}时出错: {exc}") + section_lines = "(获取失败)" + section_texts.append(f"**{label}**\n{section_lines}") + memory_context = "\n".join(section_texts) + else: + memory_context = "\n".join(f"**{label}**\n(无)" for label, _ in memory_sections) prestart_context = "" prestart_stream_text = "" @@ -1866,22 +1941,22 @@ def question(content, username, observation=None): or messages_buffer[-1]['content'] != content ): messages_buffer.append({"role": "user", "content": content}) - else: - # 不隔离:按独立消息存储,保留用户名信息 - def append_to_buffer_multi(role: str, text_value: str, msg_username: str = "") -> None: - if not text_value: - return - messages_buffer.append({"role": role, "content": text_value, "username": msg_username}) - if len(messages_buffer) > 60: - del messages_buffer[:-60] - - def append_to_buffer(role: str, text_value: str) -> None: - append_to_buffer_multi(role, text_value, "") - - for record in history_records: - msg_type, msg_text, msg_username = record - if not msg_text: - continue + else: + # 不隔离:按独立消息存储,保留用户名信息 + def append_to_buffer_multi(role: str, text_value: str, msg_username: str = "") -> None: + if not text_value: + return + messages_buffer.append({"role": role, "content": text_value, "username": msg_username}) + if len(messages_buffer) > 60: + del messages_buffer[:-60] + + def append_to_buffer(role: str, text_value: str) -> None: + append_to_buffer_multi(role, text_value, "") + + for record in history_records: + msg_type, msg_text, msg_username = record + if not msg_text: + continue if msg_type and msg_type.lower() in ('member', 'user'): append_to_buffer_multi("user", msg_text, msg_username) else: @@ -2020,22 +2095,22 @@ def question(content, username, observation=None): # 创建规划器流式回调,用于实时输出 finish+message 响应 planner_stream_buffer = {"text": "", "first_chunk": True} - def planner_stream_callback(chunk_text: str) -> None: - """规划器流式回调,将 message 内容实时输出""" - nonlocal accumulated_text, full_response_text, is_first_sentence, is_agent_think_start - if not chunk_text: - return - planner_stream_buffer["text"] += chunk_text - if planner_stream_buffer["first_chunk"]: - planner_stream_buffer["first_chunk"] = False - if is_agent_think_start: - closing = "" - accumulated_text += closing - full_response_text += closing - is_agent_think_start = False - # 使用 stream_response_chunks 的逻辑进行分句流式输出 - accumulated_text += chunk_text - full_response_text += chunk_text + def planner_stream_callback(chunk_text: str) -> None: + """规划器流式回调,将 message 内容实时输出""" + nonlocal accumulated_text, full_response_text, is_first_sentence, is_agent_think_start + if not chunk_text: + return + planner_stream_buffer["text"] += chunk_text + if planner_stream_buffer["first_chunk"]: + planner_stream_buffer["first_chunk"] = False + if is_agent_think_start: + closing = "" + accumulated_text += closing + full_response_text += closing + is_agent_think_start = False + # 使用 stream_response_chunks 的逻辑进行分句流式输出 + accumulated_text += chunk_text + full_response_text += chunk_text # 检查是否有完整句子可以输出 if len(accumulated_text) >= 20: while True: @@ -2507,46 +2582,46 @@ def save_agent_memory(): agent.scratch = {} # 保存记忆前进行完整性检查 - try: - # 检查seq_nodes中的每个节点是否有效 - valid_nodes = [] - for node in agent.memory_stream.seq_nodes: - if node is None: - util.log(1, "发现无效节点(None),跳过") - continue - - if not hasattr(node, 'node_id') or not hasattr(node, 'content'): - util.log(1, f"发现无效节点(缺少必要属性),跳过") - continue - raw_content = node.content if isinstance(node.content, str) else str(node.content) - cleaned_content = _remove_think_from_text(raw_content) - if cleaned_content != raw_content: - old_content = raw_content - node.content = cleaned_content - if ( - agent.memory_stream.embeddings is not None - and old_content in agent.memory_stream.embeddings - and cleaned_content not in agent.memory_stream.embeddings - ): - agent.memory_stream.embeddings[cleaned_content] = agent.memory_stream.embeddings[old_content] - else: - node.content = raw_content - valid_nodes.append(node) - - # 更新seq_nodes为有效节点列表 - agent.memory_stream.seq_nodes = valid_nodes - - # 重建id_to_node字典 - agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')} - if agent.memory_stream.embeddings is not None: - kept_contents = {node.content for node in valid_nodes if hasattr(node, 'content')} - agent.memory_stream.embeddings = { - key: value - for key, value in agent.memory_stream.embeddings.items() - if key in kept_contents - } - except Exception as e: - util.log(1, f"检查记忆完整性时出错: {str(e)}") + try: + # 检查seq_nodes中的每个节点是否有效 + valid_nodes = [] + for node in agent.memory_stream.seq_nodes: + if node is None: + util.log(1, "发现无效节点(None),跳过") + continue + + if not hasattr(node, 'node_id') or not hasattr(node, 'content'): + util.log(1, f"发现无效节点(缺少必要属性),跳过") + continue + raw_content = node.content if isinstance(node.content, str) else str(node.content) + cleaned_content = _remove_think_from_text(raw_content) + if cleaned_content != raw_content: + old_content = raw_content + node.content = cleaned_content + if ( + agent.memory_stream.embeddings is not None + and old_content in agent.memory_stream.embeddings + and cleaned_content not in agent.memory_stream.embeddings + ): + agent.memory_stream.embeddings[cleaned_content] = agent.memory_stream.embeddings[old_content] + else: + node.content = raw_content + valid_nodes.append(node) + + # 更新seq_nodes为有效节点列表 + agent.memory_stream.seq_nodes = valid_nodes + + # 重建id_to_node字典 + agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')} + if agent.memory_stream.embeddings is not None: + kept_contents = {node.content for node in valid_nodes if hasattr(node, 'content')} + agent.memory_stream.embeddings = { + key: value + for key, value in agent.memory_stream.embeddings.items() + if key in kept_contents + } + except Exception as e: + util.log(1, f"检查记忆完整性时出错: {str(e)}") # 保存记忆 try: diff --git a/test/test_fay_gpt_stream.py b/test/test_fay_gpt_stream.py index 2be854d..6574361 100644 --- a/test/test_fay_gpt_stream.py +++ b/test/test_fay_gpt_stream.py @@ -91,7 +91,7 @@ if __name__ == "__main__": print("=" * 60) print("示例1:张三的对话(带观察数据)") print("=" * 60) - test_gpt("你好,今天天气不错啊", username="user", observation=OBSERVATION_SAMPLES["张三"]) + test_gpt(prompt="", username="user", observation=OBSERVATION_SAMPLES["张三"], no_reply=False) print("\n")