diff --git a/fay_booter.py b/fay_booter.py index f37e569..0ea0543 100644 --- a/fay_booter.py +++ b/fay_booter.py @@ -1,393 +1,385 @@ -#核心启动模块 -import time -import re -import pyaudio -import socket -import requests -from core.interact import Interact -from core.recorder import Recorder -from scheduler.thread_manager import MyThread -from utils import util, config_util, stream_util -from core.wsa_server import MyServer -from core import wsa_server -from core import socket_bridge_service -# from llm.nlp_cognitive_stream import save_agent_memory - -# 全局变量声明 -feiFei = None -recorderListener = None -__running = False -deviceSocketServer = None -DeviceInputListenerDict = {} -ngrok = None -socket_service_instance = None - -# 延迟导入fay_core -def get_fay_core(): - from core import fay_core - return fay_core - -#启动状态 -def is_running(): - return __running - -#录制麦克风音频输入并传给aliyun -class RecorderListener(Recorder): - - def __init__(self, device, fei): - self.__device = device - self.__FORMAT = pyaudio.paInt16 - self.__running = False - self.username = 'User' - # 这两个参数会在 get_stream 中根据实际设备更新 - self.channels = None - self.sample_rate = None - super().__init__(fei) - - def on_speaking(self, text): - if len(text) > 1: - interact = Interact("mic", 1, {'user': 'User', 'msg': text}) - util.printInfo(3, "语音", '{}'.format(interact.data["msg"]), time.time()) - feiFei.on_interact(interact) - - def get_stream(self): - try: - while True: - config_util.load_config() - record = config_util.config['source']['record'] - if record['enabled']: - break - time.sleep(0.1) - - self.paudio = pyaudio.PyAudio() - - # 获取默认输入设备的信息 - default_device = self.paudio.get_default_input_device_info() - self.channels = min(int(default_device.get('maxInputChannels', 1)), 2) # 最多使用2个通道 - # self.sample_rate = int(default_device.get('defaultSampleRate', 16000)) - - util.printInfo(1, "系统", f"默认麦克风信息 - 采样率: {self.sample_rate}Hz, 通道数: {self.channels}") - - # 使用系统默认麦克风 - self.stream = self.paudio.open( - format=self.__FORMAT, - channels=self.channels, - rate=self.sample_rate, - input=True, - frames_per_buffer=1024 - ) - - self.__running = True - MyThread(target=self.__pyaudio_clear).start() - - except Exception as e: - util.log(1, f"打开麦克风时出错: {str(e)}") - util.printInfo(1, self.username, "请检查录音设备是否有误,再重新启动!") - time.sleep(10) - return self.stream - - def __pyaudio_clear(self): - try: - while self.__running: - time.sleep(30) - except Exception as e: - util.log(1, f"音频清理线程出错: {str(e)}") - finally: - if hasattr(self, 'stream') and self.stream: - try: - self.stream.stop_stream() - self.stream.close() - except Exception as e: - util.log(1, f"关闭音频流时出错: {str(e)}") - - def stop(self): - super().stop() - self.__running = False - time.sleep(0.1)#给清理线程一点处理时间 - try: - while self.is_reading:#是为了确保停止的时候麦克风没有刚好在读取音频的 - time.sleep(0.1) - if self.stream is not None: - self.stream.stop_stream() - self.stream.close() - self.paudio.terminate() - except Exception as e: - print(e) - util.log(1, "请检查设备是否有误,再重新启动!") - - def is_remote(self): - return False - - - - -#Edit by xszyou on 20230113:录制远程设备音频输入并传给aliyun -class DeviceInputListener(Recorder): - def __init__(self, deviceConnector, fei): - super().__init__(fei) - self.__running = True - self.streamCache = None - self.thread = MyThread(target=self.run) - self.thread.start() #启动远程音频输入设备监听线程 - self.username = 'User' - self.isOutput = True - self.deviceConnector = deviceConnector - - def run(self): - #启动ngork - self.streamCache = stream_util.StreamCache(1024*1024*20) - addr = None - while self.__running: - try: - - data = b"" - while self.deviceConnector: - data = self.deviceConnector.recv(2048) - if b"" in data: - data_str = data.decode("utf-8") - match = re.search(r"(.*?)", data_str) - if match: - self.username = match.group(1) - else: - self.streamCache.write(data) - if b"" in data: - data_str = data.decode("utf-8") - match = re.search(r"(.*?)", data_str) - if match: - self.isOutput = (match.group(1) == "True") - else: - self.streamCache.write(data) - if not b"" in data and not b"" in data: - self.streamCache.write(data) - time.sleep(0.005) - self.streamCache.clear() - - except Exception as err: - pass - time.sleep(1) - - def on_speaking(self, text): - global feiFei - if len(text) > 1: - interact = Interact("socket", 1, {"user": self.username, "msg": text, "socket": self.deviceConnector}) - util.printInfo(3, "(" + self.username + ")远程音频输入", '{}'.format(interact.data["msg"]), time.time()) - feiFei.on_interact(interact) - - #recorder会等待stream不为空才开始录音 - def get_stream(self): - while not self.deviceConnector: - time.sleep(1) - pass - return self.streamCache - - def stop(self): - super().stop() - self.__running = False - - def is_remote(self): - return True - -#检查远程音频连接状态 -def device_socket_keep_alive(): - global DeviceInputListenerDict - while __running: - delkey = None - for key, value in DeviceInputListenerDict.items(): - try: - value.deviceConnector.send(b'\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8')#发送心跳包 - if wsa_server.get_web_instance().is_connected(value.username): - wsa_server.get_web_instance().add_cmd({"remote_audio_connect": True, "Username" : value.username}) - except Exception as serr: - util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key)) - value.stop() - delkey = key - break - if delkey: - value = DeviceInputListenerDict.pop(delkey) - if wsa_server.get_web_instance().is_connected(value.username): - wsa_server.get_web_instance().add_cmd({"remote_audio_connect": False, "Username" : value.username}) - time.sleep(10) - -#远程音频连接 -def accept_audio_device_output_connect(): - global deviceSocketServer - global __running - global DeviceInputListenerDict - deviceSocketServer = socket.socket(socket.AF_INET,socket.SOCK_STREAM) - deviceSocketServer.bind(("0.0.0.0",10001)) - deviceSocketServer.listen(1) - MyThread(target = device_socket_keep_alive).start() # 开启心跳包检测 - addr = None - - while __running: - try: - deviceConnector,addr = deviceSocketServer.accept() #接受TCP连接,并返回新的套接字与IP地址 - deviceInputListener = DeviceInputListener(deviceConnector, feiFei) # 设备音频输入输出麦克风 - deviceInputListener.start() - - #把DeviceInputListenner对象记录下来 - peername = str(deviceConnector.getpeername()[0]) + ":" + str(deviceConnector.getpeername()[1]) - DeviceInputListenerDict[peername] = deviceInputListener - util.log(1,"远程音频{}输入输出设备连接上:{}".format(len(DeviceInputListenerDict), addr)) - except Exception as e: - pass - -#数字人端请求获取最新的自动播报消息,若自动播报服务关闭会自动退出自动播报 -def start_auto_play_service(): #TODO 评估一下有无优化的空间 - if config_util.config['source'].get('automatic_player_url') is None or config_util.config['source'].get('automatic_player_status') is None: - return - url = f"{config_util.config['source']['automatic_player_url']}/get_auto_play_item" - user = "User" #TODO 临时固死了 - is_auto_server_error = False - while __running: - if config_util.config['source']['wake_word_enabled'] and config_util.config['source']['wake_word_type'] == 'common' and recorderListener.wakeup_matched == True: - time.sleep(0.01) - continue - if is_auto_server_error: - util.printInfo(1, user, '60s后重连自动播报服务器') - time.sleep(60) - # 请求自动播报服务器 - with get_fay_core().auto_play_lock: - if config_util.config['source']['automatic_player_status'] and config_util.config['source']['automatic_player_url'] is not None and get_fay_core().can_auto_play == True and (config_util.config["interact"]["playSound"] or wsa_server.get_instance().is_connected(user)): - get_fay_core().can_auto_play = False - post_data = {"user": user} - try: - response = requests.post(url, json=post_data, timeout=5) - if response.status_code == 200: - is_auto_server_error = False - data = response.json() - audio_url = data.get('audio') - if not audio_url or audio_url.strip()[0:4] != "http": - audio_url = None - response_text = data.get('text') - if audio_url is None and (response_text is None or '' == response_text.strip()): - continue - timestamp = data.get('timestamp') - interact = Interact("auto_play", 2, {'user': user, 'text': response_text, 'audio': audio_url}) - util.printInfo(1, user, '自动播报:{},{}'.format(response_text, audio_url), time.time()) - feiFei.on_interact(interact) - else: - is_auto_server_error = True - get_fay_core().can_auto_play = True - util.printInfo(1, user, '请求自动播报服务器失败,错误代码是:{}'.format(response.status_code)) - except requests.exceptions.RequestException as e: - is_auto_server_error = True - get_fay_core().can_auto_play = True - util.printInfo(1, user, '请求自动播报服务器失败,错误信息是:{}'.format(e)) - time.sleep(0.01) - - - -#停止服务 -def stop(): - global feiFei - global recorderListener - global __running - global DeviceInputListenerDict - global ngrok - global socket_service_instance - global deviceSocketServer - - util.log(1, '正在关闭服务...') - __running = False - - # 断开所有MCP服务连接 - util.log(1, '正在断开所有MCP服务连接...') - try: - from faymcp import mcp_service - mcp_service.disconnect_all_mcp_servers() - util.log(1, '所有MCP服务连接已断开') - except Exception as e: - util.log(1, f'断开MCP服务连接失败: {str(e)}') - - # 保存代理记忆(仅在未使用仿生记忆时) - if not config_util.config["memory"].get("use_bionic_memory", False): - util.log(1, '正在保存代理记忆...') - try: - from llm.nlp_cognitive_stream import save_agent_memory - save_agent_memory() - util.log(1, '代理记忆保存成功') - except Exception as e: - util.log(1, f'保存代理记忆失败: {str(e)}') - - if recorderListener is not None: - util.log(1, '正在关闭录音服务...') - recorderListener.stop() - time.sleep(0.1) - util.log(1, '正在关闭远程音频输入输出服务...') - try: - if len(DeviceInputListenerDict) > 0: - for key in list(DeviceInputListenerDict.keys()): - value = DeviceInputListenerDict.pop(key) - value.stop() - deviceSocketServer.close() - if socket_service_instance is not None: - socket_service_instance.stop_server() - socket_service_instance = None - except: - pass - - util.log(1, '正在关闭核心服务...') - feiFei.stop() - util.log(1, '服务已关闭!') - - -#开启服务 -def start(): - global feiFei - global recorderListener - global __running - global socket_service_instance - - util.log(1, '开启服务...') - __running = True - - #读取配置 - util.log(1, '读取配置...') - config_util.load_config() - - #开启核心服务 - util.log(1, '开启核心服务...') - feiFei = get_fay_core().FeiFei() - feiFei.start() - - #根据配置决定是否初始化认知记忆系统 - if not config_util.config["memory"].get("use_bionic_memory", False): - util.log(1, '初始化定时保存记忆及反思的任务...') - from llm.nlp_cognitive_stream import init_memory_scheduler - init_memory_scheduler() - - #初始化知识库(两个模块共用) - util.log(1, '初始化本地知识库...') - if config_util.config["memory"].get("use_bionic_memory", False): - from llm.nlp_bionicmemory_stream import init_knowledge_base - else: - from llm.nlp_cognitive_stream import init_knowledge_base - init_knowledge_base() - - #开启录音服务 - record = config_util.config['source']['record'] - if record['enabled']: - util.log(1, '开启录音服务...') - recorderListener = RecorderListener('device', feiFei) # 监听麦克风 - recorderListener.start() - - #启动声音沟通接口服务 - util.log(1,'启动声音沟通接口服务...') - deviceSocketThread = MyThread(target=accept_audio_device_output_connect) - deviceSocketThread.start() - socket_service_instance = socket_bridge_service.new_instance() - socket_bridge_service_Thread = MyThread(target=socket_service_instance.start_service) - socket_bridge_service_Thread.start() - - #启动自动播报服务 - util.log(1,'启动自动播报服务...') - MyThread(target=start_auto_play_service).start() - - util.log(1, '服务启动完成!') - -if __name__ == '__main__': - ws_server: MyServer = None - feiFei: get_fay_core().FeiFei = None - recorderListener: Recorder = None - start() +#核心启动模块 +import time +import re +import pyaudio +import socket +import requests +from core.interact import Interact +from core.recorder import Recorder +from scheduler.thread_manager import MyThread +from utils import util, config_util, stream_util +from core.wsa_server import MyServer +from core import wsa_server +from core import socket_bridge_service +# from llm.nlp_cognitive_stream import save_agent_memory + +# 全局变量声明 +feiFei = None +recorderListener = None +__running = False +deviceSocketServer = None +DeviceInputListenerDict = {} +ngrok = None +socket_service_instance = None + +# 延迟导入fay_core +def get_fay_core(): + from core import fay_core + return fay_core + +#启动状态 +def is_running(): + return __running + +#录制麦克风音频输入并传给aliyun +class RecorderListener(Recorder): + + def __init__(self, device, fei): + self.__device = device + self.__FORMAT = pyaudio.paInt16 + self.__running = False + self.username = 'User' + # 这两个参数会在 get_stream 中根据实际设备更新 + self.channels = None + self.sample_rate = None + super().__init__(fei) + + def on_speaking(self, text): + if len(text) > 1: + interact = Interact("mic", 1, {'user': 'User', 'msg': text}) + util.printInfo(3, "语音", '{}'.format(interact.data["msg"]), time.time()) + feiFei.on_interact(interact) + + def get_stream(self): + try: + while True: + config_util.load_config() + record = config_util.config['source']['record'] + if record['enabled']: + break + time.sleep(0.1) + + self.paudio = pyaudio.PyAudio() + + # 获取默认输入设备的信息 + default_device = self.paudio.get_default_input_device_info() + self.channels = min(int(default_device.get('maxInputChannels', 1)), 2) # 最多使用2个通道 + # self.sample_rate = int(default_device.get('defaultSampleRate', 16000)) + + util.printInfo(1, "系统", f"默认麦克风信息 - 采样率: {self.sample_rate}Hz, 通道数: {self.channels}") + + # 使用系统默认麦克风 + self.stream = self.paudio.open( + format=self.__FORMAT, + channels=self.channels, + rate=self.sample_rate, + input=True, + frames_per_buffer=1024 + ) + + self.__running = True + MyThread(target=self.__pyaudio_clear).start() + + except Exception as e: + util.log(1, f"打开麦克风时出错: {str(e)}") + util.printInfo(1, self.username, "请检查录音设备是否有误,再重新启动!") + time.sleep(10) + return self.stream + + def __pyaudio_clear(self): + try: + while self.__running: + time.sleep(30) + except Exception as e: + util.log(1, f"音频清理线程出错: {str(e)}") + finally: + if hasattr(self, 'stream') and self.stream: + try: + self.stream.stop_stream() + self.stream.close() + except Exception as e: + util.log(1, f"关闭音频流时出错: {str(e)}") + + def stop(self): + super().stop() + self.__running = False + time.sleep(0.1)#给清理线程一点处理时间 + try: + while self.is_reading:#是为了确保停止的时候麦克风没有刚好在读取音频的 + time.sleep(0.1) + if self.stream is not None: + self.stream.stop_stream() + self.stream.close() + self.paudio.terminate() + except Exception as e: + print(e) + util.log(1, "请检查设备是否有误,再重新启动!") + + def is_remote(self): + return False + + + + +#Edit by xszyou on 20230113:录制远程设备音频输入并传给aliyun +class DeviceInputListener(Recorder): + def __init__(self, deviceConnector, fei): + super().__init__(fei) + self.__running = True + self.streamCache = None + self.thread = MyThread(target=self.run) + self.thread.start() #启动远程音频输入设备监听线程 + self.username = 'User' + self.isOutput = True + self.deviceConnector = deviceConnector + + def run(self): + #启动ngork + self.streamCache = stream_util.StreamCache(1024*1024*20) + addr = None + while self.__running: + try: + + data = b"" + while self.deviceConnector: + data = self.deviceConnector.recv(2048) + if b"" in data: + data_str = data.decode("utf-8") + match = re.search(r"(.*?)", data_str) + if match: + self.username = match.group(1) + else: + self.streamCache.write(data) + if b"" in data: + data_str = data.decode("utf-8") + match = re.search(r"(.*?)", data_str) + if match: + self.isOutput = (match.group(1) == "True") + else: + self.streamCache.write(data) + if not b"" in data and not b"" in data: + self.streamCache.write(data) + time.sleep(0.005) + self.streamCache.clear() + + except Exception as err: + pass + time.sleep(1) + + def on_speaking(self, text): + global feiFei + if len(text) > 1: + interact = Interact("socket", 1, {"user": self.username, "msg": text, "socket": self.deviceConnector}) + util.printInfo(3, "(" + self.username + ")远程音频输入", '{}'.format(interact.data["msg"]), time.time()) + feiFei.on_interact(interact) + + #recorder会等待stream不为空才开始录音 + def get_stream(self): + while not self.deviceConnector: + time.sleep(1) + pass + return self.streamCache + + def stop(self): + super().stop() + self.__running = False + + def is_remote(self): + return True + +#检查远程音频连接状态 +def device_socket_keep_alive(): + global DeviceInputListenerDict + while __running: + delkey = None + for key, value in DeviceInputListenerDict.items(): + try: + value.deviceConnector.send(b'\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8')#发送心跳包 + if wsa_server.get_web_instance().is_connected(value.username): + wsa_server.get_web_instance().add_cmd({"remote_audio_connect": True, "Username" : value.username}) + except Exception as serr: + util.printInfo(1, value.username, "远程音频输入输出设备已经断开:{}".format(key)) + value.stop() + delkey = key + break + if delkey: + value = DeviceInputListenerDict.pop(delkey) + if wsa_server.get_web_instance().is_connected(value.username): + wsa_server.get_web_instance().add_cmd({"remote_audio_connect": False, "Username" : value.username}) + time.sleep(10) + +#远程音频连接 +def accept_audio_device_output_connect(): + global deviceSocketServer + global __running + global DeviceInputListenerDict + deviceSocketServer = socket.socket(socket.AF_INET,socket.SOCK_STREAM) + deviceSocketServer.bind(("0.0.0.0",10001)) + deviceSocketServer.listen(1) + MyThread(target = device_socket_keep_alive).start() # 开启心跳包检测 + addr = None + + while __running: + try: + deviceConnector,addr = deviceSocketServer.accept() #接受TCP连接,并返回新的套接字与IP地址 + deviceInputListener = DeviceInputListener(deviceConnector, feiFei) # 设备音频输入输出麦克风 + deviceInputListener.start() + + #把DeviceInputListenner对象记录下来 + peername = str(deviceConnector.getpeername()[0]) + ":" + str(deviceConnector.getpeername()[1]) + DeviceInputListenerDict[peername] = deviceInputListener + util.log(1,"远程音频{}输入输出设备连接上:{}".format(len(DeviceInputListenerDict), addr)) + except Exception as e: + pass + +#数字人端请求获取最新的自动播报消息,若自动播报服务关闭会自动退出自动播报 +def start_auto_play_service(): #TODO 评估一下有无优化的空间 + if config_util.config['source'].get('automatic_player_url') is None or config_util.config['source'].get('automatic_player_status') is None: + return + url = f"{config_util.config['source']['automatic_player_url']}/get_auto_play_item" + user = "User" #TODO 临时固死了 + is_auto_server_error = False + while __running: + if config_util.config['source']['wake_word_enabled'] and config_util.config['source']['wake_word_type'] == 'common' and recorderListener.wakeup_matched == True: + time.sleep(0.01) + continue + if is_auto_server_error: + util.printInfo(1, user, '60s后重连自动播报服务器') + time.sleep(60) + # 请求自动播报服务器 + with get_fay_core().auto_play_lock: + if config_util.config['source']['automatic_player_status'] and config_util.config['source']['automatic_player_url'] is not None and get_fay_core().can_auto_play == True and (config_util.config["interact"]["playSound"] or wsa_server.get_instance().is_connected(user)): + get_fay_core().can_auto_play = False + post_data = {"user": user} + try: + response = requests.post(url, json=post_data, timeout=5) + if response.status_code == 200: + is_auto_server_error = False + data = response.json() + audio_url = data.get('audio') + if not audio_url or audio_url.strip()[0:4] != "http": + audio_url = None + response_text = data.get('text') + if audio_url is None and (response_text is None or '' == response_text.strip()): + continue + timestamp = data.get('timestamp') + interact = Interact("auto_play", 2, {'user': user, 'text': response_text, 'audio': audio_url}) + util.printInfo(1, user, '自动播报:{},{}'.format(response_text, audio_url), time.time()) + feiFei.on_interact(interact) + else: + is_auto_server_error = True + get_fay_core().can_auto_play = True + util.printInfo(1, user, '请求自动播报服务器失败,错误代码是:{}'.format(response.status_code)) + except requests.exceptions.RequestException as e: + is_auto_server_error = True + get_fay_core().can_auto_play = True + util.printInfo(1, user, '请求自动播报服务器失败,错误信息是:{}'.format(e)) + time.sleep(0.01) + + + +#停止服务 +def stop(): + global feiFei + global recorderListener + global __running + global DeviceInputListenerDict + global ngrok + global socket_service_instance + global deviceSocketServer + + util.log(1, '正在关闭服务...') + __running = False + + # 断开所有MCP服务连接 + util.log(1, '正在断开所有MCP服务连接...') + try: + from faymcp import mcp_service + mcp_service.disconnect_all_mcp_servers() + util.log(1, '所有MCP服务连接已断开') + except Exception as e: + util.log(1, f'断开MCP服务连接失败: {str(e)}') + + # 保存代理记忆(仅在未使用仿生记忆时) + if not config_util.config["memory"].get("use_bionic_memory", False): + util.log(1, '正在保存代理记忆...') + try: + from llm.nlp_cognitive_stream import save_agent_memory + save_agent_memory() + util.log(1, '代理记忆保存成功') + except Exception as e: + util.log(1, f'保存代理记忆失败: {str(e)}') + + if recorderListener is not None: + util.log(1, '正在关闭录音服务...') + recorderListener.stop() + time.sleep(0.1) + util.log(1, '正在关闭远程音频输入输出服务...') + try: + if len(DeviceInputListenerDict) > 0: + for key in list(DeviceInputListenerDict.keys()): + value = DeviceInputListenerDict.pop(key) + value.stop() + deviceSocketServer.close() + if socket_service_instance is not None: + socket_service_instance.stop_server() + socket_service_instance = None + except: + pass + + util.log(1, '正在关闭核心服务...') + feiFei.stop() + util.log(1, '服务已关闭!') + + +#开启服务 +def start(): + global feiFei + global recorderListener + global __running + global socket_service_instance + + util.log(1, '开启服务...') + __running = True + + #读取配置 + util.log(1, '读取配置...') + config_util.load_config() + + #开启核心服务 + util.log(1, '开启核心服务...') + feiFei = get_fay_core().FeiFei() + feiFei.start() + + #根据配置决定是否初始化认知记忆系统 + if not config_util.config["memory"].get("use_bionic_memory", False): + util.log(1, '初始化定时保存记忆及反思的任务...') + from llm.nlp_cognitive_stream import init_memory_scheduler + init_memory_scheduler() + + #开启录音服务 + record = config_util.config['source']['record'] + if record['enabled']: + util.log(1, '开启录音服务...') + recorderListener = RecorderListener('device', feiFei) # 监听麦克风 + recorderListener.start() + + #启动声音沟通接口服务 + util.log(1,'启动声音沟通接口服务...') + deviceSocketThread = MyThread(target=accept_audio_device_output_connect) + deviceSocketThread.start() + socket_service_instance = socket_bridge_service.new_instance() + socket_bridge_service_Thread = MyThread(target=socket_service_instance.start_service) + socket_bridge_service_Thread.start() + + #启动自动播报服务 + util.log(1,'启动自动播报服务...') + MyThread(target=start_auto_play_service).start() + + util.log(1, '服务启动完成!') + +if __name__ == '__main__': + ws_server: MyServer = None + feiFei: get_fay_core().FeiFei = None + recorderListener: Recorder = None + start() diff --git a/faymcp/data/mcp_servers.json b/faymcp/data/mcp_servers.json index cb4286b..0dfb388 100644 --- a/faymcp/data/mcp_servers.json +++ b/faymcp/data/mcp_servers.json @@ -1,46 +1,81 @@ -[ - { - "id": 1, - "name": "tools", - "ip": "", - "connection_time": "2025-11-11 11:44:56", - "key": "", - "transport": "stdio", - "command": "python", - "args": [ - "test/mcp_stdio_example.py" - ], - "cwd": "", - "env": {} - }, - { - "id": 2, - "name": "Fay日程管理", - "ip": "", - "connection_time": "2025-11-11 11:44:59", - "key": "", - "transport": "stdio", - "command": "python", - "args": [ - "server.py" - ], - "cwd": "mcp_servers/schedule_manager", - "env": {} - }, - { - "id": 3, - "name": "logseq", - "ip": "", - "connection_time": "2025-10-21 11:07:20", - "key": "", - "transport": "stdio", - "command": "python", - "args": [ - "server.py" - ], - "cwd": "mcp_servers/logseq", - "env": { - "LOGSEQ_GRAPH_DIR": "D:/iCloudDrive/iCloud~com~logseq~logseq/第二大脑" - } - } +[ + { + "id": 1, + "name": "tools", + "ip": "", + "connection_time": "2025-12-10 21:16:35", + "key": "", + "transport": "stdio", + "command": "python", + "args": [ + "test/mcp_stdio_example.py" + ], + "cwd": "", + "env": {} + }, + { + "id": 2, + "name": "Fay日程管理", + "ip": "", + "connection_time": "2025-12-10 21:16:38", + "key": "", + "transport": "stdio", + "command": "python", + "args": [ + "server.py" + ], + "cwd": "mcp_servers/schedule_manager", + "env": {} + }, + { + "id": 3, + "name": "logseq", + "ip": "", + "connection_time": "2025-12-10 21:16:39", + "key": "", + "transport": "stdio", + "command": "python", + "args": [ + "server.py" + ], + "cwd": "mcp_servers/logseq", + "env": { + "LOGSEQ_GRAPH_DIR": "D:/iCloudDrive/iCloud~com~logseq~logseq/第二大脑" + } + }, + { + "id": 4, + "name": "yueshen rag", + "ip": "", + "connection_time": "2025-12-10 21:16:44", + "key": "", + "transport": "stdio", + "command": "C:\\Users\\Lenovo\\anaconda3\\envs\\rag\\python.exe", + "args": [ + "mcp_servers/yueshen_rag/server.py" + ], + "cwd": "", + "env": { + "YUESHEN_AUTO_INGEST": "1", + "YUESHEN_AUTO_INTERVAL": "300", + "YUESHEN_AUTO_RESET_ON_START": "0", + "YUESHEN_EMBED_API_KEY": "sk-izmvqrzyhjghzyghiofqfpusxprmfljntxzggkcovtneqpas", + "YUESHEN_EMBED_BASE_URL": "https://api.siliconflow.cn/v1", + "YUESHEN_EMBED_MODEL": "Qwen/Qwen3-Embedding-8B" + } + }, + { + "id": 5, + "name": "window capture", + "ip": "", + "connection_time": "2025-12-10 21:16:45", + "key": "", + "transport": "stdio", + "command": "python", + "args": [ + "mcp_servers/window_capture/server.py" + ], + "cwd": "", + "env": {} + } ] \ No newline at end of file diff --git a/faymcp/data/mcp_tool_states.json b/faymcp/data/mcp_tool_states.json index c0a6beb..ab80d07 100644 --- a/faymcp/data/mcp_tool_states.json +++ b/faymcp/data/mcp_tool_states.json @@ -1,21 +1,44 @@ -{ - "5": { - "now": true, - "add": false, - "upper": false, - "echo": false, - "ping": false - }, - "1": { - "add": true, - "upper": false, - "echo": false, - "ping": false, - "now": true - }, - "2": { - "('meta', None)": true, - "('nextCursor', None)": true - }, - "3": {} -} +{ + "5": { + "now": true, + "add": false, + "upper": false, + "echo": false, + "ping": false, + "ingest_yueshen": false, + "yueshen_stats": false + }, + "1": { + "add": true, + "upper": false, + "echo": false, + "ping": false, + "now": true, + "display_media": true + }, + "2": { + "('meta', None)": true, + "('nextCursor', None)": true, + "browser_snapshot": false + }, + "3": { + "get_file_info": false, + "list_allowed_directories": true, + "edit_file": false, + "list_directory": false, + "directory_tree": false, + "create_directory": false, + "read_text_file": false, + "read_multiple_files": false, + "read_media_file": false, + "read_file": false, + "move_file": false, + "list_directory_with_sizes": false, + "search_files": false, + "write_file": false + }, + "4": {}, + "6": { + "list_windows": true + } +} \ No newline at end of file diff --git a/gui/flask_server.py b/gui/flask_server.py index cd9c9f1..febc073 100644 --- a/gui/flask_server.py +++ b/gui/flask_server.py @@ -883,6 +883,63 @@ def api_start_genagents(): util.log(1, f"启动决策分析页面时出错: {str(e)}") return jsonify({'success': False, 'message': f'启动决策分析页面时出错: {str(e)}'}), 500 +# 获取本地图片(用于在网页中显示本地图片) +@__app.route('/api/local-image') +def api_local_image(): + try: + file_path = request.args.get('path', '') + if not file_path: + return jsonify({'error': '缺少文件路径参数'}), 400 + + # 检查文件是否存在 + if not os.path.exists(file_path): + return jsonify({'error': f'文件不存在: {file_path}'}), 404 + + # 检查是否为图片文件 + valid_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp') + if not file_path.lower().endswith(valid_extensions): + return jsonify({'error': '不是有效的图片文件'}), 400 + + # 返回图片文件 + return send_file(file_path) + except Exception as e: + return jsonify({'error': f'获取图片时出错: {str(e)}'}), 500 + +# 打开图片文件(使用系统默认程序) +@__app.route('/api/open-image', methods=['POST']) +def api_open_image(): + try: + data = request.get_json() + if not data or 'path' not in data: + return jsonify({'success': False, 'message': '缺少文件路径参数'}), 400 + + file_path = data['path'] + + # 检查文件是否存在 + if not os.path.exists(file_path): + return jsonify({'success': False, 'message': f'文件不存在: {file_path}'}), 404 + + # 检查是否为图片文件 + valid_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp') + if not file_path.lower().endswith(valid_extensions): + return jsonify({'success': False, 'message': '不是有效的图片文件'}), 400 + + # 使用系统默认程序打开图片 + import subprocess + import platform + + system = platform.system() + if system == 'Windows': + os.startfile(file_path) + elif system == 'Darwin': # macOS + subprocess.run(['open', file_path]) + else: # Linux + subprocess.run(['xdg-open', file_path]) + + return jsonify({'success': True, 'message': '已打开图片'}), 200 + except Exception as e: + return jsonify({'success': False, 'message': f'打开图片时出错: {str(e)}'}), 500 + def run(): class NullLogHandler: def write(self, *args, **kwargs): diff --git a/gui/static/css/index.css b/gui/static/css/index.css index 3653ce3..ca133d2 100644 --- a/gui/static/css/index.css +++ b/gui/static/css/index.css @@ -591,6 +591,181 @@ html { } .message-text { - white-space: pre-wrap; word-wrap: break-word; +} + +/* Markdown 样式 */ +.markdown-body { + line-height: 1.6; + font-size: 14px; +} + +.markdown-body p { + margin: 0 0 8px 0; +} + +.markdown-body p:last-child { + margin-bottom: 0; +} + +.markdown-body h1, .markdown-body h2, .markdown-body h3, +.markdown-body h4, .markdown-body h5, .markdown-body h6 { + margin: 12px 0 8px 0; + font-weight: 600; + line-height: 1.4; +} + +.markdown-body h1 { font-size: 1.5em; } +.markdown-body h2 { font-size: 1.3em; } +.markdown-body h3 { font-size: 1.15em; } +.markdown-body h4 { font-size: 1em; } + +.markdown-body ul, .markdown-body ol { + margin: 8px 0; + padding-left: 20px; +} + +.markdown-body li { + margin: 4px 0; +} + +.markdown-body code { + background-color: rgba(175, 184, 193, 0.2); + padding: 2px 6px; + border-radius: 4px; + font-family: 'Consolas', 'Monaco', monospace; + font-size: 0.9em; +} + +.markdown-body pre { + background-color: #f6f8fa; + padding: 12px; + border-radius: 6px; + overflow-x: auto; + margin: 8px 0; +} + +.markdown-body pre code { + background-color: transparent; + padding: 0; + font-size: 0.85em; + line-height: 1.5; +} + +.markdown-body blockquote { + margin: 8px 0; + padding: 8px 12px; + border-left: 4px solid #dfe2e5; + color: #6a737d; + background-color: #f6f8fa; +} + +.markdown-body blockquote p { + margin: 0; +} + +.markdown-body a { + color: #0366d6; + text-decoration: none; +} + +.markdown-body a:hover { + text-decoration: underline; +} + +.markdown-body table { + border-collapse: collapse; + margin: 8px 0; + width: 100%; +} + +.markdown-body th, .markdown-body td { + border: 1px solid #dfe2e5; + padding: 6px 12px; + text-align: left; +} + +.markdown-body th { + background-color: #f6f8fa; + font-weight: 600; +} + +.markdown-body hr { + border: none; + border-top: 1px solid #dfe2e5; + margin: 12px 0; +} + +.markdown-body strong { + font-weight: 600; +} + +.markdown-body em { + font-style: italic; +} + +/* 图片缩略图样式 */ +.image-thumbnail-container { + display: inline-block; + position: relative; + cursor: pointer; + margin: 4px 2px; + border-radius: 6px; + overflow: hidden; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15); + transition: transform 0.2s ease, box-shadow 0.2s ease; + max-width: 200px; + vertical-align: middle; +} + +.image-thumbnail-container:hover { + transform: scale(1.02); + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.25); +} + +.message-image-thumbnail { + display: block; + max-width: 200px; + max-height: 150px; + object-fit: cover; + border-radius: 6px; +} + +.image-zoom-hint { + position: absolute; + bottom: 0; + left: 0; + right: 0; + background: rgba(0, 0, 0, 0.6); + color: #fff; + font-size: 11px; + text-align: center; + padding: 3px 0; + opacity: 0; + transition: opacity 0.2s ease; +} + +.image-thumbnail-container:hover .image-zoom-hint { + opacity: 1; +} + +.image-path-text { + display: inline-block; + padding: 4px 8px; + background-color: #f0f4ff; + border: 1px solid #d0d8e8; + border-radius: 4px; + font-size: 12px; + color: #617bab; + word-break: break-all; +} + +/* prestart 内容中的图片缩略图 */ +.prestart-content-inline .image-thumbnail-container { + max-width: 150px; +} + +.prestart-content-inline .message-image-thumbnail { + max-width: 150px; + max-height: 100px; } \ No newline at end of file diff --git a/gui/static/js/index.js b/gui/static/js/index.js index b2b5db8..48411d4 100644 --- a/gui/static/js/index.js +++ b/gui/static/js/index.js @@ -1,4 +1,30 @@ // fayApp.js + +// 全局函数:打开图片文件 +window.openImageFile = function(encodedPath) { + const filePath = decodeURIComponent(encodedPath); + const baseUrl = window.location.protocol + '//' + window.location.hostname + ':' + window.location.port; + + fetch(`${baseUrl}/api/open-image`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ path: filePath }) + }) + .then(response => response.json()) + .then(data => { + if (!data.success) { + console.error('打开图片失败:', data.message); + alert('打开图片失败: ' + data.message); + } + }) + .catch(error => { + console.error('请求失败:', error); + alert('打开图片时发生错误'); + }); +}; + class FayInterface { constructor(baseWsUrl, baseApiUrl, vueInstance) { this.baseWsUrl = baseWsUrl; @@ -634,12 +660,14 @@ unadoptText(id) { let mainContent = content; let prestartContent = ''; - // 解析 prestart 标签 - const prestartRegex = /([\s\S]*?)<\/prestart>/i; + // 解析 prestart 标签 - 使用贪婪匹配确保匹配到最后一个 + // 同时支持多个 prestart 标签的情况 + const prestartRegex = /([\s\S]*)<\/prestart>/i; const prestartMatch = mainContent.match(prestartRegex); if (prestartMatch && prestartMatch[1]) { prestartContent = this.trimThinkLines(prestartMatch[1]); - mainContent = mainContent.replace(prestartRegex, ''); + // 移除所有 prestart 标签及其内容 + mainContent = mainContent.replace(/[\s\S]*<\/prestart>/gi, ''); } // 先尝试匹配完整的 think 标签 @@ -693,7 +721,64 @@ unadoptText(id) { const message = this.messages[index]; this.$set(message, 'prestartExpanded', !message.prestartExpanded); }, - + + // 检测并转换图片路径为缩略图 + convertImagePaths(content) { + if (!content) return content; + // 匹配常见图片路径格式: + // Windows: D:\path\to\image.png 或 D:/path/to/image.png + // Unix: /path/to/image.png + // 支持的图片格式: png, jpg, jpeg, gif, bmp, webp + const imagePathRegex = /([A-Za-z]:[\\\/][^\s<>"']+\.(png|jpg|jpeg|gif|bmp|webp)|\/[^\s<>"']+\.(png|jpg|jpeg|gif|bmp|webp))/gi; + + const baseUrl = window.location.protocol + '//' + window.location.hostname + ':' + window.location.port; + + return content.replace(imagePathRegex, (match) => { + // 对原始路径进行编码 + const encodedPath = encodeURIComponent(match); + // 通过后端 API 获取图片(解决浏览器安全限制) + const imgSrc = `${baseUrl}/api/local-image?path=${encodedPath}`; + // 用于显示的安全路径 + const displayPath = match.replace(/\\/g, '/').replace(/'/g, ''').replace(/"/g, '"'); + return ` + 图片 + 点击查看 + `; + }); + }, + + // 渲染 Markdown 内容 + renderMarkdown(content) { + if (!content) return ''; + try { + // 配置 marked 选项 + if (typeof marked !== 'undefined') { + marked.setOptions({ + breaks: true, // 支持换行 + gfm: true, // 支持 GitHub 风格的 Markdown + }); + // 预处理:确保 ** 和 * 标记能正确解析 + // 处理中文加粗:**文字** 后面可能有空格或其他字符 + let processed = content; + // 手动处理加粗语法 **text** + processed = processed.replace(/\*\*([^*\n]+)\*\*/g, '$1'); + // 手动处理斜体语法 *text*(避免与加粗冲突) + processed = processed.replace(/(?$1'); + // 对剩余内容使用 marked 解析 + let result = marked.parse(processed); + // 转换图片路径为缩略图 + result = this.convertImagePaths(result); + return result; + } + } catch (e) { + console.error('Markdown rendering error:', e); + } + // 如果 marked 不可用,返回简单处理的内容 + let result = content.replace(/\n/g, '
'); + result = this.convertImagePaths(result); + return result; + }, + // 检查MCP服务器状态 checkMcpStatus() { const mcpUrl = `http://${this.hostname}:5010/api/mcp/servers`; diff --git a/gui/templates/index.html b/gui/templates/index.html index 1b000db..2584945 100644 --- a/gui/templates/index.html +++ b/gui/templates/index.html @@ -10,6 +10,7 @@ + @@ -51,12 +52,12 @@
[[parseThinkContent(item.content).thinkContent]]
-
[[parseThinkContent(item.content).mainContent]]
+
预启动工具
-
[[parseThinkContent(item.content).prestartContent]]
+
[[item.timetext]] diff --git a/llm/data/测试.docx b/llm/data/测试.docx deleted file mode 100644 index 616b7a1..0000000 Binary files a/llm/data/测试.docx and /dev/null differ diff --git a/llm/nlp_bionicmemory_stream.py b/llm/nlp_bionicmemory_stream.py index 9322412..914bb53 100644 --- a/llm/nlp_bionicmemory_stream.py +++ b/llm/nlp_bionicmemory_stream.py @@ -1,1499 +1,1473 @@ -# -*- coding: utf-8 -*- -import os -import json -import time -import threading -import requests -import datetime -import schedule -import textwrap -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Literal, Optional, TypedDict, Tuple -from collections.abc import Mapping, Sequence -from langchain_openai import ChatOpenAI -from langchain_core.messages import HumanMessage, SystemMessage -from langgraph.graph import END, START, StateGraph - -# 新增:本地知识库相关导入 -import re -from pathlib import Path -import docx -from docx.document import Document -from docx.oxml.table import CT_Tbl -from docx.oxml.text.paragraph import CT_P -from docx.table import _Cell, Table -from docx.text.paragraph import Paragraph -try: - from pptx import Presentation - PPTX_AVAILABLE = True -except ImportError: - PPTX_AVAILABLE = False - -# 用于处理 .doc 文件的库 -try: - import win32com.client - WIN32COM_AVAILABLE = True -except ImportError: - WIN32COM_AVAILABLE = False - -from utils import util -import utils.config_util as cfg -from urllib3.exceptions import InsecureRequestWarning -from scheduler.thread_manager import MyThread -from core import content_db -from core import stream_manager -from faymcp import tool_registry as mcp_tool_registry - -# 新增:长短期记忆系统相关导入 -from bionicmemory.core.chroma_service import ChromaService -from bionicmemory.core.memory_system import LongShortTermMemorySystem, SourceType - -# 加载配置 -cfg.load_config() - -# 禁用不安全请求警告 -requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) - -# 记忆系统全局变量 -chroma_service = None # ChromaDB服务实例 -memory_system = None # 长短期记忆系统实例 -memory_system_lock = threading.RLock() # 保护记忆系统的锁 - -# 当前会话用户名(保留,用于兼容性) -current_username = None - -llm = ChatOpenAI( - model=cfg.gpt_model_engine, - base_url=cfg.gpt_base_url, - api_key=cfg.key_gpt_api_key, - streaming=True - ) - - -def init_memory_system(): - """ - 初始化长短期记忆系统 - - Returns: - bool: 是否初始化成功 - """ - global chroma_service, memory_system - - try: - util.log(1, "正在初始化记忆系统...") - - # 启动时检查并清除数据库(如果存在清除标记) - if ChromaService.check_and_clear_database_on_startup(): - util.log(1, "检测到记忆清除标记,已清除ChromaDB数据库") - - # 初始化ChromaDB服务 - chroma_service = ChromaService() - if not chroma_service: - util.log(1, "ChromaDB服务初始化失败") - return False - - # 初始化长短期记忆系统 - memory_system = LongShortTermMemorySystem( - chroma_service=chroma_service, - summary_threshold=500, - max_retrieval_results=10, - cluster_multiplier=3, - retrieval_multiplier=2 - ) - - util.log(1, "记忆系统初始化成功") - return True - - except Exception as e: - util.log(1, f"记忆系统初始化失败: {e}") - return False - - -# 在模块加载时初始化记忆系统 -init_memory_system() - - -@dataclass -class WorkflowToolSpec: - name: str - description: str - schema: Dict[str, Any] - executor: Callable[[Dict[str, Any], int], Tuple[bool, Optional[str], Optional[str]]] - example_args: Dict[str, Any] - - -class ToolCall(TypedDict): - name: str - args: Dict[str, Any] - - -class ToolResult(TypedDict, total=False): - call: ToolCall - success: bool - output: Optional[str] - error: Optional[str] - attempt: int - - -class ConversationMessage(TypedDict): - role: Literal["user", "assistant"] - content: str - - -class AgentState(TypedDict, total=False): - request: str - messages: List[ConversationMessage] - tool_results: List[ToolResult] - next_action: Optional[ToolCall] - status: Literal["planning", "needs_tool", "completed", "failed"] - final_response: Optional[str] - final_messages: Optional[List[SystemMessage | HumanMessage]] - planner_preview: Optional[str] - audit_log: List[str] - context: Dict[str, Any] - error: Optional[str] - max_steps: int - - -def _truncate_text(text: Any, limit: int = 400) -> str: - text_str = "" if text is None else str(text) - if len(text_str) <= limit: - return text_str - return text_str[:limit] + "..." - - -def _extract_text_from_result(value: Any, *, depth: int = 0) -> List[str]: - """Try to pull human-readable text snippets from tool results.""" - if value is None: - return [] - if depth > 6: - return [] - if isinstance(value, (str, int, float, bool)): - text = str(value).strip() - return [text] if text else [] - if isinstance(value, Mapping): - # Prefer explicit text/content fields - if "text" in value and not isinstance(value["text"], (dict, list, tuple)): - text = str(value["text"]).strip() - return [text] if text else [] - if "content" in value: - segments: List[str] = [] - for item in value.get("content", []): - segments.extend(_extract_text_from_result(item, depth=depth + 1)) - if segments: - return segments - segments = [] - for key, item in value.items(): - if key in {"meta", "annotations", "uid", "id", "messageId"}: - continue - segments.extend(_extract_text_from_result(item, depth=depth + 1)) - return segments - if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray)): - segments: List[str] = [] - for item in value: - segments.extend(_extract_text_from_result(item, depth=depth + 1)) - return segments - if hasattr(value, "text") and not callable(getattr(value, "text")): - text = str(getattr(value, "text", "")).strip() - return [text] if text else [] - if hasattr(value, "__dict__"): - return _extract_text_from_result(vars(value), depth=depth + 1) - text = str(value).strip() - return [text] if text else [] - - -def _normalize_tool_output(result: Any) -> str: - """Convert structured tool output to a concise human-readable string.""" - if result is None: - return "" - segments = _extract_text_from_result(result) - if segments: - cleaned = [segment for segment in segments if segment] - if cleaned: - return "\n".join(dict.fromkeys(cleaned)) - try: - return json.dumps(result, ensure_ascii=False, default=lambda o: getattr(o, "__dict__", str(o))) - except TypeError: - return str(result) - - -def _truncate_history(history: List[ToolResult], limit: int = 6) -> str: - if not history: - return "(暂无)" - lines: List[str] = [] - for item in history[-limit:]: - call = item.get("call", {}) - name = call.get("name", "未知工具") - attempt = item.get("attempt", 0) - success = item.get("success", False) - status = "成功" if success else "失败" - lines.append(f"- {name} 第 {attempt} 次 → {status}") - if item.get("output"): - lines.append(" 输出:" + _truncate_text(item["output"], 200)) - if item.get("error"): - lines.append(" 错误:" + _truncate_text(item["error"], 200)) - return "\n".join(lines) - - -def _format_schema_parameters(schema: Dict[str, Any]) -> List[str]: - if not schema: - return [" - 无参数"] - props = schema.get("properties") or {} - if not props: - return [" - 无参数"] - required = set(schema.get("required") or []) - lines: List[str] = [] - for field, meta in props.items(): - meta = meta or {} - field_type = meta.get("type", "string") - desc = (meta.get("description") or "").strip() - req_label = "必填" if field in required else "可选" - line = f" - {field} ({field_type},{req_label})" - if desc: - line += f":{desc}" - lines.append(line) - return lines or [" - 无参数"] - - -def _generate_example_args(schema: Dict[str, Any]) -> Dict[str, Any]: - example: Dict[str, Any] = {} - if not schema: - return example - props = schema.get("properties") or {} - for field, meta in props.items(): - meta = meta or {} - if "default" in meta: - example[field] = meta["default"] - continue - enum_values = meta.get("enum") or [] - if enum_values: - example[field] = enum_values[0] - continue - field_type = meta.get("type", "string") - if field_type in ("number", "integer"): - example[field] = 0 - elif field_type == "boolean": - example[field] = True - elif field_type == "array": - example[field] = [] - elif field_type == "object": - example[field] = {} - else: - description_hint = meta.get("description") or "" - example[field] = description_hint or "" - return example - - -def _format_tool_block(spec: WorkflowToolSpec) -> str: - param_lines = _format_schema_parameters(spec.schema) - example = json.dumps(spec.example_args, ensure_ascii=False) if spec.example_args else "{}" - lines = [ - f"- 工具名:{spec.name}", - f" 功能:{spec.description or '暂无描述'}", - " 参数:", - *param_lines, - f" 示例:{example}", - ] - return "\n".join(lines) - - -def _build_workflow_tool_spec(tool_def: Dict[str, Any]) -> Optional[WorkflowToolSpec]: - if not tool_def: - return None - name = tool_def.get("name") - if not name: - return None - description = tool_def.get("description") or tool_def.get("summary") or "" - schema = tool_def.get("inputSchema") or {} - example_args = _generate_example_args(schema) - - def _executor(args: Dict[str, Any], attempt: int) -> Tuple[bool, Optional[str], Optional[str]]: - try: - resp = requests.post( - f"http://127.0.0.1:5010/api/mcp/tools/{name}", - json=args, - timeout=120, - ) - resp.raise_for_status() - data = resp.json() - except Exception as exc: - util.log(1, f"调用工具 {name} 异常: {exc}") - return False, None, str(exc) - - if data.get("success"): - result = data.get("result") - output = _normalize_tool_output(result) - return True, output, None - - error_msg = data.get("error") or "未知错误" - util.log(1, f"调用工具 {name} 失败: {error_msg}") - return False, None, error_msg - - return WorkflowToolSpec( - name=name, - description=description, - schema=schema, - executor=_executor, - example_args=example_args, - ) - - -def _format_tools_for_prompt(tool_specs: Dict[str, WorkflowToolSpec]) -> str: - if not tool_specs: - return "(暂无可用工具)" - return "\n".join(_format_tool_block(spec) for spec in tool_specs.values()) - - -def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMessage]: - context = state.get("context", {}) or {} - system_prompt = context.get("system_prompt", "") - request = state.get("request", "") - tool_specs = context.get("tool_registry", {}) or {} - planner_preview = state.get("planner_preview") - conversation = state.get("messages", []) or [] - history = state.get("tool_results", []) or [] - knowledge_context = context.get("knowledge_context", "") - observation = context.get("observation", "") - - convo_text = "\n".join(f"{msg['role']}: {msg['content']}" for msg in conversation) or "(暂无对话)" - history_text = _truncate_history(history) - tools_text = _format_tools_for_prompt(tool_specs) - preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else "" - - user_block = textwrap.dedent( - f""" - -**当前请求** -{request} - -{system_prompt} - -**额外观察** -{observation or '(无补充)'} - -**关联知识** -{knowledge_context or '(无相关知识)'} - -**可用工具** -{tools_text} - -**历史工具执行** -{history_text}{preview_section} - -**对话及工具记录** -{convo_text} - -请返回 JSON,格式如下: -- 若需要调用工具: - {{"action": "tool", "tool": "工具名", "args": {{...}}}} -- 若直接回复: - {{"action": "finish_text"}}""" - ).strip() - - return [ - SystemMessage(content="你负责规划下一步行动,请严格输出合法 JSON。"), - HumanMessage(content=user_block), - ] - - -def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessage]: - context = state.get("context", {}) or {} - system_prompt = context.get("system_prompt", "") - request = state.get("request", "") - knowledge_context = context.get("knowledge_context", "") - observation = context.get("observation", "") - conversation = state.get("messages", []) or [] - planner_preview = state.get("planner_preview") - conversation_block = "\n".join(f"{msg['role']}: {msg['content']}" for msg in conversation) or "(暂无对话)" - history_text = _truncate_history(state.get("tool_results", [])) - preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else "" - - user_block = textwrap.dedent( - f""" -**当前请求** -{request} - -{system_prompt} - -**关联知识** -{knowledge_context or '(无相关知识)'} - -**其他观察** -{observation or '(无补充)'} - -**工具执行摘要** -{history_text}{preview_section} - -**对话及工具记录** -{conversation_block}""" - ).strip() - - return [ - SystemMessage(content="你是最终回复的口播助手,请用中文自然表达。"), - HumanMessage(content=user_block), - ] - - -def _call_planner_llm(state: AgentState) -> Dict[str, Any]: - response = llm.invoke(_build_planner_messages(state)) - content = getattr(response, "content", None) - if not isinstance(content, str): - raise RuntimeError("规划器返回内容异常,未获得字符串。") - trimmed = content.strip() - try: - decision = json.loads(trimmed) - except json.JSONDecodeError as exc: - raise RuntimeError(f"规划器返回的 JSON 无法解析: {trimmed}") from exc - decision.setdefault("_raw", trimmed) - return decision - - -def _plan_next_action(state: AgentState) -> AgentState: - context = state.get("context", {}) or {} - audit_log = list(state.get("audit_log", [])) - history = state.get("tool_results", []) or [] - max_steps = state.get("max_steps", 12) - if len(history) >= max_steps: - audit_log.append("规划器:超过最大步数,终止流程。") - return { - "status": "failed", - "audit_log": audit_log, - "error": "工具调用步数超限", - "context": context, - } - - decision = _call_planner_llm(state) - audit_log.append(f"规划器:决策 -> {decision.get('_raw', decision)}") - - action = decision.get("action") - if action == "tool": - tool_name = decision.get("tool") - tool_registry: Dict[str, WorkflowToolSpec] = context.get("tool_registry", {}) - if tool_name not in tool_registry: - audit_log.append(f"规划器:未知工具 {tool_name}") - return { - "status": "failed", - "audit_log": audit_log, - "error": f"未知工具 {tool_name}", - "context": context, - } - args = decision.get("args") or {} - - if history: - last_entry = history[-1] - last_call = last_entry.get("call", {}) or {} - if ( - last_entry.get("success") - and last_call.get("name") == tool_name - and (last_call.get("args") or {}) == args - and last_entry.get("output") - ): - recent_attempts = sum( - 1 - for item in reversed(history) - if item.get("call", {}).get("name") == tool_name - ) - if recent_attempts >= 1: - audit_log.append( - "规划器:检测到工具重复调用,使用最新结果产出最终回复。" - ) - final_messages = _build_final_messages(state) - preview = last_entry.get("output") - return { - "status": "completed", - "planner_preview": preview, - "final_response": None, - "final_messages": final_messages, - "audit_log": audit_log, - "context": context, - } - return { - "next_action": {"name": tool_name, "args": args}, - "status": "needs_tool", - "audit_log": audit_log, - "context": context, - } - - if action in {"finish", "finish_text"}: - preview = decision.get("message") - final_messages = _build_final_messages(state) - audit_log.append("规划器:任务完成,准备输出最终回复。") - return { - "status": "completed", - "planner_preview": preview, - "final_response": preview if action == "finish" else None, - "final_messages": final_messages, - "audit_log": audit_log, - "context": context, - } - - raise RuntimeError(f"未知的规划器决策: {decision}") - - -def _execute_tool(state: AgentState) -> AgentState: - context = dict(state.get("context", {}) or {}) - action = state.get("next_action") - if not action: - return { - "status": "failed", - "error": "缺少要执行的工具指令", - "context": context, - } - - history = list(state.get("tool_results", []) or []) - audit_log = list(state.get("audit_log", []) or []) - conversation = list(state.get("messages", []) or []) - - name = action.get("name") - args = action.get("args", {}) - tool_registry: Dict[str, WorkflowToolSpec] = context.get("tool_registry", {}) - spec = tool_registry.get(name) - if not spec: - return { - "status": "failed", - "error": f"未知工具 {name}", - "context": context, - } - - attempts = sum(1 for item in history if item.get("call", {}).get("name") == name) - success, output, error = spec.executor(args, attempts) - result: ToolResult = { - "call": {"name": name, "args": args}, - "success": success, - "output": output, - "error": error, - "attempt": attempts + 1, - } - history.append(result) - audit_log.append(f"执行器:{name} 第 {result['attempt']} 次 -> {'成功' if success else '失败'}") - - message_lines = [ - f"[TOOL] {name} {'成功' if success else '失败'}。", - ] - if output: - message_lines.append(f"[TOOL] 输出:{_truncate_text(output, 200)}") - if error: - message_lines.append(f"[TOOL] 错误:{_truncate_text(error, 200)}") - conversation.append({"role": "assistant", "content": "\n".join(message_lines)}) - - return { - "tool_results": history, - "messages": conversation, - "next_action": None, - "audit_log": audit_log, - "status": "planning", - "error": error if not success else None, - "context": context, - } - - -def _route_decision(state: AgentState) -> str: - return "call_tool" if state.get("status") == "needs_tool" else "end" - - -def _build_workflow_app() -> StateGraph: - graph = StateGraph(AgentState) - graph.add_node("plan_next", _plan_next_action) - graph.add_node("call_tool", _execute_tool) - graph.add_edge(START, "plan_next") - graph.add_conditional_edges( - "plan_next", - _route_decision, - { - "call_tool": "call_tool", - "end": END, - }, - ) - graph.add_edge("call_tool", "plan_next") - return graph.compile() - - -_WORKFLOW_APP = _build_workflow_app() - -# 新增:本地知识库相关函数 -def read_doc_file(file_path): - """ - 读取doc文件内容 - - 参数: - file_path: doc文件路径 - - 返回: - str: 文档内容 - """ - try: - # 方法1: 使用 win32com.client(Windows系统,推荐用于.doc文件) - if WIN32COM_AVAILABLE: - word = None - doc = None - try: - import pythoncom - pythoncom.CoInitialize() # 初始化COM组件 - - word = win32com.client.Dispatch("Word.Application") - word.Visible = False - doc = word.Documents.Open(file_path) - content = doc.Content.Text - - # 先保存内容,再尝试关闭 - if content and content.strip(): - try: - doc.Close() - word.Quit() - except Exception as close_e: - util.log(1, f"关闭Word应用程序时出错: {str(close_e)},但内容已成功提取") - - try: - pythoncom.CoUninitialize() # 清理COM组件 - except: - pass - - return content.strip() - - except Exception as e: - util.log(1, f"使用 win32com 读取 .doc 文件失败: {str(e)}") - finally: - # 确保资源被释放 - try: - if doc: - doc.Close() - except: - pass - try: - if word: - word.Quit() - except: - pass - try: - pythoncom.CoUninitialize() - except: - pass - - # 方法2: 简单的二进制文本提取(备选方案) - try: - with open(file_path, 'rb') as f: - raw_data = f.read() - # 尝试提取可打印的文本 - text_parts = [] - current_text = "" - - for byte in raw_data: - char = chr(byte) if 32 <= byte <= 126 or byte in [9, 10, 13] else None - if char: - current_text += char - else: - if len(current_text) > 3: # 只保留长度大于3的文本片段 - text_parts.append(current_text.strip()) - current_text = "" - - if len(current_text) > 3: - text_parts.append(current_text.strip()) - - # 过滤和清理文本 - filtered_parts = [] - for part in text_parts: - # 移除过多的重复字符和无意义的片段 - if (len(part) > 5 and - not part.startswith('Microsoft') and - not all(c in '0123456789-_.' for c in part) and - len(set(part)) > 3): # 字符种类要多样 - filtered_parts.append(part) - - if filtered_parts: - return '\n'.join(filtered_parts) - - except Exception as e: - util.log(1, f"使用二进制方法读取 .doc 文件失败: {str(e)}") - - util.log(1, f"无法读取 .doc 文件 {file_path},建议转换为 .docx 格式") - return "" - - except Exception as e: - util.log(1, f"读取doc文件 {file_path} 时出错: {str(e)}") - return "" - -def read_docx_file(file_path): - """ - 读取docx文件内容 - - 参数: - file_path: docx文件路径 - - 返回: - str: 文档内容 - """ - try: - doc = docx.Document(file_path) - content = [] - - for element in doc.element.body: - if isinstance(element, CT_P): - paragraph = Paragraph(element, doc) - if paragraph.text.strip(): - content.append(paragraph.text.strip()) - elif isinstance(element, CT_Tbl): - table = Table(element, doc) - for row in table.rows: - row_text = [] - for cell in row.cells: - if cell.text.strip(): - row_text.append(cell.text.strip()) - if row_text: - content.append(" | ".join(row_text)) - - return "\n".join(content) - except Exception as e: - util.log(1, f"读取docx文件 {file_path} 时出错: {str(e)}") - return "" - -def read_pptx_file(file_path): - """ - 读取pptx文件内容 - - 参数: - file_path: pptx文件路径 - - 返回: - str: 演示文稿内容 - """ - if not PPTX_AVAILABLE: - util.log(1, "python-pptx 库未安装,无法读取 PowerPoint 文件") - return "" - - try: - prs = Presentation(file_path) - content = [] - - for i, slide in enumerate(prs.slides): - slide_content = [f"第{i+1}页:"] - - for shape in slide.shapes: - if hasattr(shape, "text") and shape.text.strip(): - slide_content.append(shape.text.strip()) - - if len(slide_content) > 1: # 有内容才添加 - content.append("\n".join(slide_content)) - - return "\n\n".join(content) - except Exception as e: - util.log(1, f"读取pptx文件 {file_path} 时出错: {str(e)}") - return "" - -def load_local_knowledge_base(): - """ - 加载本地知识库内容 - - 返回: - dict: 文件名到内容的映射 - """ - knowledge_base = {} - - # 获取llm/data目录路径 - current_dir = os.path.dirname(os.path.abspath(__file__)) - data_dir = os.path.join(current_dir, "data") - - if not os.path.exists(data_dir): - util.log(1, f"知识库目录不存在: {data_dir}") - return knowledge_base - - # 遍历data目录中的文件 - for file_path in Path(data_dir).iterdir(): - if not file_path.is_file(): - continue - - file_name = file_path.name - file_extension = file_path.suffix.lower() - - try: - if file_extension == '.docx': - content = read_docx_file(str(file_path)) - elif file_extension == '.doc': - content = read_doc_file(str(file_path)) - elif file_extension == '.pptx': - content = read_pptx_file(str(file_path)) - else: - # 尝试作为文本文件读取 - try: - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - except UnicodeDecodeError: - try: - with open(file_path, 'r', encoding='gbk') as f: - content = f.read() - except UnicodeDecodeError: - util.log(1, f"无法解码文件: {file_name}") - continue - - if content.strip(): - knowledge_base[file_name] = content - util.log(1, f"成功加载知识库文件: {file_name} ({len(content)} 字符)") - - except Exception as e: - util.log(1, f"加载知识库文件 {file_name} 时出错: {str(e)}") - - return knowledge_base - -def search_knowledge_base(query, knowledge_base, max_results=3): - """ - 在知识库中搜索相关内容 - - 参数: - query: 查询内容 - knowledge_base: 知识库字典 - max_results: 最大返回结果数 - - 返回: - list: 相关内容列表 - """ - if not knowledge_base: - return [] - - results = [] - query_lower = query.lower() - - # 搜索关键词 - query_keywords = re.findall(r'\w+', query_lower) - - for file_name, content in knowledge_base.items(): - content_lower = content.lower() - - # 计算匹配度 - score = 0 - matched_sentences = [] - - # 按句子分割内容 - sentences = re.split(r'[。!?\n]', content) - - for sentence in sentences: - if not sentence.strip(): - continue - - sentence_lower = sentence.lower() - sentence_score = 0 - - # 计算关键词匹配度 - for keyword in query_keywords: - if keyword in sentence_lower: - sentence_score += 1 - - # 如果句子有匹配,记录 - if sentence_score > 0: - matched_sentences.append((sentence.strip(), sentence_score)) - score += sentence_score - - # 如果有匹配的内容 - if score > 0: - # 按匹配度排序句子 - matched_sentences.sort(key=lambda x: x[1], reverse=True) - - # 取前几个最相关的句子 - relevant_sentences = [sent[0] for sent in matched_sentences[:5] if sent[0]] - - if relevant_sentences: - results.append({ - 'file_name': file_name, - 'score': score, - 'content': '\n'.join(relevant_sentences) - }) - - # 按匹配度排序 - results.sort(key=lambda x: x['score'], reverse=True) - - return results[:max_results] - -# 全局知识库缓存 -_knowledge_base_cache = None -_knowledge_base_load_time = None -_knowledge_base_file_times = {} # 存储文件的最后修改时间 - -def check_knowledge_base_changes(): - """ - 检查知识库文件是否有变化 - - 返回: - bool: 如果有文件变化返回True,否则返回False - """ - global _knowledge_base_file_times - - # 获取llm/data目录路径 - current_dir = os.path.dirname(os.path.abspath(__file__)) - data_dir = os.path.join(current_dir, "data") - - if not os.path.exists(data_dir): - return False - - current_file_times = {} - - # 遍历data目录中的文件 - for file_path in Path(data_dir).iterdir(): - if not file_path.is_file(): - continue - - file_name = file_path.name - file_extension = file_path.suffix.lower() - - # 只检查支持的文件格式 - if file_extension in ['.docx', '.doc', '.pptx', '.txt'] or file_extension == '': - try: - mtime = os.path.getmtime(str(file_path)) - current_file_times[file_name] = mtime - except OSError: - continue - - # 检查是否有变化 - if not _knowledge_base_file_times: - # 第一次检查,保存文件时间 - _knowledge_base_file_times = current_file_times - return True - - # 比较文件时间 - if set(current_file_times.keys()) != set(_knowledge_base_file_times.keys()): - # 文件数量发生变化 - _knowledge_base_file_times = current_file_times - return True - - for file_name, mtime in current_file_times.items(): - if file_name not in _knowledge_base_file_times or _knowledge_base_file_times[file_name] != mtime: - # 文件被修改 - _knowledge_base_file_times = current_file_times - return True - - return False - -def init_knowledge_base(): - """ - 初始化知识库,在系统启动时调用 - """ - global _knowledge_base_cache, _knowledge_base_load_time - - util.log(1, "初始化本地知识库...") - _knowledge_base_cache = load_local_knowledge_base() - _knowledge_base_load_time = time.time() - - # 初始化文件修改时间跟踪 - check_knowledge_base_changes() - - util.log(1, f"知识库初始化完成,共 {len(_knowledge_base_cache)} 个文件") - -def get_knowledge_base(): - """ - 获取知识库,使用缓存机制 - - 返回: - dict: 知识库内容 - """ - global _knowledge_base_cache, _knowledge_base_load_time - - # 如果缓存为空,先初始化 - if _knowledge_base_cache is None: - init_knowledge_base() - return _knowledge_base_cache - - # 检查文件是否有变化 - if check_knowledge_base_changes(): - util.log(1, "检测到知识库文件变化,正在重新加载...") - _knowledge_base_cache = load_local_knowledge_base() - _knowledge_base_load_time = time.time() - util.log(1, f"知识库重新加载完成,共 {len(_knowledge_base_cache)} 个文件") - - return _knowledge_base_cache - - -def question(content, username, observation=None): - """处理用户提问并返回回复。""" - global current_username - current_username = username - full_response_text = "" - accumulated_text = "" - default_punctuations = [",", ".", "!", "?", "\n", "\uFF0C", "\u3002", "\uFF01", "\uFF1F"] - is_first_sentence = True - - from core import stream_manager - sm = stream_manager.new_instance() - conversation_id = sm.get_conversation_id(username) - - # 记忆系统已在全局初始化,无需创建agent - # 直接从配置文件获取人物设定 - agent_desc = { - "first_name": cfg.config["attribute"]["name"], - "last_name": "", - "age": cfg.config["attribute"]["age"], - "sex": cfg.config["attribute"]["gender"], - "additional": cfg.config["attribute"]["additional"], - "birthplace": cfg.config["attribute"]["birth"], - "position": cfg.config["attribute"]["position"], - "zodiac": cfg.config["attribute"]["zodiac"], - "constellation": cfg.config["attribute"]["constellation"], - "contact": cfg.config["attribute"]["contact"], - "voice": cfg.config["attribute"]["voice"], - "goal": cfg.config["attribute"]["goal"], - "occupation": cfg.config["attribute"]["job"], - "current_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - } - - # 使用新记忆系统处理用户消息 - # 一次性完成:入库 → 长期检索 → 短期检索 → 生成提示语 - short_term_records = [] - memory_prompt = "" - query_embedding = None - - try: - short_term_records, memory_prompt, query_embedding = memory_system.process_user_message( - content, user_id=username - ) - util.log(1, f"记忆检索成功,获取 {len(short_term_records)} 条相关记录") - except Exception as exc: - util.log(1, f"记忆检索失败: {exc}") - # 失败时使用空值,不影响后续流程 - short_term_records = [] - memory_prompt = "" - query_embedding = None - - knowledge_context = "" - try: - knowledge_base = get_knowledge_base() - if knowledge_base: - knowledge_results = search_knowledge_base(content, knowledge_base, max_results=3) - if knowledge_results: - parts = ["**本地知识库相关信息**:"] - for result in knowledge_results: - parts.append(f"来源文件:{result['file_name']}") - parts.append(result["content"]) - parts.append("") - knowledge_context = "\n".join(parts).strip() - util.log(1, f"找到 {len(knowledge_results)} 条相关知识库信息") - except Exception as exc: - util.log(1, f"搜索知识库时出错: {exc}") - - # 方案B:保留人设信息,补充记忆提示语 - # 1. 构建人设部分 - persona_prompt = f"""\n**角色设定**\n -- 名字:{agent_desc['first_name']} -- 性别:{agent_desc['sex']} -- 年龄:{agent_desc['age']} -- 职业:{agent_desc['occupation']} -- 出生地:{agent_desc['birthplace']} -- 星座:{agent_desc['constellation']} -- 生肖:{agent_desc['zodiac']} -- 联系方式:{agent_desc['contact']} -- 定位:{agent_desc['position']} -- 目标:{agent_desc['goal']} -- 补充信息:{agent_desc['additional']}\n - -""" - - # 2. 合并人设和记忆提示语 - if memory_prompt: - system_prompt = memory_prompt + persona_prompt - else: - # 如果记忆系统返回空提示语,使用基础提示语 - system_prompt = persona_prompt + "请根据用户的问题,提供有帮助的回答。" - - try: - history_records = content_db.new_instance().get_recent_messages_by_user(username=username, limit=30) - except Exception as exc: - util.log(1, f"加载历史消息失败: {exc}") - history_records = [] - - messages_buffer: List[ConversationMessage] = [] - - def append_to_buffer(role: str, text_value: str) -> None: - if not text_value: - return - messages_buffer.append({"role": role, "content": text_value}) - if len(messages_buffer) > 60: - del messages_buffer[:-60] - - for msg_type, msg_text in history_records: - role = 'assistant' - if msg_type and msg_type.lower() in ('member', 'user'): - role = 'user' - append_to_buffer(role, msg_text) - - if ( - not messages_buffer - or messages_buffer[-1]['role'] != 'user' - or messages_buffer[-1]['content'] != content - ): - append_to_buffer('user', content) - - tool_registry: Dict[str, WorkflowToolSpec] = {} - try: - mcp_tools = get_mcp_tools() - except Exception as exc: - util.log(1, f"获取工具列表失败: {exc}") - mcp_tools = [] - for tool_def in mcp_tools: - spec = _build_workflow_tool_spec(tool_def) - if spec: - tool_registry[spec.name] = spec - - try: - from utils.stream_state_manager import get_state_manager as _get_state_manager - - state_mgr = _get_state_manager() - session_label = "workflow_agent" if tool_registry else "llm_stream" - if not state_mgr.is_session_active(username, conversation_id=conversation_id): - state_mgr.start_new_session(username, session_label, conversation_id=conversation_id) - except Exception: - state_mgr = None - - try: - from utils.stream_text_processor import get_processor - - processor = get_processor() - punctuation_list = getattr(processor, "punctuation_marks", default_punctuations) - except Exception: - processor = None - punctuation_list = default_punctuations - def write_sentence(text: str, *, force_first: bool = False, force_end: bool = False) -> None: - if text is None: - text = "" - if not isinstance(text, str): - text = str(text) - if not text and not force_end and not force_first: - return - marked_text = None - if state_mgr is not None: - try: - marked_text, _, _ = state_mgr.prepare_sentence( - username, - text, - force_first=force_first, - force_end=force_end, - conversation_id=conversation_id, - ) - except Exception: - marked_text = None - if marked_text is None: - prefix = "_" if force_first else "" - suffix = "_" if force_end else "" - marked_text = f"{prefix}{text}{suffix}" - stream_manager.new_instance().write_sentence(username, marked_text, conversation_id=conversation_id) - - def stream_response_chunks(chunks, prepend_text: str = "") -> None: - nonlocal accumulated_text, full_response_text, is_first_sentence - if prepend_text: - accumulated_text += prepend_text - full_response_text += prepend_text - for chunk in chunks: - if sm.should_stop_generation(username, conversation_id=conversation_id): - util.log(1, f"检测到停止标志,中断文本生成: {username}") - break - if isinstance(chunk, str): - flush_text = chunk - elif isinstance(chunk, dict): - flush_text = chunk.get("content") - else: - flush_text = getattr(chunk, "content", None) - if isinstance(flush_text, list): - flush_text = "".join(part if isinstance(part, str) else "" for part in flush_text) - if not flush_text: - continue - flush_text = str(flush_text) - accumulated_text += flush_text - full_response_text += flush_text - if len(accumulated_text) >= 20: - while True: - last_punct_pos = -1 - for punct in punctuation_list: - pos = accumulated_text.rfind(punct) - if pos > last_punct_pos: - last_punct_pos = pos - if last_punct_pos > 10: - sentence_text = accumulated_text[: last_punct_pos + 1] - write_sentence(sentence_text, force_first=is_first_sentence) - is_first_sentence = False - accumulated_text = accumulated_text[last_punct_pos + 1 :].lstrip() - else: - break - - def finalize_stream(force_end: bool = False) -> None: - nonlocal accumulated_text, is_first_sentence - if accumulated_text: - write_sentence(accumulated_text, force_first=is_first_sentence, force_end=force_end) - is_first_sentence = False - accumulated_text = "" - elif force_end: - if state_mgr is not None: - try: - session_info = state_mgr.get_session_info(username, conversation_id=conversation_id) - except Exception: - session_info = None - if not session_info or not session_info.get("is_end_sent", False): - write_sentence("", force_end=True) - else: - write_sentence("", force_end=True) - - def run_workflow(tool_registry: Dict[str, WorkflowToolSpec]) -> bool: - nonlocal accumulated_text, full_response_text, is_first_sentence, messages_buffer - - initial_state: AgentState = { - "request": content, - "messages": messages_buffer, - "tool_results": [], - "audit_log": [], - "status": "planning", - "max_steps": 30, - "context": { - "system_prompt": system_prompt, - "knowledge_context": knowledge_context, - "observation": observation, - "tool_registry": tool_registry, - }, - } - - config = {"configurable": {"thread_id": f"workflow-{username}-{conversation_id}"}} - workflow_app = _WORKFLOW_APP - is_agent_think_start = False - final_state: Optional[AgentState] = None - final_stream_done = False - - try: - for event in workflow_app.stream(initial_state, config=config, stream_mode="updates"): - if sm.should_stop_generation(username, conversation_id=conversation_id): - util.log(1, f"检测到停止标志,中断工作流生成: {username}") - break - step, state = next(iter(event.items())) - final_state = state - status = state.get("status") - - state_messages = state.get("messages") or [] - if state_messages and len(state_messages) > len(messages_buffer): - messages_buffer.extend(state_messages[len(messages_buffer):]) - if len(messages_buffer) > 60: - del messages_buffer[:-60] - - if step == "plan_next": - if status == "needs_tool": - next_action = state.get("next_action") or {} - tool_name = next_action.get("name") or "unknown_tool" - tool_args = next_action.get("args") or {} - audit_log = state.get("audit_log") or [] - decision_note = audit_log[-1] if audit_log else "" - if "->" in decision_note: - decision_note = decision_note.split("->", 1)[1].strip() - args_text = json.dumps(tool_args, ensure_ascii=False) - message_lines = [ - "[PLAN] Planner preparing to call a tool.", - f"[PLAN] Decision: {decision_note}" if decision_note else "[PLAN] Decision: (missing)", - f"[PLAN] Tool: {tool_name}", - f"[PLAN] Args: {args_text}", - ] - message = "\n".join(message_lines) + "\n" - if not is_agent_think_start: - message = "" + message - is_agent_think_start = True - write_sentence(message, force_first=is_first_sentence) - is_first_sentence = False - full_response_text += message - append_to_buffer('assistant', message.strip()) - elif status == "completed" and not final_stream_done: - closing = "" if is_agent_think_start else "" - final_messages = state.get("final_messages") - final_response = state.get("final_response") - success = False - if final_messages: - try: - stream_response_chunks(llm.stream(final_messages), prepend_text=closing) - success = True - except requests.exceptions.RequestException as stream_exc: - util.log(1, f"最终回复流式输出失败: {stream_exc}") - elif final_response: - stream_response_chunks([closing + final_response]) - success = True - elif closing: - accumulated_text += closing - full_response_text += closing - final_stream_done = success - is_agent_think_start = False - elif step == "call_tool": - history = state.get("tool_results") or [] - if history: - last = history[-1] - call_info = last.get("call", {}) or {} - tool_name = call_info.get("name") or "unknown_tool" - success = last.get("success", False) - status_text = "SUCCESS" if success else "FAILED" - args_text = json.dumps(call_info.get("args") or {}, ensure_ascii=False) - message_lines = [ - f"[TOOL] {tool_name} execution {status_text}.", - f"[TOOL] Args: {args_text}", - ] - if last.get("output"): - message_lines.append(f"[TOOL] Output: {_truncate_text(last['output'], 120)}") - if last.get("error"): - message_lines.append(f"[TOOL] Error: {last['error']}") - message = "\n".join(message_lines) + "\n" - write_sentence(message, force_first=is_first_sentence) - is_first_sentence = False - full_response_text += message - append_to_buffer('assistant', message.strip()) - elif step == "__end__": - break - except Exception as exc: - util.log(1, f"执行工具工作流时出错: {exc}") - if is_agent_think_start: - closing = "" - accumulated_text += closing - full_response_text += closing - return False - - if final_state is None: - if is_agent_think_start: - closing = "" - accumulated_text += closing - full_response_text += closing - return False - - if not final_stream_done and is_agent_think_start: - closing = "" - accumulated_text += closing - full_response_text += closing - util.log(1, f"工具工作流未能完成,状态: {final_state.get('status')}") - - final_state_messages = final_state.get("messages") if final_state else None - if final_state_messages and len(final_state_messages) > len(messages_buffer): - messages_buffer.extend(final_state_messages[len(messages_buffer):]) - if len(messages_buffer) > 60: - del messages_buffer[:-60] - - return final_stream_done - - def run_direct_llm() -> bool: - nonlocal full_response_text, accumulated_text, is_first_sentence, messages_buffer - try: - # 统一使用 _build_final_messages 构建消息,确保历史对话始终被包含 - summary_state: AgentState = { - "request": content, - "messages": messages_buffer, - "tool_results": [], - "planner_preview": None, - "context": { - "system_prompt": system_prompt, - "knowledge_context": knowledge_context, - "observation": observation, - }, - } - - final_messages = _build_final_messages(summary_state) - stream_response_chunks(llm.stream(final_messages)) - return True - except requests.exceptions.RequestException as exc: - util.log(1, f"请求失败: {exc}") - error_message = "抱歉,我现在太忙了,休息一会,请稍后再试。" - write_sentence(error_message, force_first=is_first_sentence) - is_first_sentence = False - full_response_text = error_message - accumulated_text = "" - return False - - workflow_success = False - if tool_registry: - workflow_success = run_workflow(tool_registry) - - if (not tool_registry or not workflow_success) and not sm.should_stop_generation(username, conversation_id=conversation_id): - run_direct_llm() - - if not sm.should_stop_generation(username, conversation_id=conversation_id): - finalize_stream(force_end=True) - - if state_mgr is not None: - try: - state_mgr.end_session(username, conversation_id=conversation_id) - except Exception: - pass - else: - try: - from utils.stream_state_manager import get_state_manager - - get_state_manager().end_session(username, conversation_id=conversation_id) - except Exception: - pass - - final_text = full_response_text.split("")[-1] if full_response_text else "" - - # 使用新记忆系统异步处理agent回复 - try: - import asyncio - - # 创建新的事件循环(在独立线程中运行) - def async_memory_task(): - """在独立线程中运行异步记忆存储""" - try: - # 创建新的事件循环 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # 运行异步任务 - loop.run_until_complete( - memory_system.process_agent_reply_async( - final_text, - user_id=username, - current_user_content=content - ) - ) - - # 关闭循环 - loop.close() - except Exception as e: - util.log(1, f"异步记忆存储失败: {e}") - - # 启动独立线程执行异步任务 - MyThread(target=async_memory_task).start() - util.log(1, f"异步记忆存储任务已启动") - - except Exception as exc: - util.log(1, f"异步记忆处理启动失败: {exc}") - - return final_text -def clear_agent_memory(username=None): - """ - 清除指定用户的记忆(使用新记忆系统) - - Args: - username: 用户名,如果为None则清除当前用户的记忆 - - Returns: - bool: 是否清除成功 - """ - global memory_system, current_username - - try: - # 确定要清除的用户ID - user_id = username if username else current_username - if not user_id: - user_id = "User" # 默认用户 - - util.log(1, f"正在清除用户 {user_id} 的记忆...") - - # 调用新记忆系统的清除方法 - result = memory_system.clear_user_history(user_id=user_id) - - util.log(1, f"用户 {user_id} 的记忆清除完成: {result}") - return True - - except Exception as e: - util.log(1, f"清除用户记忆时出错: {str(e)}") - return False - -def get_mcp_tools() -> List[Dict[str, Any]]: - """ - 从共享缓存获取所有可用且已启用的MCP工具列表。 - """ - try: - tools = mcp_tool_registry.get_enabled_tools() - return tools or [] - except Exception as e: - util.log(1, f"获取工具列表出错:{e}") - return [] - - -if __name__ == "__main__": - # 记忆系统已在模块加载时初始化,无需再次调用 - for _ in range(3): - query = "Who is Fay?" - response = question(query, "User") - print(f"Q: {query}") - print(f"A: {response}") - time.sleep(1) +# -*- coding: utf-8 -*- +import os +import json +import time +import threading +import requests +import datetime +import schedule +import textwrap +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Literal, Optional, TypedDict, Tuple +from collections.abc import Mapping, Sequence +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage, SystemMessage +from langgraph.graph import END, START, StateGraph + +# 新增:本地知识库相关导入 +import re +from pathlib import Path +import docx +from docx.document import Document +from docx.oxml.table import CT_Tbl +from docx.oxml.text.paragraph import CT_P +from docx.table import _Cell, Table +from docx.text.paragraph import Paragraph +try: + from pptx import Presentation + PPTX_AVAILABLE = True +except ImportError: + PPTX_AVAILABLE = False + +# 用于处理 .doc 文件的库 +try: + import win32com.client + WIN32COM_AVAILABLE = True +except ImportError: + WIN32COM_AVAILABLE = False + +from utils import util +import utils.config_util as cfg +from urllib3.exceptions import InsecureRequestWarning +from scheduler.thread_manager import MyThread +from core import content_db +from core import stream_manager +from faymcp import tool_registry as mcp_tool_registry + +# 新增:长短期记忆系统相关导入 +from bionicmemory.core.chroma_service import ChromaService +from bionicmemory.core.memory_system import LongShortTermMemorySystem, SourceType + +# 加载配置 +cfg.load_config() + +# 禁用不安全请求警告 +requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) + +# 记忆系统全局变量 +chroma_service = None # ChromaDB服务实例 +memory_system = None # 长短期记忆系统实例 +memory_system_lock = threading.RLock() # 保护记忆系统的锁 + +# 当前会话用户名(保留,用于兼容性) +current_username = None + +llm = ChatOpenAI( + model=cfg.gpt_model_engine, + base_url=cfg.gpt_base_url, + api_key=cfg.key_gpt_api_key, + streaming=True + ) + + +def init_memory_system(): + """ + 初始化长短期记忆系统 + + Returns: + bool: 是否初始化成功 + """ + global chroma_service, memory_system + + try: + util.log(1, "正在初始化记忆系统...") + + # 启动时检查并清除数据库(如果存在清除标记) + if ChromaService.check_and_clear_database_on_startup(): + util.log(1, "检测到记忆清除标记,已清除ChromaDB数据库") + + # 初始化ChromaDB服务 + chroma_service = ChromaService() + if not chroma_service: + util.log(1, "ChromaDB服务初始化失败") + return False + + # 初始化长短期记忆系统 + memory_system = LongShortTermMemorySystem( + chroma_service=chroma_service, + summary_threshold=500, + max_retrieval_results=10, + cluster_multiplier=3, + retrieval_multiplier=2 + ) + + util.log(1, "记忆系统初始化成功") + return True + + except Exception as e: + util.log(1, f"记忆系统初始化失败: {e}") + return False + + +# 在模块加载时初始化记忆系统 +init_memory_system() + + +@dataclass +class WorkflowToolSpec: + name: str + description: str + schema: Dict[str, Any] + executor: Callable[[Dict[str, Any], int], Tuple[bool, Optional[str], Optional[str]]] + example_args: Dict[str, Any] + + +class ToolCall(TypedDict): + name: str + args: Dict[str, Any] + + +class ToolResult(TypedDict, total=False): + call: ToolCall + success: bool + output: Optional[str] + error: Optional[str] + attempt: int + + +class ConversationMessage(TypedDict): + role: Literal["user", "assistant"] + content: str + + +class AgentState(TypedDict, total=False): + request: str + messages: List[ConversationMessage] + tool_results: List[ToolResult] + next_action: Optional[ToolCall] + status: Literal["planning", "needs_tool", "completed", "failed"] + final_response: Optional[str] + final_messages: Optional[List[SystemMessage | HumanMessage]] + planner_preview: Optional[str] + audit_log: List[str] + context: Dict[str, Any] + error: Optional[str] + max_steps: int + + +def _truncate_text(text: Any, limit: int = 400) -> str: + text_str = "" if text is None else str(text) + if len(text_str) <= limit: + return text_str + return text_str[:limit] + "..." + + +def _extract_text_from_result(value: Any, *, depth: int = 0) -> List[str]: + """Try to pull human-readable text snippets from tool results.""" + if value is None: + return [] + if depth > 6: + return [] + if isinstance(value, (str, int, float, bool)): + text = str(value).strip() + return [text] if text else [] + if isinstance(value, Mapping): + # Prefer explicit text/content fields + if "text" in value and not isinstance(value["text"], (dict, list, tuple)): + text = str(value["text"]).strip() + return [text] if text else [] + if "content" in value: + segments: List[str] = [] + for item in value.get("content", []): + segments.extend(_extract_text_from_result(item, depth=depth + 1)) + if segments: + return segments + segments = [] + for key, item in value.items(): + if key in {"meta", "annotations", "uid", "id", "messageId"}: + continue + segments.extend(_extract_text_from_result(item, depth=depth + 1)) + return segments + if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray)): + segments: List[str] = [] + for item in value: + segments.extend(_extract_text_from_result(item, depth=depth + 1)) + return segments + if hasattr(value, "text") and not callable(getattr(value, "text")): + text = str(getattr(value, "text", "")).strip() + return [text] if text else [] + if hasattr(value, "__dict__"): + return _extract_text_from_result(vars(value), depth=depth + 1) + text = str(value).strip() + return [text] if text else [] + + +def _normalize_tool_output(result: Any) -> str: + """Convert structured tool output to a concise human-readable string.""" + if result is None: + return "" + segments = _extract_text_from_result(result) + if segments: + cleaned = [segment for segment in segments if segment] + if cleaned: + return "\n".join(dict.fromkeys(cleaned)) + try: + return json.dumps(result, ensure_ascii=False, default=lambda o: getattr(o, "__dict__", str(o))) + except TypeError: + return str(result) + + +def _truncate_history(history: List[ToolResult], limit: int = 6) -> str: + if not history: + return "(暂无)" + lines: List[str] = [] + for item in history[-limit:]: + call = item.get("call", {}) + name = call.get("name", "未知工具") + attempt = item.get("attempt", 0) + success = item.get("success", False) + status = "成功" if success else "失败" + lines.append(f"- {name} 第 {attempt} 次 → {status}") + if item.get("output"): + lines.append(" 输出:" + _truncate_text(item["output"], 200)) + if item.get("error"): + lines.append(" 错误:" + _truncate_text(item["error"], 200)) + return "\n".join(lines) + + +def _format_schema_parameters(schema: Dict[str, Any]) -> List[str]: + if not schema: + return [" - 无参数"] + props = schema.get("properties") or {} + if not props: + return [" - 无参数"] + required = set(schema.get("required") or []) + lines: List[str] = [] + for field, meta in props.items(): + meta = meta or {} + field_type = meta.get("type", "string") + desc = (meta.get("description") or "").strip() + req_label = "必填" if field in required else "可选" + line = f" - {field} ({field_type},{req_label})" + if desc: + line += f":{desc}" + lines.append(line) + return lines or [" - 无参数"] + + +def _generate_example_args(schema: Dict[str, Any]) -> Dict[str, Any]: + example: Dict[str, Any] = {} + if not schema: + return example + props = schema.get("properties") or {} + for field, meta in props.items(): + meta = meta or {} + if "default" in meta: + example[field] = meta["default"] + continue + enum_values = meta.get("enum") or [] + if enum_values: + example[field] = enum_values[0] + continue + field_type = meta.get("type", "string") + if field_type in ("number", "integer"): + example[field] = 0 + elif field_type == "boolean": + example[field] = True + elif field_type == "array": + example[field] = [] + elif field_type == "object": + example[field] = {} + else: + description_hint = meta.get("description") or "" + example[field] = description_hint or "" + return example + + +def _format_tool_block(spec: WorkflowToolSpec) -> str: + param_lines = _format_schema_parameters(spec.schema) + example = json.dumps(spec.example_args, ensure_ascii=False) if spec.example_args else "{}" + lines = [ + f"- 工具名:{spec.name}", + f" 功能:{spec.description or '暂无描述'}", + " 参数:", + *param_lines, + f" 示例:{example}", + ] + return "\n".join(lines) + + +def _build_workflow_tool_spec(tool_def: Dict[str, Any]) -> Optional[WorkflowToolSpec]: + if not tool_def: + return None + name = tool_def.get("name") + if not name: + return None + description = tool_def.get("description") or tool_def.get("summary") or "" + schema = tool_def.get("inputSchema") or {} + example_args = _generate_example_args(schema) + + def _executor(args: Dict[str, Any], attempt: int) -> Tuple[bool, Optional[str], Optional[str]]: + try: + resp = requests.post( + f"http://127.0.0.1:5010/api/mcp/tools/{name}", + json=args, + timeout=120, + ) + resp.raise_for_status() + data = resp.json() + except Exception as exc: + util.log(1, f"调用工具 {name} 异常: {exc}") + return False, None, str(exc) + + if data.get("success"): + result = data.get("result") + output = _normalize_tool_output(result) + return True, output, None + + error_msg = data.get("error") or "未知错误" + util.log(1, f"调用工具 {name} 失败: {error_msg}") + return False, None, error_msg + + return WorkflowToolSpec( + name=name, + description=description, + schema=schema, + executor=_executor, + example_args=example_args, + ) + + +def _format_tools_for_prompt(tool_specs: Dict[str, WorkflowToolSpec]) -> str: + if not tool_specs: + return "(暂无可用工具)" + return "\n".join(_format_tool_block(spec) for spec in tool_specs.values()) + + +def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMessage]: + context = state.get("context", {}) or {} + system_prompt = context.get("system_prompt", "") + request = state.get("request", "") + tool_specs = context.get("tool_registry", {}) or {} + planner_preview = state.get("planner_preview") + conversation = state.get("messages", []) or [] + history = state.get("tool_results", []) or [] + observation = context.get("observation", "") + + convo_text = "\n".join(f"{msg['role']}: {msg['content']}" for msg in conversation) or "(暂无对话)" + history_text = _truncate_history(history) + tools_text = _format_tools_for_prompt(tool_specs) + preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else "" + + user_block = textwrap.dedent( + f""" + +**当前请求** +{request} + +{system_prompt} + +**额外观察** +{observation or '(无补充)'} + +**可用工具** +{tools_text} + +**历史工具执行** +{history_text}{preview_section} + +**对话及工具记录** +{convo_text} + +请返回 JSON,格式如下: +- 若需要调用工具: + {{"action": "tool", "tool": "工具名", "args": {{...}}}} +- 若直接回复: + {{"action": "finish_text"}}""" + ).strip() + + return [ + SystemMessage(content="你负责规划下一步行动,请严格输出合法 JSON。"), + HumanMessage(content=user_block), + ] + + +def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessage]: + context = state.get("context", {}) or {} + system_prompt = context.get("system_prompt", "") + request = state.get("request", "") + observation = context.get("observation", "") + conversation = state.get("messages", []) or [] + planner_preview = state.get("planner_preview") + conversation_block = "\n".join(f"{msg['role']}: {msg['content']}" for msg in conversation) or "(暂无对话)" + history_text = _truncate_history(state.get("tool_results", [])) + preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else "" + + user_block = textwrap.dedent( + f""" +**当前请求** +{request} + +{system_prompt} + +**其他观察** +{observation or '(无补充)'} + +**工具执行摘要** +{history_text}{preview_section} + +**对话及工具记录** +{conversation_block}""" + ).strip() + + return [ + SystemMessage(content="你是最终回复的口播助手,请用中文自然表达。"), + HumanMessage(content=user_block), + ] + + +def _call_planner_llm(state: AgentState) -> Dict[str, Any]: + response = llm.invoke(_build_planner_messages(state)) + content = getattr(response, "content", None) + if not isinstance(content, str): + raise RuntimeError("规划器返回内容异常,未获得字符串。") + trimmed = content.strip() + try: + decision = json.loads(trimmed) + except json.JSONDecodeError as exc: + raise RuntimeError(f"规划器返回的 JSON 无法解析: {trimmed}") from exc + decision.setdefault("_raw", trimmed) + return decision + + +def _plan_next_action(state: AgentState) -> AgentState: + context = state.get("context", {}) or {} + audit_log = list(state.get("audit_log", [])) + history = state.get("tool_results", []) or [] + max_steps = state.get("max_steps", 12) + if len(history) >= max_steps: + audit_log.append("规划器:超过最大步数,终止流程。") + return { + "status": "failed", + "audit_log": audit_log, + "error": "工具调用步数超限", + "context": context, + } + + decision = _call_planner_llm(state) + audit_log.append(f"规划器:决策 -> {decision.get('_raw', decision)}") + + action = decision.get("action") + if action == "tool": + tool_name = decision.get("tool") + tool_registry: Dict[str, WorkflowToolSpec] = context.get("tool_registry", {}) + if tool_name not in tool_registry: + audit_log.append(f"规划器:未知工具 {tool_name}") + return { + "status": "failed", + "audit_log": audit_log, + "error": f"未知工具 {tool_name}", + "context": context, + } + args = decision.get("args") or {} + + if history: + last_entry = history[-1] + last_call = last_entry.get("call", {}) or {} + if ( + last_entry.get("success") + and last_call.get("name") == tool_name + and (last_call.get("args") or {}) == args + and last_entry.get("output") + ): + recent_attempts = sum( + 1 + for item in reversed(history) + if item.get("call", {}).get("name") == tool_name + ) + if recent_attempts >= 1: + audit_log.append( + "规划器:检测到工具重复调用,使用最新结果产出最终回复。" + ) + final_messages = _build_final_messages(state) + preview = last_entry.get("output") + return { + "status": "completed", + "planner_preview": preview, + "final_response": None, + "final_messages": final_messages, + "audit_log": audit_log, + "context": context, + } + return { + "next_action": {"name": tool_name, "args": args}, + "status": "needs_tool", + "audit_log": audit_log, + "context": context, + } + + if action in {"finish", "finish_text"}: + preview = decision.get("message") + final_messages = _build_final_messages(state) + audit_log.append("规划器:任务完成,准备输出最终回复。") + return { + "status": "completed", + "planner_preview": preview, + "final_response": preview if action == "finish" else None, + "final_messages": final_messages, + "audit_log": audit_log, + "context": context, + } + + raise RuntimeError(f"未知的规划器决策: {decision}") + + +def _execute_tool(state: AgentState) -> AgentState: + context = dict(state.get("context", {}) or {}) + action = state.get("next_action") + if not action: + return { + "status": "failed", + "error": "缺少要执行的工具指令", + "context": context, + } + + history = list(state.get("tool_results", []) or []) + audit_log = list(state.get("audit_log", []) or []) + conversation = list(state.get("messages", []) or []) + + name = action.get("name") + args = action.get("args", {}) + tool_registry: Dict[str, WorkflowToolSpec] = context.get("tool_registry", {}) + spec = tool_registry.get(name) + if not spec: + return { + "status": "failed", + "error": f"未知工具 {name}", + "context": context, + } + + attempts = sum(1 for item in history if item.get("call", {}).get("name") == name) + success, output, error = spec.executor(args, attempts) + result: ToolResult = { + "call": {"name": name, "args": args}, + "success": success, + "output": output, + "error": error, + "attempt": attempts + 1, + } + history.append(result) + audit_log.append(f"执行器:{name} 第 {result['attempt']} 次 -> {'成功' if success else '失败'}") + + message_lines = [ + f"[TOOL] {name} {'成功' if success else '失败'}。", + ] + if output: + message_lines.append(f"[TOOL] 输出:{_truncate_text(output, 200)}") + if error: + message_lines.append(f"[TOOL] 错误:{_truncate_text(error, 200)}") + conversation.append({"role": "assistant", "content": "\n".join(message_lines)}) + + return { + "tool_results": history, + "messages": conversation, + "next_action": None, + "audit_log": audit_log, + "status": "planning", + "error": error if not success else None, + "context": context, + } + + +def _route_decision(state: AgentState) -> str: + return "call_tool" if state.get("status") == "needs_tool" else "end" + + +def _build_workflow_app() -> StateGraph: + graph = StateGraph(AgentState) + graph.add_node("plan_next", _plan_next_action) + graph.add_node("call_tool", _execute_tool) + graph.add_edge(START, "plan_next") + graph.add_conditional_edges( + "plan_next", + _route_decision, + { + "call_tool": "call_tool", + "end": END, + }, + ) + graph.add_edge("call_tool", "plan_next") + return graph.compile() + + +_WORKFLOW_APP = _build_workflow_app() + +# 新增:本地知识库相关函数 +def read_doc_file(file_path): + """ + 读取doc文件内容 + + 参数: + file_path: doc文件路径 + + 返回: + str: 文档内容 + """ + try: + # 方法1: 使用 win32com.client(Windows系统,推荐用于.doc文件) + if WIN32COM_AVAILABLE: + word = None + doc = None + try: + import pythoncom + pythoncom.CoInitialize() # 初始化COM组件 + + word = win32com.client.Dispatch("Word.Application") + word.Visible = False + doc = word.Documents.Open(file_path) + content = doc.Content.Text + + # 先保存内容,再尝试关闭 + if content and content.strip(): + try: + doc.Close() + word.Quit() + except Exception as close_e: + util.log(1, f"关闭Word应用程序时出错: {str(close_e)},但内容已成功提取") + + try: + pythoncom.CoUninitialize() # 清理COM组件 + except: + pass + + return content.strip() + + except Exception as e: + util.log(1, f"使用 win32com 读取 .doc 文件失败: {str(e)}") + finally: + # 确保资源被释放 + try: + if doc: + doc.Close() + except: + pass + try: + if word: + word.Quit() + except: + pass + try: + pythoncom.CoUninitialize() + except: + pass + + # 方法2: 简单的二进制文本提取(备选方案) + try: + with open(file_path, 'rb') as f: + raw_data = f.read() + # 尝试提取可打印的文本 + text_parts = [] + current_text = "" + + for byte in raw_data: + char = chr(byte) if 32 <= byte <= 126 or byte in [9, 10, 13] else None + if char: + current_text += char + else: + if len(current_text) > 3: # 只保留长度大于3的文本片段 + text_parts.append(current_text.strip()) + current_text = "" + + if len(current_text) > 3: + text_parts.append(current_text.strip()) + + # 过滤和清理文本 + filtered_parts = [] + for part in text_parts: + # 移除过多的重复字符和无意义的片段 + if (len(part) > 5 and + not part.startswith('Microsoft') and + not all(c in '0123456789-_.' for c in part) and + len(set(part)) > 3): # 字符种类要多样 + filtered_parts.append(part) + + if filtered_parts: + return '\n'.join(filtered_parts) + + except Exception as e: + util.log(1, f"使用二进制方法读取 .doc 文件失败: {str(e)}") + + util.log(1, f"无法读取 .doc 文件 {file_path},建议转换为 .docx 格式") + return "" + + except Exception as e: + util.log(1, f"读取doc文件 {file_path} 时出错: {str(e)}") + return "" + +def read_docx_file(file_path): + """ + 读取docx文件内容 + + 参数: + file_path: docx文件路径 + + 返回: + str: 文档内容 + """ + try: + doc = docx.Document(file_path) + content = [] + + for element in doc.element.body: + if isinstance(element, CT_P): + paragraph = Paragraph(element, doc) + if paragraph.text.strip(): + content.append(paragraph.text.strip()) + elif isinstance(element, CT_Tbl): + table = Table(element, doc) + for row in table.rows: + row_text = [] + for cell in row.cells: + if cell.text.strip(): + row_text.append(cell.text.strip()) + if row_text: + content.append(" | ".join(row_text)) + + return "\n".join(content) + except Exception as e: + util.log(1, f"读取docx文件 {file_path} 时出错: {str(e)}") + return "" + +def read_pptx_file(file_path): + """ + 读取pptx文件内容 + + 参数: + file_path: pptx文件路径 + + 返回: + str: 演示文稿内容 + """ + if not PPTX_AVAILABLE: + util.log(1, "python-pptx 库未安装,无法读取 PowerPoint 文件") + return "" + + try: + prs = Presentation(file_path) + content = [] + + for i, slide in enumerate(prs.slides): + slide_content = [f"第{i+1}页:"] + + for shape in slide.shapes: + if hasattr(shape, "text") and shape.text.strip(): + slide_content.append(shape.text.strip()) + + if len(slide_content) > 1: # 有内容才添加 + content.append("\n".join(slide_content)) + + return "\n\n".join(content) + except Exception as e: + util.log(1, f"读取pptx文件 {file_path} 时出错: {str(e)}") + return "" + +def load_local_knowledge_base(): + """ + 加载本地知识库内容 + + 返回: + dict: 文件名到内容的映射 + """ + knowledge_base = {} + + # 获取llm/data目录路径 + current_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.join(current_dir, "data") + + if not os.path.exists(data_dir): + util.log(1, f"知识库目录不存在: {data_dir}") + return knowledge_base + + # 遍历data目录中的文件 + for file_path in Path(data_dir).iterdir(): + if not file_path.is_file(): + continue + + file_name = file_path.name + file_extension = file_path.suffix.lower() + + try: + if file_extension == '.docx': + content = read_docx_file(str(file_path)) + elif file_extension == '.doc': + content = read_doc_file(str(file_path)) + elif file_extension == '.pptx': + content = read_pptx_file(str(file_path)) + else: + # 尝试作为文本文件读取 + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + except UnicodeDecodeError: + try: + with open(file_path, 'r', encoding='gbk') as f: + content = f.read() + except UnicodeDecodeError: + util.log(1, f"无法解码文件: {file_name}") + continue + + if content.strip(): + knowledge_base[file_name] = content + util.log(1, f"成功加载知识库文件: {file_name} ({len(content)} 字符)") + + except Exception as e: + util.log(1, f"加载知识库文件 {file_name} 时出错: {str(e)}") + + return knowledge_base + +def search_knowledge_base(query, knowledge_base, max_results=3): + """ + 在知识库中搜索相关内容 + + 参数: + query: 查询内容 + knowledge_base: 知识库字典 + max_results: 最大返回结果数 + + 返回: + list: 相关内容列表 + """ + if not knowledge_base: + return [] + + results = [] + query_lower = query.lower() + + # 搜索关键词 + query_keywords = re.findall(r'\w+', query_lower) + + for file_name, content in knowledge_base.items(): + content_lower = content.lower() + + # 计算匹配度 + score = 0 + matched_sentences = [] + + # 按句子分割内容 + sentences = re.split(r'[。!?\n]', content) + + for sentence in sentences: + if not sentence.strip(): + continue + + sentence_lower = sentence.lower() + sentence_score = 0 + + # 计算关键词匹配度 + for keyword in query_keywords: + if keyword in sentence_lower: + sentence_score += 1 + + # 如果句子有匹配,记录 + if sentence_score > 0: + matched_sentences.append((sentence.strip(), sentence_score)) + score += sentence_score + + # 如果有匹配的内容 + if score > 0: + # 按匹配度排序句子 + matched_sentences.sort(key=lambda x: x[1], reverse=True) + + # 取前几个最相关的句子 + relevant_sentences = [sent[0] for sent in matched_sentences[:5] if sent[0]] + + if relevant_sentences: + results.append({ + 'file_name': file_name, + 'score': score, + 'content': '\n'.join(relevant_sentences) + }) + + # 按匹配度排序 + results.sort(key=lambda x: x['score'], reverse=True) + + return results[:max_results] + +# 全局知识库缓存 +_knowledge_base_cache = None +_knowledge_base_load_time = None +_knowledge_base_file_times = {} # 存储文件的最后修改时间 + +def check_knowledge_base_changes(): + """ + 检查知识库文件是否有变化 + + 返回: + bool: 如果有文件变化返回True,否则返回False + """ + global _knowledge_base_file_times + + # 获取llm/data目录路径 + current_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.join(current_dir, "data") + + if not os.path.exists(data_dir): + return False + + current_file_times = {} + + # 遍历data目录中的文件 + for file_path in Path(data_dir).iterdir(): + if not file_path.is_file(): + continue + + file_name = file_path.name + file_extension = file_path.suffix.lower() + + # 只检查支持的文件格式 + if file_extension in ['.docx', '.doc', '.pptx', '.txt'] or file_extension == '': + try: + mtime = os.path.getmtime(str(file_path)) + current_file_times[file_name] = mtime + except OSError: + continue + + # 检查是否有变化 + if not _knowledge_base_file_times: + # 第一次检查,保存文件时间 + _knowledge_base_file_times = current_file_times + return True + + # 比较文件时间 + if set(current_file_times.keys()) != set(_knowledge_base_file_times.keys()): + # 文件数量发生变化 + _knowledge_base_file_times = current_file_times + return True + + for file_name, mtime in current_file_times.items(): + if file_name not in _knowledge_base_file_times or _knowledge_base_file_times[file_name] != mtime: + # 文件被修改 + _knowledge_base_file_times = current_file_times + return True + + return False + +def init_knowledge_base(): + """ + 初始化知识库,在系统启动时调用 + """ + global _knowledge_base_cache, _knowledge_base_load_time + + util.log(1, "初始化本地知识库...") + _knowledge_base_cache = load_local_knowledge_base() + _knowledge_base_load_time = time.time() + + # 初始化文件修改时间跟踪 + check_knowledge_base_changes() + + util.log(1, f"知识库初始化完成,共 {len(_knowledge_base_cache)} 个文件") + +def get_knowledge_base(): + """ + 获取知识库,使用缓存机制 + + 返回: + dict: 知识库内容 + """ + global _knowledge_base_cache, _knowledge_base_load_time + + # 如果缓存为空,先初始化 + if _knowledge_base_cache is None: + init_knowledge_base() + return _knowledge_base_cache + + # 检查文件是否有变化 + if check_knowledge_base_changes(): + util.log(1, "检测到知识库文件变化,正在重新加载...") + _knowledge_base_cache = load_local_knowledge_base() + _knowledge_base_load_time = time.time() + util.log(1, f"知识库重新加载完成,共 {len(_knowledge_base_cache)} 个文件") + + return _knowledge_base_cache + + +def question(content, username, observation=None): + """处理用户提问并返回回复。""" + global current_username + current_username = username + full_response_text = "" + accumulated_text = "" + default_punctuations = [",", ".", "!", "?", "\n", "\uFF0C", "\u3002", "\uFF01", "\uFF1F"] + is_first_sentence = True + + from core import stream_manager + sm = stream_manager.new_instance() + conversation_id = sm.get_conversation_id(username) + + # 记忆系统已在全局初始化,无需创建agent + # 直接从配置文件获取人物设定 + agent_desc = { + "first_name": cfg.config["attribute"]["name"], + "last_name": "", + "age": cfg.config["attribute"]["age"], + "sex": cfg.config["attribute"]["gender"], + "additional": cfg.config["attribute"]["additional"], + "birthplace": cfg.config["attribute"]["birth"], + "position": cfg.config["attribute"]["position"], + "zodiac": cfg.config["attribute"]["zodiac"], + "constellation": cfg.config["attribute"]["constellation"], + "contact": cfg.config["attribute"]["contact"], + "voice": cfg.config["attribute"]["voice"], + "goal": cfg.config["attribute"]["goal"], + "occupation": cfg.config["attribute"]["job"], + "current_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + # 使用新记忆系统处理用户消息 + # 一次性完成:入库 → 长期检索 → 短期检索 → 生成提示语 + short_term_records = [] + memory_prompt = "" + query_embedding = None + + try: + short_term_records, memory_prompt, query_embedding = memory_system.process_user_message( + content, user_id=username + ) + util.log(1, f"记忆检索成功,获取 {len(short_term_records)} 条相关记录") + except Exception as exc: + util.log(1, f"记忆检索失败: {exc}") + # 失败时使用空值,不影响后续流程 + short_term_records = [] + memory_prompt = "" + query_embedding = None + + # 方案B:保留人设信息,补充记忆提示语 + # 1. 构建人设部分 + persona_prompt = f"""\n**角色设定**\n +- 名字:{agent_desc['first_name']} +- 性别:{agent_desc['sex']} +- 年龄:{agent_desc['age']} +- 职业:{agent_desc['occupation']} +- 出生地:{agent_desc['birthplace']} +- 星座:{agent_desc['constellation']} +- 生肖:{agent_desc['zodiac']} +- 联系方式:{agent_desc['contact']} +- 定位:{agent_desc['position']} +- 目标:{agent_desc['goal']} +- 补充信息:{agent_desc['additional']}\n + +""" + + # 2. 合并人设和记忆提示语 + if memory_prompt: + system_prompt = memory_prompt + persona_prompt + else: + # 如果记忆系统返回空提示语,使用基础提示语 + system_prompt = persona_prompt + "请根据用户的问题,提供有帮助的回答。" + + try: + history_records = content_db.new_instance().get_recent_messages_by_user(username=username, limit=30) + except Exception as exc: + util.log(1, f"加载历史消息失败: {exc}") + history_records = [] + + messages_buffer: List[ConversationMessage] = [] + + def append_to_buffer(role: str, text_value: str) -> None: + if not text_value: + return + messages_buffer.append({"role": role, "content": text_value}) + if len(messages_buffer) > 60: + del messages_buffer[:-60] + + for msg_type, msg_text in history_records: + role = 'assistant' + if msg_type and msg_type.lower() in ('member', 'user'): + role = 'user' + append_to_buffer(role, msg_text) + + if ( + not messages_buffer + or messages_buffer[-1]['role'] != 'user' + or messages_buffer[-1]['content'] != content + ): + append_to_buffer('user', content) + + tool_registry: Dict[str, WorkflowToolSpec] = {} + try: + mcp_tools = get_mcp_tools() + except Exception as exc: + util.log(1, f"获取工具列表失败: {exc}") + mcp_tools = [] + for tool_def in mcp_tools: + spec = _build_workflow_tool_spec(tool_def) + if spec: + tool_registry[spec.name] = spec + + try: + from utils.stream_state_manager import get_state_manager as _get_state_manager + + state_mgr = _get_state_manager() + session_label = "workflow_agent" if tool_registry else "llm_stream" + if not state_mgr.is_session_active(username, conversation_id=conversation_id): + state_mgr.start_new_session(username, session_label, conversation_id=conversation_id) + except Exception: + state_mgr = None + + try: + from utils.stream_text_processor import get_processor + + processor = get_processor() + punctuation_list = getattr(processor, "punctuation_marks", default_punctuations) + except Exception: + processor = None + punctuation_list = default_punctuations + def write_sentence(text: str, *, force_first: bool = False, force_end: bool = False) -> None: + if text is None: + text = "" + if not isinstance(text, str): + text = str(text) + if not text and not force_end and not force_first: + return + marked_text = None + if state_mgr is not None: + try: + marked_text, _, _ = state_mgr.prepare_sentence( + username, + text, + force_first=force_first, + force_end=force_end, + conversation_id=conversation_id, + ) + except Exception: + marked_text = None + if marked_text is None: + prefix = "_" if force_first else "" + suffix = "_" if force_end else "" + marked_text = f"{prefix}{text}{suffix}" + stream_manager.new_instance().write_sentence(username, marked_text, conversation_id=conversation_id) + + def stream_response_chunks(chunks, prepend_text: str = "") -> None: + nonlocal accumulated_text, full_response_text, is_first_sentence + if prepend_text: + accumulated_text += prepend_text + full_response_text += prepend_text + for chunk in chunks: + if sm.should_stop_generation(username, conversation_id=conversation_id): + util.log(1, f"检测到停止标志,中断文本生成: {username}") + break + if isinstance(chunk, str): + flush_text = chunk + elif isinstance(chunk, dict): + flush_text = chunk.get("content") + else: + flush_text = getattr(chunk, "content", None) + if isinstance(flush_text, list): + flush_text = "".join(part if isinstance(part, str) else "" for part in flush_text) + if not flush_text: + continue + flush_text = str(flush_text) + accumulated_text += flush_text + full_response_text += flush_text + if len(accumulated_text) >= 20: + while True: + last_punct_pos = -1 + for punct in punctuation_list: + pos = accumulated_text.rfind(punct) + if pos > last_punct_pos: + last_punct_pos = pos + if last_punct_pos > 10: + sentence_text = accumulated_text[: last_punct_pos + 1] + write_sentence(sentence_text, force_first=is_first_sentence) + is_first_sentence = False + accumulated_text = accumulated_text[last_punct_pos + 1 :].lstrip() + else: + break + + def finalize_stream(force_end: bool = False) -> None: + nonlocal accumulated_text, is_first_sentence + if accumulated_text: + write_sentence(accumulated_text, force_first=is_first_sentence, force_end=force_end) + is_first_sentence = False + accumulated_text = "" + elif force_end: + if state_mgr is not None: + try: + session_info = state_mgr.get_session_info(username, conversation_id=conversation_id) + except Exception: + session_info = None + if not session_info or not session_info.get("is_end_sent", False): + write_sentence("", force_end=True) + else: + write_sentence("", force_end=True) + + def run_workflow(tool_registry: Dict[str, WorkflowToolSpec]) -> bool: + nonlocal accumulated_text, full_response_text, is_first_sentence, messages_buffer + + initial_state: AgentState = { + "request": content, + "messages": messages_buffer, + "tool_results": [], + "audit_log": [], + "status": "planning", + "max_steps": 30, + "context": { + "system_prompt": system_prompt, + "observation": observation, + "tool_registry": tool_registry, + }, + } + + config = {"configurable": {"thread_id": f"workflow-{username}-{conversation_id}"}} + workflow_app = _WORKFLOW_APP + is_agent_think_start = False + final_state: Optional[AgentState] = None + final_stream_done = False + + try: + for event in workflow_app.stream(initial_state, config=config, stream_mode="updates"): + if sm.should_stop_generation(username, conversation_id=conversation_id): + util.log(1, f"检测到停止标志,中断工作流生成: {username}") + break + step, state = next(iter(event.items())) + final_state = state + status = state.get("status") + + state_messages = state.get("messages") or [] + if state_messages and len(state_messages) > len(messages_buffer): + messages_buffer.extend(state_messages[len(messages_buffer):]) + if len(messages_buffer) > 60: + del messages_buffer[:-60] + + if step == "plan_next": + if status == "needs_tool": + next_action = state.get("next_action") or {} + tool_name = next_action.get("name") or "unknown_tool" + tool_args = next_action.get("args") or {} + audit_log = state.get("audit_log") or [] + decision_note = audit_log[-1] if audit_log else "" + if "->" in decision_note: + decision_note = decision_note.split("->", 1)[1].strip() + args_text = json.dumps(tool_args, ensure_ascii=False) + message_lines = [ + "[PLAN] Planner preparing to call a tool.", + f"[PLAN] Decision: {decision_note}" if decision_note else "[PLAN] Decision: (missing)", + f"[PLAN] Tool: {tool_name}", + f"[PLAN] Args: {args_text}", + ] + message = "\n".join(message_lines) + "\n" + if not is_agent_think_start: + message = "" + message + is_agent_think_start = True + write_sentence(message, force_first=is_first_sentence) + is_first_sentence = False + full_response_text += message + append_to_buffer('assistant', message.strip()) + elif status == "completed" and not final_stream_done: + closing = "" if is_agent_think_start else "" + final_messages = state.get("final_messages") + final_response = state.get("final_response") + success = False + if final_messages: + try: + stream_response_chunks(llm.stream(final_messages), prepend_text=closing) + success = True + except requests.exceptions.RequestException as stream_exc: + util.log(1, f"最终回复流式输出失败: {stream_exc}") + elif final_response: + stream_response_chunks([closing + final_response]) + success = True + elif closing: + accumulated_text += closing + full_response_text += closing + final_stream_done = success + is_agent_think_start = False + elif step == "call_tool": + history = state.get("tool_results") or [] + if history: + last = history[-1] + call_info = last.get("call", {}) or {} + tool_name = call_info.get("name") or "unknown_tool" + success = last.get("success", False) + status_text = "SUCCESS" if success else "FAILED" + args_text = json.dumps(call_info.get("args") or {}, ensure_ascii=False) + message_lines = [ + f"[TOOL] {tool_name} execution {status_text}.", + f"[TOOL] Args: {args_text}", + ] + if last.get("output"): + message_lines.append(f"[TOOL] Output: {_truncate_text(last['output'], 120)}") + if last.get("error"): + message_lines.append(f"[TOOL] Error: {last['error']}") + message = "\n".join(message_lines) + "\n" + write_sentence(message, force_first=is_first_sentence) + is_first_sentence = False + full_response_text += message + append_to_buffer('assistant', message.strip()) + elif step == "__end__": + break + except Exception as exc: + util.log(1, f"执行工具工作流时出错: {exc}") + if is_agent_think_start: + closing = "" + accumulated_text += closing + full_response_text += closing + return False + + if final_state is None: + if is_agent_think_start: + closing = "" + accumulated_text += closing + full_response_text += closing + return False + + if not final_stream_done and is_agent_think_start: + closing = "" + accumulated_text += closing + full_response_text += closing + util.log(1, f"工具工作流未能完成,状态: {final_state.get('status')}") + + final_state_messages = final_state.get("messages") if final_state else None + if final_state_messages and len(final_state_messages) > len(messages_buffer): + messages_buffer.extend(final_state_messages[len(messages_buffer):]) + if len(messages_buffer) > 60: + del messages_buffer[:-60] + + return final_stream_done + + def run_direct_llm() -> bool: + nonlocal full_response_text, accumulated_text, is_first_sentence, messages_buffer + try: + # 统一使用 _build_final_messages 构建消息,确保历史对话始终被包含 + summary_state: AgentState = { + "request": content, + "messages": messages_buffer, + "tool_results": [], + "planner_preview": None, + "context": { + "system_prompt": system_prompt, + "observation": observation, + }, + } + + final_messages = _build_final_messages(summary_state) + stream_response_chunks(llm.stream(final_messages)) + return True + except requests.exceptions.RequestException as exc: + util.log(1, f"请求失败: {exc}") + error_message = "抱歉,我现在太忙了,休息一会,请稍后再试。" + write_sentence(error_message, force_first=is_first_sentence) + is_first_sentence = False + full_response_text = error_message + accumulated_text = "" + return False + + workflow_success = False + if tool_registry: + workflow_success = run_workflow(tool_registry) + + if (not tool_registry or not workflow_success) and not sm.should_stop_generation(username, conversation_id=conversation_id): + run_direct_llm() + + if not sm.should_stop_generation(username, conversation_id=conversation_id): + finalize_stream(force_end=True) + + if state_mgr is not None: + try: + state_mgr.end_session(username, conversation_id=conversation_id) + except Exception: + pass + else: + try: + from utils.stream_state_manager import get_state_manager + + get_state_manager().end_session(username, conversation_id=conversation_id) + except Exception: + pass + + final_text = full_response_text.split("")[-1] if full_response_text else "" + + # 使用新记忆系统异步处理agent回复 + try: + import asyncio + + # 创建新的事件循环(在独立线程中运行) + def async_memory_task(): + """在独立线程中运行异步记忆存储""" + try: + # 创建新的事件循环 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # 运行异步任务 + loop.run_until_complete( + memory_system.process_agent_reply_async( + final_text, + user_id=username, + current_user_content=content + ) + ) + + # 关闭循环 + loop.close() + except Exception as e: + util.log(1, f"异步记忆存储失败: {e}") + + # 启动独立线程执行异步任务 + MyThread(target=async_memory_task).start() + util.log(1, f"异步记忆存储任务已启动") + + except Exception as exc: + util.log(1, f"异步记忆处理启动失败: {exc}") + + return final_text +def clear_agent_memory(username=None): + """ + 清除指定用户的记忆(使用新记忆系统) + + Args: + username: 用户名,如果为None则清除当前用户的记忆 + + Returns: + bool: 是否清除成功 + """ + global memory_system, current_username + + try: + # 确定要清除的用户ID + user_id = username if username else current_username + if not user_id: + user_id = "User" # 默认用户 + + util.log(1, f"正在清除用户 {user_id} 的记忆...") + + # 调用新记忆系统的清除方法 + result = memory_system.clear_user_history(user_id=user_id) + + util.log(1, f"用户 {user_id} 的记忆清除完成: {result}") + return True + + except Exception as e: + util.log(1, f"清除用户记忆时出错: {str(e)}") + return False + +def get_mcp_tools() -> List[Dict[str, Any]]: + """ + 从共享缓存获取所有可用且已启用的MCP工具列表。 + """ + try: + tools = mcp_tool_registry.get_enabled_tools() + return tools or [] + except Exception as e: + util.log(1, f"获取工具列表出错:{e}") + return [] + + +if __name__ == "__main__": + # 记忆系统已在模块加载时初始化,无需再次调用 + for _ in range(3): + query = "Who is Fay?" + response = question(query, "User") + print(f"Q: {query}") + print(f"A: {response}") + time.sleep(1) diff --git a/llm/nlp_cognitive_stream.py b/llm/nlp_cognitive_stream.py index 5a574e3..28d1418 100644 --- a/llm/nlp_cognitive_stream.py +++ b/llm/nlp_cognitive_stream.py @@ -491,7 +491,6 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess conversation = state.get("messages", []) or [] history = state.get("tool_results", []) or [] memory_context = context.get("memory_context", "") - knowledge_context = context.get("knowledge_context", "") observation = context.get("observation", "") prestart_context = context.get("prestart_context", "") @@ -520,9 +519,6 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess **关联记忆** {memory_context or '(无相关记忆)'} - -**关联知识** -{knowledge_context or '(无相关知识)'} {prestart_section} **可用工具** {tools_text} @@ -540,9 +536,6 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess {{"action": "finish_text"}}""" ).strip() - print("***********************************************************************") - print(user_block) - print("****************************************************************") return [ SystemMessage(content="你负责规划下一步行动,请严格输出合法 JSON。"), HumanMessage(content=user_block), @@ -553,7 +546,6 @@ def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessag context = state.get("context", {}) or {} system_prompt = context.get("system_prompt", "") request = state.get("request", "") - knowledge_context = context.get("knowledge_context", "") memory_context = context.get("memory_context", "") observation = context.get("observation", "") prestart_context = context.get("prestart_context", "") @@ -579,9 +571,6 @@ def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessag **关联记忆** {memory_context or '(无相关记忆)'} - -**关联知识** -{knowledge_context or '(无相关知识)'} {prestart_section} **其他观察** {observation or '(无补充)'} @@ -593,9 +582,6 @@ def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessag {conversation_block}""" ).strip() - print("***********************************************************************") - print(user_block) - print("****************************************************************") return [ SystemMessage(content="你是最终回复的口播助手,请用中文自然表达。"), HumanMessage(content=user_block), @@ -1464,22 +1450,6 @@ def question(content, username, observation=None): except Exception as exc: util.log(1, f"获取相关记忆时出错: {exc}") - knowledge_context = "" - try: - knowledge_base = get_knowledge_base() - if knowledge_base: - knowledge_results = search_knowledge_base(content, knowledge_base, max_results=3) - if knowledge_results: - parts = ["**本地知识库相关信息**:"] - for result in knowledge_results: - parts.append(f"来源文件:{result['file_name']}") - parts.append(result["content"]) - parts.append("") - knowledge_context = "\n".join(parts).strip() - util.log(1, f"找到 {len(knowledge_results)} 条相关知识库信息") - except Exception as exc: - util.log(1, f"搜索知识库时出错: {exc}") - prestart_context = "" try: prestart_context = _run_prestart_tools(content) @@ -1703,7 +1673,6 @@ def question(content, username, observation=None): "max_steps": 30, "context": { "system_prompt": system_prompt, - "knowledge_context": knowledge_context, "observation": observation, "memory_context": memory_context, "prestart_context": prestart_context, @@ -1838,7 +1807,6 @@ def question(content, username, observation=None): "planner_preview": None, "context": { "system_prompt": system_prompt, - "knowledge_context": knowledge_context, "observation": observation, "memory_context": memory_context, "prestart_context": prestart_context, diff --git a/mcp_servers/window_capture/README.md b/mcp_servers/window_capture/README.md new file mode 100644 index 0000000..622ca52 --- /dev/null +++ b/mcp_servers/window_capture/README.md @@ -0,0 +1,34 @@ +# Window Capture MCP Server + +在 Windows 上按窗口标题(或句柄)截图的 MCP 服务器,提供两个工具: +- `list_windows`:列出当前顶层窗口,可按关键词过滤。 +- `capture_window`:按窗口标题关键字或句柄截屏,返回 PNG(同时保存到本地)。 + +## 准备 +1. 进入本目录:`cd mcp_servers/window_capture` +2. 安装依赖:`pip install -r requirements.txt`(需要 Pillow;仅支持 Windows) + +默认保存目录:`cache_data/window_captures`(相对仓库根目录,可在调用时通过 `save_dir` 自定义)。 + +## 运行 +```bash +python mcp_servers/window_capture/server.py +``` +或在 Fay 的 MCP 管理页面添加一条记录: +- transport: `stdio` +- command: `python` +- args: `["mcp_servers/window_capture/server.py"]` +- cwd: 仓库根目录或留空 + +## 工具参数 +- `list_windows` + - `keyword` (可选): 标题关键字,模糊匹配,不区分大小写。 + - `include_hidden` (可选): 是否包含隐藏/最小化窗口,默认 false。 + - `limit` (可选): 最大返回数量,默认 20,0 表示不限制。 +- `capture_window` + - `window` (必填): 窗口标题关键字,或窗口句柄(十进制/0x16 进制)。 + - `include_hidden` (可选): 允许捕获隐藏/最小化窗口,默认 false。 + - `save_dir` (可选): 自定义保存路径。 + +返回内容: +- 文本摘要(JSON 字符串,包含窗口信息与保存路径)。截图文件保存在 `cache_data/window_captures` 或自定义 `save_dir`。 diff --git a/mcp_servers/window_capture/requirements.txt b/mcp_servers/window_capture/requirements.txt new file mode 100644 index 0000000..b63c575 --- /dev/null +++ b/mcp_servers/window_capture/requirements.txt @@ -0,0 +1 @@ +Pillow>=9.5.0 diff --git a/mcp_servers/window_capture/server.py b/mcp_servers/window_capture/server.py new file mode 100644 index 0000000..de02638 --- /dev/null +++ b/mcp_servers/window_capture/server.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +Window Capture MCP server. + +Tools: +- list_windows: enumerate top-level windows with optional keyword filtering. +- capture_window: take a PNG screenshot of a specific window by title keyword or handle. + +Only Windows is supported because the capture path relies on Win32 APIs and Pillow's ImageGrab. +""" + +import asyncio +import ctypes +from ctypes import wintypes +import json +import os +import sys +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +try: + from PIL import ImageGrab +except ImportError: + print("Pillow not installed. Please run: pip install Pillow", file=sys.stderr) + sys.exit(1) + +try: + from mcp.server import Server + from mcp.types import Tool, TextContent + import mcp.server.stdio +except ImportError: + print("MCP library not installed. Please run: pip install mcp", file=sys.stderr) + sys.exit(1) + + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +DEFAULT_SAVE_DIR = os.path.join(PROJECT_ROOT, "cache_data", "window_captures") +os.makedirs(DEFAULT_SAVE_DIR, exist_ok=True) + +if os.name != "nt": + print("window_capture MCP server currently supports Windows only.", file=sys.stderr) + +server = Server("window_capture") + +user32 = ctypes.windll.user32 +SW_RESTORE = 9 + +try: + user32.SetProcessDPIAware() +except Exception: + pass + + +@dataclass +class WindowInfo: + handle: int + title: str + cls: str + rect: Tuple[int, int, int, int] + visible: bool + minimized: bool + + def to_dict(self) -> Dict[str, Any]: + left, top, right, bottom = self.rect + return { + "title": self.title, + "class": self.cls, + "handle": self.handle, + "handle_hex": hex(self.handle), + "visible": self.visible, + "minimized": self.minimized, + "rect": {"left": left, "top": top, "right": right, "bottom": bottom}, + "size": {"width": max(0, right - left), "height": max(0, bottom - top)}, + } + + +class WindowCaptureError(Exception): + pass + + +def _text_content(text: str): + try: + return TextContent(type="text", text=text) + except Exception: + return {"type": "text", "text": text} + + +def _sanitize_filename(name: str) -> str: + cleaned = "".join(ch for ch in name if ch.isalnum() or ch in (" ", "_", "-")) + cleaned = cleaned.strip().replace(" ", "_") + return cleaned or "window" + + +def _enum_windows(keyword: Optional[str] = None, include_hidden: bool = False, limit: int = 30) -> List[WindowInfo]: + if os.name != "nt": + raise WindowCaptureError("Window enumeration is only supported on Windows.") + + results: List[WindowInfo] = [] + keyword_l = keyword.lower() if keyword else None + + def callback(hwnd, _lparam): + if not user32.IsWindow(hwnd): + return True + if not include_hidden and not user32.IsWindowVisible(hwnd): + return True + + length = user32.GetWindowTextLengthW(hwnd) + if length == 0: + return True + + title_buf = ctypes.create_unicode_buffer(length + 1) + user32.GetWindowTextW(hwnd, title_buf, length + 1) + title = title_buf.value.strip() + if not title: + return True + + if keyword_l and keyword_l not in title.lower(): + return True + + class_buf = ctypes.create_unicode_buffer(256) + user32.GetClassNameW(hwnd, class_buf, 255) + rect = wintypes.RECT() + if not user32.GetWindowRect(hwnd, ctypes.byref(rect)): + return True + + info = WindowInfo( + handle=int(hwnd), + title=title, + cls=class_buf.value, + rect=(rect.left, rect.top, rect.right, rect.bottom), + visible=bool(user32.IsWindowVisible(hwnd)), + minimized=bool(user32.IsIconic(hwnd)), + ) + results.append(info) + if limit > 0 and len(results) >= limit: + return False + return True + + enum_proc = ctypes.WINFUNCTYPE(ctypes.c_bool, wintypes.HWND, wintypes.LPARAM) + user32.EnumWindows(enum_proc(callback), 0) + + results.sort(key=lambda w: w.title.lower()) + return results + + +def _parse_handle(value: str) -> Optional[int]: + text = value.strip().lower() + if text.startswith("0x"): + try: + return int(text, 16) + except ValueError: + return None + if text.isdigit(): + try: + return int(text) + except ValueError: + return None + return None + + +def _resolve_window(query: str, include_hidden: bool = False) -> WindowInfo: + if not query or not str(query).strip(): + raise WindowCaptureError("Window identifier is required.") + + handle_candidate = _parse_handle(str(query)) + if handle_candidate is not None: + windows = _enum_windows(None, include_hidden=include_hidden, limit=0) + for win in windows: + if win.handle == handle_candidate: + return win + raise WindowCaptureError(f"Window handle {handle_candidate} not found.") + + matches = _enum_windows(query, include_hidden=include_hidden, limit=50) + if not matches: + raise WindowCaptureError(f"No window matched keyword '{query}'.") + + exact = [w for w in matches if w.title.lower() == query.lower()] + if len(exact) == 1: + return exact[0] + if len(matches) == 1: + return matches[0] + + names = "; ".join(w.title for w in matches[:6]) + raise WindowCaptureError(f"Multiple windows matched. Please be more specific. Candidates: {names}") + + +def _get_foreground_window() -> int: + """获取当前前台窗口句柄""" + try: + return user32.GetForegroundWindow() + except Exception: + return 0 + + +def _activate_window(hwnd: int) -> bool: + """激活指定窗口,返回是否成功""" + try: + # 如果窗口最小化,先恢复 + if user32.IsIconic(hwnd): + user32.ShowWindow(hwnd, SW_RESTORE) + time.sleep(0.1) + + # 尝试多种方式激活窗口 + # 方法1: 使用 SetForegroundWindow + result = user32.SetForegroundWindow(hwnd) + + if not result: + # 方法2: 使用 keybd_event 模拟 Alt 键来允许切换前台 + ALT_KEY = 0x12 + KEYEVENTF_EXTENDEDKEY = 0x0001 + KEYEVENTF_KEYUP = 0x0002 + user32.keybd_event(ALT_KEY, 0, KEYEVENTF_EXTENDEDKEY, 0) + user32.SetForegroundWindow(hwnd) + user32.keybd_event(ALT_KEY, 0, KEYEVENTF_EXTENDEDKEY | KEYEVENTF_KEYUP, 0) + + time.sleep(0.3) # 等待窗口完全显示 + return True + except Exception: + return False + + +def _get_window_rect(hwnd: int) -> Tuple[int, int, int, int]: + """获取窗口的最新坐标""" + rect = wintypes.RECT() + if user32.GetWindowRect(hwnd, ctypes.byref(rect)): + return (rect.left, rect.top, rect.right, rect.bottom) + return (0, 0, 0, 0) + + +def capture_window(query: str, save_dir: Optional[str] = None, include_hidden: bool = False): + if os.name != "nt": + raise WindowCaptureError("Window capture is only supported on Windows.") + + window = _resolve_window(query, include_hidden=include_hidden) + + # 记录当前前台窗口,截图后恢复 + original_foreground = _get_foreground_window() + + try: + # 激活目标窗口 + _activate_window(window.handle) + + # 激活后重新获取窗口坐标(窗口位置可能在激活/恢复后发生变化) + left, top, right, bottom = _get_window_rect(window.handle) + if right - left <= 0 or bottom - top <= 0: + raise WindowCaptureError("Target window has zero area.") + + # 更新 window 对象的 rect 以便返回正确信息 + window.rect = (left, top, right, bottom) + + try: + img = ImageGrab.grab(bbox=(left, top, right, bottom)) + except Exception as exc: + raise WindowCaptureError(f"ImageGrab failed: {exc}") from exc + + save_dir = save_dir or DEFAULT_SAVE_DIR + os.makedirs(save_dir, exist_ok=True) + filename = f"{_sanitize_filename(window.title)}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" + save_path = os.path.abspath(os.path.join(save_dir, filename)) + + img.save(save_path, format="PNG") + + return save_path + + finally: + # 恢复原来的前台窗口 + if original_foreground and original_foreground != window.handle: + time.sleep(0.1) + _activate_window(original_foreground) + + +TOOLS: List[Tool] = [ + Tool( + name="list_windows", + description="List visible top-level windows. Supports keyword filter and optional hidden windows.", + inputSchema={ + "type": "object", + "properties": { + "keyword": {"type": "string", "description": "Substring to match in the window title (case-insensitive)."}, + "include_hidden": {"type": "boolean", "description": "Include hidden/minimized windows."}, + "limit": {"type": "integer", "description": "Max number of windows to return (0 means no limit)."}, + }, + "required": [], + }, + ), + Tool( + name="capture_window", + description="Capture a PNG screenshot of a specific window by title keyword or numeric/hex handle; returns local file path.", + inputSchema={ + "type": "object", + "properties": { + "window": { + "type": "string", + "description": "Title keyword or window handle (e.g. 'Notepad', '197324', or '0x2ff3e').", + }, + "include_hidden": {"type": "boolean", "description": "Allow capturing hidden/minimized windows."}, + "save_dir": { + "type": "string", + "description": f"Optional folder to save the PNG. Default: {DEFAULT_SAVE_DIR}", + }, + }, + "required": ["window"], + }, + ), +] + + +@server.list_tools() +async def list_tools() -> List[Tool]: + return TOOLS + + +@server.call_tool() +async def call_tool(name: str, arguments: Optional[Dict[str, Any]]) -> List[Any]: + args = arguments or {} + try: + if name == "list_windows": + keyword = args.get("keyword") + include_hidden = bool(args.get("include_hidden", False)) + limit_raw = args.get("limit", 20) + try: + limit_val = int(limit_raw) + except Exception: + limit_val = 20 + limit_val = max(0, min(limit_val, 200)) + + windows = _enum_windows(keyword, include_hidden=include_hidden, limit=limit_val or 0) + payload = { + "count": len(windows), + "keyword": keyword or "", + "include_hidden": include_hidden, + "windows": [w.to_dict() for w in windows], + } + text = json.dumps(payload, ensure_ascii=False, indent=2) + return [_text_content(text)] + + if name == "capture_window": + query = args.get("window") or args.get("title") + include_hidden = bool(args.get("include_hidden", False)) + save_dir = args.get("save_dir") or None + + save_path = capture_window(query, save_dir=save_dir, include_hidden=include_hidden) + return [_text_content(save_path)] + + return [_text_content(f"Unknown tool: {name}")] + + except Exception as exc: + return [_text_content(f"Error running tool {name}: {exc}")] + + +async def main() -> None: + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + init_opts = server.create_initialization_options() + await server.run(read_stream, write_stream, init_opts) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/mcp_servers/yueshen_rag/README.md b/mcp_servers/yueshen_rag/README.md new file mode 100644 index 0000000..955b7a7 --- /dev/null +++ b/mcp_servers/yueshen_rag/README.md @@ -0,0 +1,43 @@ +## YueShen RAG MCP Server + +扫描 `新知识库`(或自定义目录)的 pdf/docx,按段落/句子切块写入 Chroma,并提供检索工具。Embedding 配置可通过 MCP 参数或环境变量传入。 + +### 依赖 +```bash +pip install -r requirements.txt +``` +- 若有 `.doc` 请先转换为 `.docx` 再处理;当前依赖仅支持 pdf/docx。 + +### 环境变量(可选) +- `YUESHEN_CORPUS_DIR`:知识库原始文档目录(默认 `新知识库`) +- `YUESHEN_PERSIST_DIR`:Chroma 向量库持久化目录(默认 `cache_data/chromadb_yueshen`) +- `YUESHEN_EMBED_BASE_URL`:Embedding API base url(将拼接 `/embeddings`) +- `YUESHEN_EMBED_API_KEY`:Embedding API key +- `YUESHEN_EMBED_MODEL`:Embedding 模型名(默认 `text-embedding-3-small`) +- `YUESHEN_AUTO_INGEST`:是否启用启动即自动扫描入库(默认 1,设为 0 关闭) +- `YUESHEN_AUTO_INTERVAL`:自动扫描间隔秒(默认 300,最小 30) +- `YUESHEN_AUTO_RESET_ON_START`:启动时是否 reset 后重建索引(默认 0) + +### 运行 +```bash +cd mcp_servers/yueshen_rag +python server.py +``` + +### 添加到 Fay +- MCP 管理页面:新增服务器,transport 选 `stdio`;command 填 Python(如 `python` 或虚拟环境路径);args `["mcp_servers/yueshen_rag/server.py"]`;cwd 指向项目根目录;如需自定义 Embedding,填入 env 的 base url / api key / model。 +- 也可以直接编辑 `faymcp/data/mcp_servers.json` 添加对应项,重启 Fay MCP 服务后生效。 + +### 预启动推荐 +- 在 MCP 页面工具列表为 `query_yueshen` 打开“预启动”,参数示例:`{"query": "{{question}}", "top_k": 4}`,用户提问会替换 `{{question}}`。 +- 若希望启动后自动补扫新文档,可为 `ingest_yueshen` 配置预启动(如 `{"reset": false}` 或指定 `corpus_dir`/`batch_size` 等)。 + +### 工具 +- `ingest_yueshen`:扫描并入库;参数 `corpus_dir`、`reset`、`chunk_size`、`overlap`、`batch_size`、`max_files`,以及可选 `embedding_base_url`/`embedding_api_key`/`embedding_model` 覆盖环境变量。 +- `query_yueshen`:向量检索;参数 `query`,可选 `top_k`、`where`,以及可选 embedding 配置与 ingest 保持一致。 +- `yueshen_stats`:查看向量库状态(持久化目录、集合名、向量数等)。 + +### 默认路径与切块 +- 语料目录:`悦肾e家知识库202511/新知识库` +- 持久化目录:`cache_data/chromadb_yueshen` +- 切块:约 600 字,120 重叠,可按需调整 diff --git a/mcp_servers/yueshen_rag/requirements.txt b/mcp_servers/yueshen_rag/requirements.txt new file mode 100644 index 0000000..fb794d4 --- /dev/null +++ b/mcp_servers/yueshen_rag/requirements.txt @@ -0,0 +1,6 @@ +mcp +chromadb +pdfplumber +python-docx +requests +# doc 格式已转成 docx,则无需额外依赖 diff --git a/mcp_servers/yueshen_rag/server.py b/mcp_servers/yueshen_rag/server.py new file mode 100644 index 0000000..c6a3a9a --- /dev/null +++ b/mcp_servers/yueshen_rag/server.py @@ -0,0 +1,687 @@ +#!/usr/bin/env python3 +""" +YueShen Knowledge Base RAG MCP Server + +- Load pdf/docx from a directory, chunk, and write into Chroma. +- Embedding config is provided via MCP tool params or env vars, not system.conf. +- Auto-ingest watcher can run on startup to keep the index fresh. +- Tools: ingest_yueshen, query_yueshen, yueshen_stats. +""" + +import hashlib +import json +import logging +import os +import re +import sys +import time +import threading +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import requests + +# Keep stdout clean for MCP stdio; route logs to stderr and disable Chroma telemetry noise. +os.environ.setdefault("CHROMA_TELEMETRY", "FALSE") + +# Make project root importable (for optional fallback embedding) +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +try: + from mcp.server import Server + from mcp.types import Tool, TextContent + import mcp.server.stdio +except ImportError: + print("MCP library not installed. Please run: pip install mcp", file=sys.stderr, flush=True) + sys.exit(1) + +try: + import chromadb +except ImportError: + print("chromadb not installed. Please run: pip install chromadb", file=sys.stderr, flush=True) + sys.exit(1) + +server = Server("yueshen_rag") + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", + stream=sys.stderr, +) +logger = logging.getLogger("yueshen_rag") + +# Defaults (can be overridden via env) +DEFAULT_CORPUS_DIR = os.getenv( + "YUESHEN_CORPUS_DIR", + os.path.join(PROJECT_ROOT, "新知识库"), +) +DEFAULT_PERSIST_DIR = os.getenv( + "YUESHEN_PERSIST_DIR", + os.path.join(PROJECT_ROOT, "cache_data", "chromadb_yueshen"), +) +COLLECTION_NAME = "yueshen_kb" +DEFAULT_EMBED_BASE_URL = os.getenv("YUESHEN_EMBED_BASE_URL") +DEFAULT_EMBED_API_KEY = os.getenv("YUESHEN_EMBED_API_KEY") +DEFAULT_EMBED_MODEL = os.getenv("YUESHEN_EMBED_MODEL", "text-embedding-3-small") +AUTO_INGEST_ENABLED = os.getenv("YUESHEN_AUTO_INGEST", "1") != "0" +AUTO_INGEST_INTERVAL = int(os.getenv("YUESHEN_AUTO_INTERVAL", "300")) +AUTO_RESET_ON_START = os.getenv("YUESHEN_AUTO_RESET_ON_START", "0") != "0" + + +# -------------------- Text chunking -------------------- # +def _len_with_newlines(parts: List[str]) -> int: + if not parts: + return 0 + return sum(len(p) for p in parts) + (len(parts) - 1) + + +def split_into_chunks(text: str, chunk_size: int = 600, overlap: int = 120) -> List[str]: + """Paragraph/ sentence-aware chunking with small overlap.""" + cleaned = re.sub(r"[ \t]+", " ", text.replace("\u00a0", " ")).strip() + paragraphs = [p.strip() for p in re.split(r"\n\s*\n", cleaned) if p.strip()] + segments: List[str] = [] + for para in paragraphs: + if len(para) <= chunk_size: + segments.append(para) + else: + for sent in re.split(r"(?<=[。!?!?…])", para): + s = sent.strip() + if s: + segments.append(s) + + chunks: List[str] = [] + buf: List[str] = [] + for seg in segments: + seg = seg.strip() + if not seg: + continue + current_len = _len_with_newlines(buf) + if current_len + len(seg) + (1 if buf else 0) <= chunk_size: + buf.append(seg) + continue + + if buf: + chunks.append("\n".join(buf).strip()) + + # build overlap from previous chunk tail + buf = [] + if overlap > 0 and chunks: + tail: List[str] = [] + tail_len = 0 + for s in reversed(chunks[-1].split("\n")): + tail.insert(0, s) + tail_len += len(s) + if tail_len >= overlap: + break + if tail: + buf.extend(tail) + + buf.append(seg) + + if buf: + chunks.append("\n".join(buf).strip()) + return chunks + + +# -------------------- Document readers -------------------- # +def _extract_docx(path: str) -> str: + from docx import Document + + doc = Document(path) + texts: List[str] = [] + for para in doc.paragraphs: + t = para.text.strip() + if t: + texts.append(t) + + for table in doc.tables: + for row in table.rows: + cells = [cell.text.strip() for cell in row.cells if cell.text.strip()] + if cells: + texts.append(" | ".join(cells)) + + return "\n".join(texts) + + +def _extract_pdf_pages(path: str) -> List[Tuple[int, str]]: + try: + import pdfplumber + except ImportError as exc: + raise RuntimeError("pdfplumber is required for pdf parsing") from exc + + pages: List[Tuple[int, str]] = [] + with pdfplumber.open(path) as pdf: + for idx, page in enumerate(pdf.pages, start=1): + txt = page.extract_text() or "" + pages.append((idx, txt)) + return pages + + +# -------------------- Data models -------------------- # +@dataclass +class Chunk: + text: str + source_path: str + page: Optional[int] + chunk_id: str + metadata: Dict[str, Any] + + +# -------------------- Corpus loader -------------------- # +class CorpusLoader: + def __init__(self, root_dir: str = DEFAULT_CORPUS_DIR, chunk_size: int = 600, overlap: int = 120): + self.root_dir = root_dir + self.chunk_size = chunk_size + self.overlap = overlap + + def _iter_files(self) -> Iterable[str]: + for root, _, files in os.walk(self.root_dir): + for fn in files: + if fn.lower().endswith((".pdf", ".docx")): + yield os.path.join(root, fn) + + def _file_to_chunks(self, path: str) -> List[Chunk]: + ext = os.path.splitext(path)[1].lower() + rel_path = os.path.relpath(path, self.root_dir) + chunks: List[Chunk] = [] + + try: + if ext == ".pdf": + pages = _extract_pdf_pages(path) + for page_num, page_text in pages: + for idx, chunk_text in enumerate( + split_into_chunks(page_text, chunk_size=self.chunk_size, overlap=self.overlap) + ): + chunk_id = hashlib.md5( + f"{rel_path}|{page_num}|{idx}|{chunk_text}".encode("utf-8", errors="ignore") + ).hexdigest() + chunks.append( + Chunk( + text=chunk_text, + source_path=rel_path, + page=page_num, + chunk_id=chunk_id, + metadata={"source": rel_path, "page": page_num, "ext": ext}, + ) + ) + elif ext == ".docx": + text = _extract_docx(path) + for idx, chunk_text in enumerate( + split_into_chunks(text, chunk_size=self.chunk_size, overlap=self.overlap) + ): + chunk_id = hashlib.md5( + f"{rel_path}|docx|{idx}|{chunk_text}".encode("utf-8", errors="ignore") + ).hexdigest() + chunks.append( + Chunk( + text=chunk_text, + source_path=rel_path, + page=None, + chunk_id=chunk_id, + metadata={"source": rel_path, "ext": ext}, + ) + ) + except Exception as exc: + logger.warning("Skip file due to parse error %s: %s", rel_path, exc) + + return chunks + + def load(self, max_files: Optional[int] = None) -> List[Chunk]: + all_chunks: List[Chunk] = [] + for idx, file_path in enumerate(self._iter_files(), start=1): + if max_files and idx > max_files: + break + all_chunks.extend(self._file_to_chunks(file_path)) + return all_chunks + + +# -------------------- Embedding backend -------------------- # +class EmbeddingBackend: + """Embedding client with API config, falling back to project encoder if needed.""" + + def __init__( + self, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + model: Optional[str] = None, + ): + self.base_url = base_url or DEFAULT_EMBED_BASE_URL + self.api_key = api_key or DEFAULT_EMBED_API_KEY + self.model = model or DEFAULT_EMBED_MODEL + self._cache: Dict[str, List[float]] = {} + self._fallback_encoder = None + try: + from simulation_engine.gpt_structure import get_text_embedding as _fallback + + self._fallback_encoder = _fallback + except Exception as exc: + logger.info("Fallback embedding not available: %s", exc) + + def _call_api(self, text: str) -> List[float]: + if not self.base_url or not self.api_key: + raise RuntimeError("Embedding API config missing (base_url/api_key)") + url = self.base_url.rstrip("/") + "/embeddings" + payload = {"input": text, "model": self.model} + headers = {"Authorization": f"Bearer {self.api_key}"} + resp = requests.post(url, json=payload, headers=headers, timeout=30) + if resp.status_code != 200: + raise RuntimeError(f"Embedding API error: {resp.status_code} {resp.text}") + data = resp.json() + embedding = data.get("data", [{}])[0].get("embedding") + if embedding is None: + raise RuntimeError("Embedding API response missing embedding") + return embedding + + def encode(self, text: str) -> List[float]: + cache_key = hashlib.md5(f"{self.model}|{self.base_url}|{text}".encode("utf-8", errors="ignore")).hexdigest() + if cache_key in self._cache: + return self._cache[cache_key] + + embedding: Optional[List[float]] = None + if self.base_url and self.api_key: + embedding = self._call_api(text) + elif self._fallback_encoder: + embedding = self._fallback_encoder(text) + else: + raise RuntimeError("No embedding method available (provide base_url/api_key or enable fallback).") + + if not isinstance(embedding, list): + embedding = list(embedding) + self._cache[cache_key] = embedding + return embedding + + +# -------------------- Chroma store -------------------- # +class ChromaStore: + def __init__( + self, + persist_dir: str = DEFAULT_PERSIST_DIR, + collection_name: str = COLLECTION_NAME, + embedder: Optional[EmbeddingBackend] = None, + ): + os.makedirs(persist_dir, exist_ok=True) + self.persist_dir = persist_dir + self.collection_name = collection_name + self.embedder = embedder or EmbeddingBackend() + self.client = chromadb.PersistentClient(path=persist_dir) + self.collection = self.client.get_or_create_collection(collection_name) + + def reset(self): + self.client.delete_collection(self.collection_name) + self.collection = self.client.get_or_create_collection(self.collection_name) + + def upsert_chunks(self, chunks: List[Chunk], batch_size: int = 32) -> Dict[str, Any]: + start = time.time() + total = 0 + ids: List[str] = [] + docs: List[str] = [] + metas: List[Dict[str, Any]] = [] + embs: List[List[float]] = [] + + def flush(): + nonlocal total, ids, docs, metas, embs + if not ids: + return + self.collection.upsert(ids=ids, documents=docs, metadatas=metas, embeddings=embs) + total += len(ids) + ids, docs, metas, embs = [], [], [], [] + + for chunk in chunks: + ids.append(chunk.chunk_id) + docs.append(chunk.text) + metas.append(chunk.metadata) + try: + embs.append(self.embedder.encode(chunk.text)) + except Exception as exc: + logger.error("Embedding failed, skip id=%s: %s", chunk.chunk_id, exc) + ids.pop() + docs.pop() + metas.pop() + continue + + if len(ids) >= batch_size: + flush() + + flush() + elapsed = time.time() - start + return {"inserted": total, "seconds": round(elapsed, 2)} + + def query(self, query: str, top_k: int = 5, where: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + emb = self.embedder.encode(query) + res = self.collection.query(query_embeddings=[emb], n_results=top_k, where=where if where else None) + results = [] + ids = res.get("ids", [[]])[0] + docs = res.get("documents", [[]])[0] + metas = res.get("metadatas", [[]])[0] + dists = res.get("distances", [[]])[0] if "distances" in res else [None] * len(ids) + for i in range(len(ids)): + results.append( + {"id": ids[i], "document": docs[i], "metadata": metas[i], "distance": dists[i]} + ) + return {"results": results, "count": len(results)} + + def stats(self) -> Dict[str, Any]: + try: + count = self.collection.count() + except Exception: + count = None + return {"persist_dir": self.persist_dir, "collection": self.collection_name, "vectors": count} + + +# -------------------- Knowledge manager -------------------- # +class KnowledgeManager: + def __init__( + self, + corpus_dir: str = DEFAULT_CORPUS_DIR, + persist_dir: str = DEFAULT_PERSIST_DIR, + embedder: Optional[EmbeddingBackend] = None, + ): + self.corpus_dir = corpus_dir + self.persist_dir = persist_dir + self.embedder = embedder or EmbeddingBackend() + self.store = ChromaStore(persist_dir=persist_dir, collection_name=COLLECTION_NAME, embedder=self.embedder) + + def _refresh_embedder(self, base_url: Optional[str], api_key: Optional[str], model: Optional[str]): + if any([base_url, api_key, model]): + self.embedder = EmbeddingBackend(base_url=base_url, api_key=api_key, model=model) + self.store = ChromaStore( + persist_dir=self.persist_dir, collection_name=COLLECTION_NAME, embedder=self.embedder + ) + + def ingest( + self, + corpus_dir: Optional[str] = None, + reset: bool = False, + chunk_size: int = 600, + overlap: int = 120, + batch_size: int = 32, + max_files: Optional[int] = None, + embedding_base_url: Optional[str] = None, + embedding_api_key: Optional[str] = None, + embedding_model: Optional[str] = None, + ) -> Dict[str, Any]: + self._refresh_embedder(embedding_base_url, embedding_api_key, embedding_model) + target_dir = corpus_dir or self.corpus_dir + loader = CorpusLoader(root_dir=target_dir, chunk_size=chunk_size, overlap=overlap) + if reset: + self.store.reset() + + chunks = loader.load(max_files=max_files) + logger.info("Loaded %d chunks from %s, start upsert...", len(chunks), target_dir) + upsert_res = self.store.upsert_chunks(chunks, batch_size=batch_size) + return { + "success": True, + "message": "ingest completed", + "chunks": len(chunks), + "inserted": upsert_res.get("inserted", 0), + "seconds": upsert_res.get("seconds"), + "persist_dir": self.persist_dir, + "collection": COLLECTION_NAME, + "corpus_dir": target_dir, + "embedding_base_url": self.embedder.base_url, + "embedding_model": self.embedder.model, + } + + def query( + self, + query: str, + top_k: int = 5, + where: Optional[Dict[str, Any]] = None, + embedding_base_url: Optional[str] = None, + embedding_api_key: Optional[str] = None, + embedding_model: Optional[str] = None, + ) -> Dict[str, Any]: + self._refresh_embedder(embedding_base_url, embedding_api_key, embedding_model) + return self.store.query(query=query, top_k=top_k, where=where) + + def stats(self) -> Dict[str, Any]: + info = self.store.stats() + info.update( + { + "default_corpus_dir": self.corpus_dir, + "embedding_base_url": self.embedder.base_url, + "embedding_model": self.embedder.model, + } + ) + return info + + +manager = KnowledgeManager() + + +# -------------------- Auto ingest watcher -------------------- # +class AutoIngestor: + """Simple polling-based watcher to auto-ingest when files change.""" + + def __init__( + self, + km: KnowledgeManager, + interval_sec: int = AUTO_INGEST_INTERVAL, + reset_on_start: bool = AUTO_RESET_ON_START, + enabled: bool = AUTO_INGEST_ENABLED, + ): + self.km = km + self.interval = max(30, interval_sec) + self.reset_on_start = reset_on_start + self.enabled = enabled + self._stop = threading.Event() + self._thread: Optional[threading.Thread] = None + self._snapshot: Dict[str, Tuple[float, int]] = {} + + def _take_snapshot(self) -> Dict[str, Tuple[float, int]]: + snap: Dict[str, Tuple[float, int]] = {} + for root, _, files in os.walk(self.km.corpus_dir): + for fn in files: + if fn.lower().endswith((".pdf", ".docx")): + path = os.path.join(root, fn) + try: + st = os.stat(path) + snap[path] = (st.st_mtime, st.st_size) + except OSError: + continue + return snap + + def _has_changes(self) -> bool: + new_snap = self._take_snapshot() + if new_snap != self._snapshot: + self._snapshot = new_snap + return True + return False + + def _ingest_once(self, reset: bool = False): + try: + res = self.km.ingest( + corpus_dir=self.km.corpus_dir, + reset=reset, + embedding_base_url=self.km.embedder.base_url, + embedding_api_key=self.km.embedder.api_key, + embedding_model=self.km.embedder.model, + ) + logger.info("Auto-ingest done: %s", json.dumps(res, ensure_ascii=False)) + except Exception as exc: + logger.error("Auto-ingest failed: %s", exc) + + def _loop(self): + # initial snapshot and optional first ingest + self._snapshot = self._take_snapshot() + if self.reset_on_start: + logger.info("Auto-ingest on start (reset=%s)...", self.reset_on_start) + self._ingest_once(reset=True) + elif self.enabled: + logger.info("Auto-ingest initial run...") + self._ingest_once(reset=False) + + while not self._stop.wait(self.interval): + if self._has_changes(): + logger.info("Detected corpus change, auto-ingest...") + self._ingest_once(reset=False) + + def start(self): + if not self.enabled: + logger.info("Auto-ingest disabled via env (YUESHEN_AUTO_INGEST=0)") + return + if self._thread and self._thread.is_alive(): + return + self._thread = threading.Thread(target=self._loop, daemon=True) + self._thread.start() + logger.info("Auto-ingest watcher started, interval=%ss", self.interval) + + def stop(self): + self._stop.set() + if self._thread: + self._thread.join(timeout=2) + + +# -------------------- Skip patterns for trivial queries -------------------- # +SKIP_QUERY_PATTERNS = [ + # 问候语 + r'^你好[啊呀吗]?$', r'^hello[!!]?$', r'^hi[!!]?$', r'^嗨[!!]?$', r'^hey[!!]?$', + r'^早[上]?好[啊呀]?$', r'^晚[上]?好[啊呀]?$', r'^下午好[啊呀]?$', r'^中午好[啊呀]?$', + # 简单回复 + r'^ok[!!]?$', r'^好[的吧啊呀]?[!!]?$', r'^行[!!]?$', r'^可以[!!]?$', r'^没问题[!!]?$', + r'^嗯[嗯]?[!!]?$', r'^哦[哦]?[!!]?$', r'^噢[!!]?$', + # 笑声/情绪 + r'^哈哈[哈]*[!!]?$', r'^呵呵[呵]*[!!]?$', r'^嘿嘿[嘿]*[!!]?$', r'^嘻嘻[嘻]*[!!]?$', + r'^哼[!!]?$', r'^呜呜[呜]*[!!]?$', + # 日常用语 + r'^睡觉[了去]?[!!]?$', r'^晚安[!!]?$', r'^再见[!!]?$', r'^拜拜[!!]?$', r'^bye[!!]?$', + r'^谢谢[你您]?[!!]?$', r'^感谢[!!]?$', r'^thanks[!!]?$', r'^thank you[!!]?$', + r'^对不起[!!]?$', r'^抱歉[!!]?$', r'^sorry[!!]?$', + r'^是[的吧啊]?[!!]?$', r'^对[的吧啊]?[!!]?$', r'^不是[!!]?$', r'^不对[!!]?$', + r'^知道了[!!]?$', r'^明白了[!!]?$', r'^懂了[!!]?$', r'^了解[!!]?$', + r'^收到[!!]?$', r'^好嘞[!!]?$', r'^得嘞[!!]?$', + # 疑问简单回复 + r'^啥[??]?$', r'^什么[??]?$', r'^嗯[??]$', r'^哈[??]$', + # 单字或极短 + r'^[.。,,!!??~~]+$', +] + +def _is_trivial_query(query: str) -> bool: + """Check if query is a trivial greeting or simple response that doesn't need KB search.""" + if not query: + return True + q = query.strip().lower() + if len(q) <= 2: + return True + for pattern in SKIP_QUERY_PATTERNS: + if re.match(pattern, q, re.IGNORECASE): + return True + return False + + +# -------------------- MCP tools -------------------- # +@server.list_tools() +async def list_tools() -> List[Tool]: + return [ + Tool( + name="ingest_yueshen", + description="Scan directory (pdf/docx/doc), chunk and write to Chroma", + inputSchema={ + "type": "object", + "properties": { + "corpus_dir": {"type": "string", "description": "Optional corpus directory override"}, + "reset": {"type": "boolean", "description": "Recreate collection before ingest", "default": False}, + "chunk_size": {"type": "integer", "description": "Chunk length (chars)", "default": 600}, + "overlap": {"type": "integer", "description": "Chunk overlap (chars)", "default": 120}, + "batch_size": {"type": "integer", "description": "Upsert batch size", "default": 32}, + "max_files": {"type": "integer", "description": "Optional limit for quick test"}, + "embedding_base_url": {"type": "string", "description": "Embedding API base url"}, + "embedding_api_key": {"type": "string", "description": "Embedding API key"}, + "embedding_model": {"type": "string", "description": "Embedding model name"}, + }, + }, + ), + Tool( + name="query_yueshen", + description="Vector search in YueShen KB", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "User query"}, + "top_k": {"type": "integer", "description": "Number of results", "default": 5}, + "where": {"type": "object", "description": "Optional metadata filter (Chroma where)"}, + "embedding_base_url": {"type": "string", "description": "Embedding API base url"}, + "embedding_api_key": {"type": "string", "description": "Embedding API key"}, + "embedding_model": {"type": "string", "description": "Embedding model name"}, + }, + "required": ["query"], + }, + ), + Tool( + name="yueshen_stats", + description="Show current vector store stats", + inputSchema={"type": "object", "properties": {}}, + ), + ] + + +@server.call_tool() +async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: + try: + if name == "ingest_yueshen": + res = manager.ingest( + corpus_dir=arguments.get("corpus_dir"), + reset=bool(arguments.get("reset", False)), + chunk_size=int(arguments.get("chunk_size", 600)), + overlap=int(arguments.get("overlap", 120)), + batch_size=int(arguments.get("batch_size", 32)), + max_files=arguments.get("max_files"), + embedding_base_url=arguments.get("embedding_base_url"), + embedding_api_key=arguments.get("embedding_api_key"), + embedding_model=arguments.get("embedding_model"), + ) + return [TextContent(type="text", text=json.dumps(res, ensure_ascii=False, indent=2))] + + if name == "query_yueshen": + query_text = arguments.get("query", "") + # 跳过常见问候和简单回复,不进行知识库查询 + if _is_trivial_query(query_text): + return [TextContent(type="text", text=json.dumps({ + "results": [], + "count": 0, + "skipped": True, + "reason": "trivial query (greeting or simple response)" + }, ensure_ascii=False, indent=2))] + res = manager.query( + query=query_text, + top_k=int(arguments.get("top_k", 5)), + where=arguments.get("where"), + embedding_base_url=arguments.get("embedding_base_url"), + embedding_api_key=arguments.get("embedding_api_key"), + embedding_model=arguments.get("embedding_model"), + ) + return [TextContent(type="text", text=json.dumps(res, ensure_ascii=False, indent=2))] + + if name == "yueshen_stats": + res = manager.stats() + return [TextContent(type="text", text=json.dumps(res, ensure_ascii=False, indent=2))] + + return [ + TextContent( + type="text", + text=json.dumps({"success": False, "message": f"unknown tool: {name}"}, ensure_ascii=False), + ) + ] + except Exception as exc: + return [ + TextContent( + type="text", + text=json.dumps({"success": False, "message": f"exception: {exc}"}, ensure_ascii=False), + ) + ] + + +async def main(): + auto = AutoIngestor(manager) + auto.start() + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + init_opts = server.create_initialization_options() + await server.run(read_stream, write_stream, init_opts) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main())