diff --git a/README.md b/README.md index 82ed998..3d6f7c3 100644 --- a/README.md +++ b/README.md @@ -24,15 +24,16 @@ - 支持数字人自动播报模式(虚拟教师、虚拟主播、新闻播报) - 支持任意终端使用:单片机、app、网站、大屏、成熟系统接入等 - 支持多用户多路并发 -- 提供文字沟通接口、声音沟通接口、数字人模型接口、管理控制接口、自动播报接口、意图接口 +- 提供文字交互接口、语音交互接口、数字人驱动接口、管理控制接口、自动播报接口、意图接口 - 支持语音指令灵活配置执行 - 支持自定义知识库、自定义问答对、自定义人设信息 - 支持唤醒及打断对话 - 支持服务器及单机模式 - 支持机器人表情输出 -- 支持react agent自主决策执行、主动对话 +- 支持react agent自主决策执行、主动对话(准备升级到MCP协议) - 支持后台静默启动 - 支持deepseek +- 设计独特的认知模型 ### @@ -51,7 +52,7 @@ ### **环境** -- Python 3.9、3.10、3.11、3.12 +- Python 3.12 - Windows、macos、linux ### **安装依赖** diff --git a/config.json b/config.json index 5b87399..618aff4 100644 --- a/config.json +++ b/config.json @@ -6,11 +6,11 @@ "constellation": "\u6c34\u74f6\u5ea7", "contact": "qq467665317", "gender": "\u5973", - "goal": "\u89e3\u51b3\u95ee\u9898", + "goal": "\u5de5\u4f5c\u534f\u52a9", "hobby": "\u53d1\u5446", "job": "\u52a9\u7406", "name": "\u83f2\u83f2", - "position": "\u5ba2\u670d", + "position": "\u966a\u4f34", "voice": "\u6653\u6653(edge)", "zodiac": "\u86c7" }, diff --git a/core/authorize_tb.py b/core/authorize_tb.py index 55dabf3..49813a8 100644 --- a/core/authorize_tb.py +++ b/core/authorize_tb.py @@ -17,7 +17,7 @@ class Authorize_Tb: #初始化 def init_tb(self): - conn = sqlite3.connect('fay.db') + conn = sqlite3.connect('memory/fay.db') c = conn.cursor() c.execute(''' CREATE TABLE IF NOT EXISTS T_Authorize @@ -34,7 +34,7 @@ class Authorize_Tb: @synchronized def add(self,userid,accesstoken,expirestime): self.init_tb() - conn = sqlite3.connect("fay.db") + conn = sqlite3.connect("memory/fay.db") cur = conn.cursor() cur.execute("insert into T_Authorize (userid,accesstoken,expirestime,createtime) values (?,?,?,?)",(userid,accesstoken,expirestime,int(time.time()))) @@ -46,7 +46,7 @@ class Authorize_Tb: @synchronized def find_by_userid(self,userid): self.init_tb() - conn = sqlite3.connect("fay.db") + conn = sqlite3.connect("memory/fay.db") cur = conn.cursor() cur.execute("select accesstoken,expirestime from T_Authorize where userid = ? order by id desc limit 1",(userid,)) info = cur.fetchone() @@ -57,7 +57,7 @@ class Authorize_Tb: @synchronized def update_by_userid(self, userid, new_accesstoken, new_expirestime): self.init_tb() - conn = sqlite3.connect("fay.db") + conn = sqlite3.connect("memory/fay.db") cur = conn.cursor() cur.execute("UPDATE T_Authorize SET accesstoken = ?, expirestime = ? WHERE userid = ?", (new_accesstoken, new_expirestime, userid)) diff --git a/core/content_db.py b/core/content_db.py index 9e18323..b36baf6 100644 --- a/core/content_db.py +++ b/core/content_db.py @@ -26,7 +26,8 @@ class Content_Db: # 初始化数据库 def init_db(self): - conn = sqlite3.connect('fay.db') + conn = sqlite3.connect('memory/fay.db') + conn.text_factory = str c = conn.cursor() c.execute('''CREATE TABLE IF NOT EXISTS T_Msg (id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -48,7 +49,8 @@ class Content_Db: # 添加对话 @synchronized def add_content(self, type, way, content, username='User', uid=0): - conn = sqlite3.connect("fay.db") + conn = sqlite3.connect("memory/fay.db") + conn.text_factory = str cur = conn.cursor() try: cur.execute("INSERT INTO T_Msg (type, way, content, createtime, username, uid) VALUES (?, ?, ?, ?, ?, ?)", @@ -65,7 +67,8 @@ class Content_Db: # 根据ID查询对话记录 @synchronized def get_content_by_id(self, msg_id): - conn = sqlite3.connect("fay.db") + conn = sqlite3.connect("memory/fay.db") + conn.text_factory = str cur = conn.cursor() cur.execute("SELECT * FROM T_Msg WHERE id = ?", (msg_id,)) record = cur.fetchone() @@ -75,7 +78,8 @@ class Content_Db: # 添加对话采纳记录 @synchronized def adopted_message(self, msg_id): - conn = sqlite3.connect('fay.db') + conn = sqlite3.connect('memory/fay.db') + conn.text_factory = str cur = conn.cursor() # 检查消息ID是否存在 cur.execute("SELECT 1 FROM T_Msg WHERE id = ?", (msg_id,)) @@ -96,7 +100,8 @@ class Content_Db: # 获取对话内容 @synchronized def get_list(self, way, order, limit, uid=0): - conn = sqlite3.connect("fay.db") + conn = sqlite3.connect("memory/fay.db") + conn.text_factory = str cur = conn.cursor() where_uid = "" if int(uid) != 0: @@ -126,7 +131,7 @@ class Content_Db: @synchronized def get_previous_user_message(self, msg_id): - conn = sqlite3.connect("fay.db") + conn = sqlite3.connect("memory/fay.db") cur = conn.cursor() cur.execute(""" SELECT id, type, way, content, createtime, datetime(createtime, 'unixepoch', 'localtime') AS timetext, username diff --git a/core/fay_core.py b/core/fay_core.py index cb1d9ad..c294dc1 100644 --- a/core/fay_core.py +++ b/core/fay_core.py @@ -31,6 +31,7 @@ from llm import nlp_coze from llm.agent import fay_agent from llm import nlp_qingliu from llm import nlp_gpt_stream +from llm import nlp_cognitive_stream from core import member_db import threading @@ -64,8 +65,8 @@ modules = { "nlp_coze": nlp_coze, "nlp_agent": fay_agent, "nlp_qingliu": nlp_qingliu, - "nlp_gpt_stream": nlp_gpt_stream - + "nlp_gpt_stream": nlp_gpt_stream, + "nlp_cognitive_stream": nlp_cognitive_stream } #大语言模型回复 @@ -83,7 +84,7 @@ def handle_chat_message(msg, username='User', observation='', cache=None): if cfg.key_chat_module == 'rasa': textlist = selected_module.question(msg) text = textlist[0]['text'] - elif cfg.key_chat_module == 'gpt_stream' and cache is not None:#TODO 好像是多余了 + elif cfg.key_chat_module.endswith('_stream') and cache is not None:#支持所有流式输出模块 uid = member_db.new_instance().find_user(username) text = selected_module.question(msg, uid, observation, cache) else: @@ -165,9 +166,9 @@ class FeiFei: textlist = [] if answer is None: if wsa_server.get_web_instance().is_connected(username): - wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : username, 'robot': f'http://{cfg.fay_url}:5000/robot/Thinking.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}) if wsa_server.get_instance().is_connected(username): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : username, 'robot': f'http://{cfg.fay_url}:5000/robot/Thinking.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : username, 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'} wsa_server.get_instance().add_cmd(content) text,textlist = handle_chat_message(interact.data["msg"], username, interact.data.get("observation", "")) @@ -177,8 +178,8 @@ class FeiFei: #记录回复并输出到各个终端 self.__process_text_output(text, textlist, username, uid, type) - #声音输出(gpt_stream在stream_manager.py中调用了say函数) - if type == 'qa' or cfg.key_chat_module != 'gpt_stream': + #声音输出(支持流式输出的模块在stream_manager.py中调用了say函数) + if type == 'qa' or not cfg.key_chat_module.endswith('_stream'): if "" in text: text = text.split("")[1] interact.data['isfirst'] = True @@ -219,7 +220,7 @@ class FeiFei: username = interact.data.get("user", "User") if member_db.new_instance().is_username_exist(username) == "notexists": member_db.new_instance().add_user(username) - if cfg.key_chat_module == "gpt_stream": + if cfg.key_chat_module.endswith('_stream'): MyThread(target=self.__process_interact, args=[interact]).start() return None return self.__process_interact(interact) @@ -314,9 +315,9 @@ class FeiFei: if self.think_mode_users.get(uid, False) and is_start_think: if wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : interact.data.get('user'), 'robot': f'http://{cfg.fay_url}:5000/robot/Thinking.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "思考中...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'}) if wsa_server.get_instance().is_connected(interact.data.get("user")): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : interact.data.get('user'), 'robot': f'http://{cfg.fay_url}:5000/robot/Thinking.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "思考中..."}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Thinking.jpg'} wsa_server.get_instance().add_cmd(content) # 如果用户在think模式中,不进行语音合成 @@ -336,7 +337,7 @@ class FeiFei: util.printInfo(1, interact.data.get("user"), "合成音频完成. 耗时: {} ms 文件:{}".format(math.floor((time.time() - tm) * 1000), result)) else: if is_end and wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) if result is not None or is_end: MyThread(target=self.__process_output_audio, args=[result, interact, text]).start() @@ -392,8 +393,18 @@ class FeiFei: is_end = True util.printInfo(1, interact.data.get('user'), '播放音频...') self.speaking = True + + #自动播报关闭 + global auto_play_lock + global can_auto_play + with auto_play_lock: + if self.timer is not None: + self.timer.cancel() + self.timer = None + can_auto_play = False + if wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "播放中 ...", "Username" : interact.data.get('user'), 'robot': f'http://{cfg.fay_url}:5000/robot/Speaking.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "播放中 ...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}) if file_url is not None: pygame.mixer.music.load(file_url) @@ -409,7 +420,7 @@ class FeiFei: if is_end: self.play_end(interact) if wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "", "Username" : interact.data.get('user'), 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) # 播放完毕后通知 if wsa_server.get_web_instance().is_connected(interact.data.get("user")): wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username': interact.data.get('user')}) @@ -463,21 +474,12 @@ class FeiFei: except Exception as e: audio_length = 3 - #自动播报关闭 - global auto_play_lock - global can_auto_play - with auto_play_lock: - if self.timer is not None: - self.timer.cancel() - self.timer = None - can_auto_play = False - #推送远程音频 MyThread(target=self.__send_remote_device_audio, args=[file_url, interact]).start() #发送音频给数字人接口 if file_url is not None and wsa_server.get_instance().is_connected(interact.data.get("user")): - content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'http://{cfg.fay_url}:5000/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver}, 'Username' : interact.data.get('user'), 'robot': f'http://{cfg.fay_url}:5000/robot/Speaking.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'} #计算lips if platform.system() == "Windows": try: @@ -497,7 +499,7 @@ class FeiFei: self.sound_query.put((file_url, audio_length, interact)) else: if wsa_server.get_web_instance().is_connected(interact.data.get('user')): - wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) except Exception as e: print(e) @@ -610,6 +612,9 @@ class FeiFei: :param uid: 用户ID :param type: 消息类型 """ + if text: + text = text.strip() + # 记录主回复 content_id = self.__record_response(text, username, uid) diff --git a/core/member_db.py b/core/member_db.py index 53a3c84..380700d 100644 --- a/core/member_db.py +++ b/core/member_db.py @@ -27,7 +27,7 @@ class Member_Db: #初始化 def init_db(self): - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute('''CREATE TABLE IF NOT EXISTS T_Member (id INTEGER PRIMARY KEY autoincrement, @@ -39,7 +39,7 @@ class Member_Db: @synchronized def add_user(self, username): if self.is_username_exist(username) == "notexists": - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute('INSERT INTO T_Member (username) VALUES (?)', (username,)) conn.commit() @@ -52,7 +52,7 @@ class Member_Db: @synchronized def update_user(self, username, new_username): if self.is_username_exist(new_username) == "notexists": - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute('UPDATE T_Member SET username = ? WHERE username = ?', (new_username, username)) conn.commit() @@ -64,7 +64,7 @@ class Member_Db: # 删除用户 @synchronized def delete_user(self, username): - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute('DELETE FROM T_Member WHERE username = ?', (username,)) conn.commit() @@ -73,7 +73,7 @@ class Member_Db: # 检查用户名是否已存在 def is_username_exist(self, username): - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute('SELECT COUNT(*) FROM T_Member WHERE username = ?', (username,)) result = c.fetchone()[0] @@ -85,7 +85,7 @@ class Member_Db: #根据username查询uid def find_user(self, username): - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute('SELECT * FROM T_Member WHERE username = ?', (username,)) result = c.fetchone() @@ -97,7 +97,7 @@ class Member_Db: #根据uid查询username def find_username_by_uid(self, uid): - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute('SELECT username FROM T_Member WHERE id = ?', (uid,)) result = c.fetchone() @@ -112,7 +112,7 @@ class Member_Db: @synchronized def query(self, sql): try: - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute(sql) results = c.fetchall() @@ -126,7 +126,7 @@ class Member_Db: # 获取所有用户 @synchronized def get_all_users(self): - conn = sqlite3.connect('user_profiles.db') + conn = sqlite3.connect('memory/user_profiles.db') c = conn.cursor() c.execute('SELECT * FROM T_Member') results = c.fetchall() diff --git a/core/recorder.py b/core/recorder.py index 3e8e703..e2d3c18 100644 --- a/core/recorder.py +++ b/core/recorder.py @@ -120,9 +120,9 @@ class Recorder: if wake_up: util.printInfo(1, self.username, "唤醒成功!") if wsa_server.get_web_instance().is_connected(self.username): - wsa_server.get_web_instance().add_cmd({"panelMsg": "唤醒成功!", "Username" : self.username , 'robot': f'http://{cfg.fay_url}:5000/robot/Listening.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "唤醒成功!", "Username" : self.username , 'robot': f'{cfg.fay_url}/robot/Listening.jpg'}) if wsa_server.get_instance().is_connected(self.username): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "唤醒成功!"}, 'Username' : self.username, 'robot': f'http://{cfg.fay_url}:5000/robot/Listening.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "唤醒成功!"}, 'Username' : self.username, 'robot': f'{cfg.fay_url}/robot/Listening.jpg'} wsa_server.get_instance().add_cmd(content) self.wakeup_matched = True # 唤醒成功 with fay_core.auto_play_lock: @@ -136,9 +136,9 @@ class Recorder: else: util.printInfo(1, self.username, "[!] 待唤醒!") if wsa_server.get_web_instance().is_connected(self.username): - wsa_server.get_web_instance().add_cmd({"panelMsg": "[!] 待唤醒!", "Username" : self.username , 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "[!] 待唤醒!", "Username" : self.username , 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) if wsa_server.get_instance().is_connected(self.username): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "[!] 待唤醒!"}, 'Username' : self.username, 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "[!] 待唤醒!"}, 'Username' : self.username, 'robot': f'{cfg.fay_url}/robot/Normal.jpg'} wsa_server.get_instance().add_cmd(content) else: self.on_speaking(text) @@ -160,9 +160,9 @@ class Recorder: if wake_up: util.printInfo(1, self.username, "唤醒成功!") if wsa_server.get_web_instance().is_connected(self.username): - wsa_server.get_web_instance().add_cmd({"panelMsg": "唤醒成功!", "Username" : self.username , 'robot': f'http://{cfg.fay_url}:5000/robot/Listening.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "唤醒成功!", "Username" : self.username , 'robot': f'{cfg.fay_url}/robot/Listening.jpg'}) if wsa_server.get_instance().is_connected(self.username): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "唤醒成功!"}, 'Username' : self.username, 'robot': f'http://{cfg.fay_url}:5000/robot/Listening.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "唤醒成功!"}, 'Username' : self.username, 'robot': f'{cfg.fay_url}/robot/Listening.jpg'} wsa_server.get_instance().add_cmd(content) #去除唤醒词后语句 question = text#[len(wake_up_word):].lstrip() @@ -173,9 +173,9 @@ class Recorder: else: util.printInfo(1, self.username, "[!] 待唤醒!") if wsa_server.get_web_instance().is_connected(self.username): - wsa_server.get_web_instance().add_cmd({"panelMsg": "[!] 待唤醒!", "Username" : self.username , 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "[!] 待唤醒!", "Username" : self.username , 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) if wsa_server.get_instance().is_connected(self.username): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "[!] 待唤醒!"}, 'Username' : self.username, 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': "[!] 待唤醒!"}, 'Username' : self.username, 'robot': f'{cfg.fay_url}/robot/Normal.jpg'} wsa_server.get_instance().add_cmd(content) #非唤醒模式 @@ -190,9 +190,9 @@ class Recorder: util.printInfo(1, self.username, "[!] 语音未检测到内容!") self.dynamic_threshold = self.__get_history_percentage(30) if wsa_server.get_web_instance().is_connected(self.username): - wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : self.username, 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'}) + wsa_server.get_web_instance().add_cmd({"panelMsg": "", 'Username' : self.username, 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) if wsa_server.get_instance().is_connected(self.username): - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': ""}, 'Username' : self.username, 'robot': f'http://{cfg.fay_url}:5000/robot/Normal.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': ""}, 'Username' : self.username, 'robot': f'{cfg.fay_url}/robot/Normal.jpg'} wsa_server.get_instance().add_cmd(content) def __record(self): @@ -313,14 +313,14 @@ class Recorder: wsa_server.get_web_instance().add_cmd({ "panelMsg": "聆听中...", 'Username': self.username, - 'robot': f'http://{cfg.fay_url}:5000/robot/Listening.jpg' + 'robot': f'{cfg.fay_url}/robot/Listening.jpg' }) if wsa_server.get_instance().is_connected(self.username): content = { 'Topic': 'human', 'Data': {'Key': 'log', 'Value': "聆听中..."}, 'Username': self.username, - 'robot': f'http://{cfg.fay_url}:5000/robot/Listening.jpg' + 'robot': f'{cfg.fay_url}/robot/Listening.jpg' } wsa_server.get_instance().add_cmd(content) except Exception as e: diff --git a/core/stream_manager.py b/core/stream_manager.py index db6ab9a..662490b 100644 --- a/core/stream_manager.py +++ b/core/stream_manager.py @@ -114,10 +114,10 @@ class StreamManager: fay_core = fay_booter.feiFei # 处理普通消息,区分是否是会话的第一句 if sentence.endswith('_'): - sentence = sentence[:-len('_')] + sentence = sentence[:-len('_')].strip() interact = Interact("stream", 1, {'user': username, 'msg': sentence, 'isfirst': True}) elif sentence.endswith('_'): - sentence = sentence[:-len('_')] + sentence = sentence[:-len('_')].strip() interact = Interact("stream", 1, {'user': username, 'msg': sentence, 'isend': True}) else: interact = Interact("stream", 1, {'user': username, 'msg': sentence}) diff --git a/docker/Dockerfile b/docker/Dockerfile index ba4f666..1e1fbfc 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,10 +1,11 @@ -FROM docker.m.daocloud.io/python:3.10 - -COPY install_deps.sh /usr/local/bin/install_deps.sh -RUN chmod +x /usr/local/bin/install_deps.sh && /usr/local/bin/install_deps.sh -RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple/ -COPY requirements.txt /app/ +FROM docker.m.daocloud.io/python:3.12 +#FROM python:3.12 ---> Nick +COPY app /app +RUN chmod +x /app/docker/install_deps.sh \ +# && mv /app/docker/sources.list /etc/apt/sources.list \ ---> 添加对应的sources list可以提升apt install 效率 + && /app/docker/install_deps.sh \ + && pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple/ \ + && pip install --no-cache-dir -r /app/docker/requirements.txt WORKDIR /app -RUN pip install --no-cache-dir -r requirements.txt -COPY ./ /app CMD ["python", "main.py"] + diff --git a/docker/environment.yml b/docker/environment.yml deleted file mode 100644 index c8f1da7..0000000 --- a/docker/environment.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: fay -channels: - - defaults -dependencies: - - python=3.10 - - requests - - numpy - - pyaudio=0.2.11 - - websockets=10.2 - - ws4py=0.5.1 - - pyqt=5.15.6 - - flask=3.0.0 - - openpyxl=3.0.9 - - flask-cors=3.0.10 - - pyqtwebengine=5.15.5 - - eyed3=0.9.6 - - websocket-client - - azure-cognitiveservices-speech - - aliyun-python-sdk-core - - simhash - - pytz - - gevent=22.10.1 - - edge-tts=6.1.3 - - ultralytics=8.0.2 - - pydub - - cemotion - - langchain=0.0.336 - - chromadb - - tenacity=8.2.3 - - pygame - - scipy - - pip diff --git a/docker/install_deps.sh b/docker/install_deps.sh deleted file mode 100644 index 037d72a..0000000 --- a/docker/install_deps.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -# 检测 Debian 系统(如 Ubuntu) -if grep -qEi "(debian|ubuntu)" /etc/*release; then - apt-get update -yq --fix-missing && \ - DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends \ - pkg-config \ - wget \ - cmake \ - curl \ - git \ - vim \ - build-essential \ - libgl1-mesa-glx \ - portaudio19-dev \ - libnss3 \ - libxcomposite1 \ - libxrender1 \ - libxrandr2 \ - libqt5webkit5-dev \ - libxdamage1 \ - libxtst6 && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* -# 检测 CentOS 系统 -elif grep -qEi "(centos|fedora|rhel)" /etc/*release; then - yum update -y && \ - yum install -y \ - pkgconfig \ - wget \ - cmake \ - curl \ - git \ - vim-enhanced \ - gcc \ - gcc-c++ \ - mesa-libGL \ - portaudio \ - nss \ - libXcomposite \ - libXrender \ - libXrandr \ - qt5-qtwebkit-devel \ - libXdamage \ - libXtst && \ - yum clean all -else - echo "Unsupported OS" - exit 1 -fi diff --git a/docker/requirements.txt b/docker/requirements.txt index 5802f41..ad80429 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -1,28 +1,31 @@ requests numpy pyaudio~=0.2.11 -websockets~=10.2 +websockets~=10.4 ws4py~=0.5.1 -pyqt5~=5.15.6 +PyQt5==5.15.10 +PyQt5-sip==12.13.0 +PyQtWebEngine==5.15.6 flask~=3.0.0 openpyxl~=3.0.9 flask_cors~=3.0.10 -PyQtWebEngine~=5.15.5 -eyed3~=0.9.6 websocket-client azure-cognitiveservices-speech aliyun-python-sdk-core simhash pytz gevent -edge_tts~=6.1.3 -eyed3 -ultralytics~=8.0.2 +edge_tts pydub -cemotion -langchain==0.0.336 chromadb tenacity==8.2.3 pygame scipy -flask-httpauth \ No newline at end of file +flask-httpauth +opencv-python +psutil +langchain +langchain_openai +langgraph +bs4 +schedule \ No newline at end of file diff --git a/fay_booter.py b/fay_booter.py index 6f1ac8e..c84d19d 100644 --- a/fay_booter.py +++ b/fay_booter.py @@ -16,6 +16,7 @@ from core import wsa_server from core import socket_bridge_service from llm.agent import agent_service import subprocess +from llm.nlp_cognitive_stream import save_agent_memory # 全局变量声明 feiFei = None @@ -294,6 +295,16 @@ def stop(): util.log(1, '正在关闭服务...') __running = False + + # 保存代理记忆 + if config_util.key_chat_module == 'cognitive_stream': + util.log(1, '正在保存代理记忆...') + try: + save_agent_memory() + util.log(1, '代理记忆保存成功') + except Exception as e: + util.log(1, f'保存代理记忆失败: {str(e)}') + if recorderListener is not None: util.log(1, '正在关闭录音服务...') recorderListener.stop() @@ -347,6 +358,12 @@ def start(): from llm import nlp_privategpt nlp_privategpt.save_all() + #初始化定时保存记忆的任务 + if config_util.key_chat_module == 'cognitive_stream': + util.log(1, '初始化定时保存记忆的任务...') + from llm.nlp_cognitive_stream import init_memory_scheduler + init_memory_scheduler() + #开启录音服务 record = config_util.config['source']['record'] if record['enabled']: diff --git a/genagents/genagents.py b/genagents/genagents.py new file mode 100644 index 0000000..5fb2660 --- /dev/null +++ b/genagents/genagents.py @@ -0,0 +1,150 @@ +import uuid + +from genagents.modules.interaction import * +from genagents.modules.memory_stream import * + + +# ############################################################################ +# ### GENERATIVE AGENT CLASS ### +# ############################################################################ + +class GenerativeAgent: + def __init__(self, agent_folder=None): + if agent_folder: + # We stop the process if the agent storage folder already exists. + if not check_if_file_exists(f"{agent_folder}/scratch.json"): + print ("Generative agent does not exist in the current location.") + return + + # Loading the agent's memories. + try: + with open(f"{agent_folder}/scratch.json", 'r', encoding='utf-8') as json_file: + scratch = json.load(json_file) + with open(f"{agent_folder}/memory_stream/embeddings.json", 'r', encoding='utf-8') as json_file: + embeddings = json.load(json_file) + with open(f"{agent_folder}/memory_stream/nodes.json", 'r', encoding='utf-8') as json_file: + nodes = json.load(json_file) + except Exception as e: + print(f"加载代理记忆时出错: {str(e)}") + # 如果加载失败,创建空的记忆 + scratch = {} + embeddings = {} + nodes = [] + + self.id = uuid.uuid4() + self.scratch = scratch + self.memory_stream = MemoryStream(nodes, embeddings) + + else: + self.id = uuid.uuid4() + self.scratch = {} + self.memory_stream = MemoryStream([], {}) + + + def update_scratch(self, update): + self.scratch.update(update) + + + def package(self): + """ + Packaging the agent's meta info for saving. + + Parameters: + None + Returns: + packaged dictionary + """ + return {"id": str(self.id)} + + + def save(self, save_directory): + """ + Given a save_code, save the agents' state in the storage. Right now, the + save directory works as follows: + 'storage//' + + As you grow different versions of the agent, save the new agent state in + a different save code location. Remember that 'init' is the originally + initialized agent directory. + + Parameters: + save_code: str + Returns: + None + """ + try: + # Name of the agent and the current save location. + storage = save_directory + create_folder_if_not_there(f"{storage}/memory_stream") + + # 确保embeddings不为None + if self.memory_stream.embeddings is None: + self.memory_stream.embeddings = {} + + # Saving the agent's memory stream. This includes saving the embeddings + # as well as the nodes. + with open(f"{storage}/memory_stream/embeddings.json", "w", encoding='utf-8') as json_file: + json.dump(self.memory_stream.embeddings, + json_file, ensure_ascii=False, indent=2) + with open(f"{storage}/memory_stream/nodes.json", "w", encoding='utf-8') as json_file: + json.dump([node.package() for node in self.memory_stream.seq_nodes], + json_file, ensure_ascii=False, indent=2) + + # Saving the agent's scratch memories. + with open(f"{storage}/scratch.json", "w", encoding='utf-8') as json_file: + json.dump(self.scratch, json_file, ensure_ascii=False, indent=2) + + # Saving the agent's meta information. + with open(f"{storage}/meta.json", "w", encoding='utf-8') as json_file: + json.dump(self.package(), json_file, ensure_ascii=False, indent=2) + except Exception as e: + print(f"保存代理记忆时出错: {str(e)}") + + + def get_fullname(self): + if "first_name" in self.scratch and "last_name" in self.scratch: + return f"{self.scratch['first_name']} {self.scratch['last_name']}" + else: + return "" + + def get_self_description(self): + return str(self.scratch) + + def remember(self, content, time_step=0): + """ + Add a new observation to the memory stream. + + Parameters: + content: The content of the current memory record that we are adding to + the agent's memory stream. + Returns: + None + """ + self.memory_stream.remember(content, time_step) + + + def reflect(self, anchor, time_step=0): + """ + Add a new reflection to the memory stream. + + Parameters: + anchor: str reflection anchor + Returns: + None + """ + self.memory_stream.reflect(anchor, time_step) + + + def categorical_resp(self, questions): + ret = categorical_resp(self, questions) + return ret + + + def numerical_resp(self, questions, float_resp=False): + ret = numerical_resp(self, questions, float_resp) + return ret + + + def utterance(self, curr_dialogue, context=""): + ret = utterance(self, curr_dialogue, context) + return ret diff --git a/genagents/genagents_flask.py b/genagents/genagents_flask.py new file mode 100644 index 0000000..27e9a17 --- /dev/null +++ b/genagents/genagents_flask.py @@ -0,0 +1,278 @@ +from flask import Flask, render_template, request, jsonify, send_from_directory +import os +import uuid +import json +import sys +import threading +import time +from genagents.genagents import GenerativeAgent +from utils import util + +# 添加项目根目录到sys.path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, project_root) + +# 导入项目中的模块 +from llm.nlp_cognitive_stream import save_agent_memory, create_agent, set_memory_cleared_flag + +# 创建Flask应用 +app = Flask(__name__) + +# 全局变量 +instruction = "" +genagents_port = 5001 +genagents_host = "0.0.0.0" +genagents_debug = True +server_thread = None +shutdown_flag = False +fay_agent = None + +# 确保模板和静态文件目录存在 +def setup_directories(): + os.makedirs(os.path.join(os.path.dirname(__file__), 'templates'), exist_ok=True) + os.makedirs(os.path.join(os.path.dirname(__file__), 'static'), exist_ok=True) + +# 读取指令文件 +def load_instruction(): + global instruction + instruction_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'instruction.json') + if os.path.exists(instruction_file): + try: + with open(instruction_file, 'r', encoding='utf-8') as f: + data = json.load(f) + instruction = data.get('instruction', '') + # 读取后删除文件,防止重复使用 + os.remove(instruction_file) + except Exception as e: + print(f"读取指令文件出错: {str(e)}") + +@app.route('/') +def index(): + """提供主页HTML""" + return render_template('decision_interview.html', instruction=instruction) + +# 关闭服务器的函数 +def shutdown_server(): + global shutdown_flag + shutdown_flag = True + # 不再直接访问request对象,而是设置标志让服务器自行关闭 + print("服务器将在处理完当前请求后关闭...") + +# 清除记忆API +@app.route('/api/clear-memory', methods=['POST']) +def api_clear_memory(): + try: + # 获取memory目录路径 + memory_dir = os.path.join(os.getcwd(), "memory") + + # 检查目录是否存在 + if not os.path.exists(memory_dir): + return jsonify({'success': False, 'message': '记忆目录不存在'}), 400 + + # 清空memory目录下的所有文件(保留目录结构) + for root, dirs, files in os.walk(memory_dir): + for file in files: + file_path = os.path.join(root, file) + try: + if os.path.isfile(file_path): + os.remove(file_path) + util.log(1, f"已删除文件: {file_path}") + except Exception as e: + util.log(1, f"删除文件时出错: {file_path}, 错误: {str(e)}") + + # 删除memory_stream目录(如果存在) + memory_stream_dir = os.path.join(memory_dir, "memory_stream") + if os.path.exists(memory_stream_dir): + import shutil + try: + shutil.rmtree(memory_stream_dir) + util.log(1, f"已删除目录: {memory_stream_dir}") + except Exception as e: + util.log(1, f"删除目录时出错: {memory_stream_dir}, 错误: {str(e)}") + + # 创建一个标记文件,表示记忆已被清除,防止退出时重新保存 + with open(os.path.join(memory_dir, ".memory_cleared"), "w") as f: + f.write("Memory has been cleared. Do not save on exit.") + + # 设置记忆清除标记 + try: + # 导入并修改nlp_cognitive_stream模块中的保存函数 + from llm.nlp_cognitive_stream import set_memory_cleared_flag, clear_agent_memory + + # 设置记忆清除标记 + set_memory_cleared_flag(True) + + # 清除内存中已加载的记忆 + clear_agent_memory() + + util.log(1, "已同时清除文件存储和内存中的记忆") + except Exception as e: + util.log(1, f"清除内存中记忆时出错: {str(e)}") + + util.log(1, "记忆已清除,需要重启应用才能生效") + return jsonify({'success': True, 'message': '记忆已清除,请重启应用使更改生效'}), 200 + except Exception as e: + util.log(1, f"清除记忆时出错: {str(e)}") + return jsonify({'success': False, 'message': f'清除记忆时出错: {str(e)}'}), 500 + +@app.route('/api/submit', methods=['POST']) +def submit_data(): + """处理提交的表单数据并将其添加到Agent的记忆中""" + try: + # 接收JSON格式的表单数据 + data = request.json + + if not data or 'dimensions' not in data: + return jsonify({'status': 'error', 'message': '数据格式不正确'}), 400 + + # 导入需要的函数 + from llm.nlp_cognitive_stream import get_current_time_step, save_agent_memory, create_agent + + # 确保Fay的agent已经初始化 + global fay_agent + if fay_agent is None: + fay_agent = create_agent() + + # 确保embeddings不为None + if fay_agent.memory_stream.embeddings is None: + fay_agent.memory_stream.embeddings = {} + + # 使用全局函数获取时间步 + time_step = get_current_time_step() + 1 + + # 处理各维度数据 + for dimension_name, dimension_qa in data['dimensions'].items(): + # 为每个维度创建一个摘要记忆 + dimension_summary = f"决策分析维度: {dimension_name}\n" + + for qa_pair in dimension_qa: + question = qa_pair.get('问题', '') + answer = qa_pair.get('回答', '') + dimension_summary += f"问题: {question}\n回答: {answer}\n\n" + + # 将维度摘要添加到Agent的记忆中 + fay_agent.remember(dimension_summary, time_step=time_step) + time_step += 1 + + # 添加一个总结记忆 + global instruction # 明确声明使用全局变量 + summary = f"[系统指令] 基于以上决策分析,你的人格已被重新定义。" + if 'instruction' in globals() and instruction: + summary += f" 你需要遵循以下指令:{instruction}" + + fay_agent.remember(summary, time_step=time_step) + + # 保存记忆 + save_agent_memory() + + # 设置关闭标志,让服务器在响应后关闭 + global shutdown_flag + shutdown_flag = True + + # 返回响应,添加自动关闭窗口的JavaScript代码 + return jsonify({ + 'status': 'success', + 'message': '决策分析数据已克隆到记忆中,请关闭页面并重启Fay', + 'redirect': 'http://localhost:8080/setting', + 'closeWindow': True # 添加标志,指示前端关闭窗口 + }) + except Exception as e: + import traceback + error_details = traceback.format_exc() + print(f"处理决策分析数据时出错: {str(e)}\n{error_details}") + return jsonify({'status': 'error', 'message': f'处理数据时出错: {str(e)}'}), 500 + +@app.route('/api/shutdown', methods=['POST']) +def shutdown(): + """手动关闭服务器的API""" + shutdown_server() + return jsonify({'status': 'success', 'message': '服务器正在关闭'}) + +@app.route('/static/') +def serve_static(filename): + # 提供静态文件 + return send_from_directory('static', filename) + +@app.route('/templates/') +def serve_template(filename): + # 提供模板文件(仅用于调试) + return send_from_directory('templates', filename) + +# 全局变量,用于控制服务器关闭 +shutdown_flag = False + +# 检查是否请求关闭服务器 +def is_shutdown_requested(): + global shutdown_flag + return shutdown_flag + +# 设置应用程序,复制必要的文件到正确的位置 +def setup(): + setup_directories() + + # 确保decision_interview.html存在于templates目录 + template_source = os.path.join(os.path.dirname(__file__), 'decision_interview.html') + template_dest = os.path.join(os.path.dirname(__file__), 'templates', 'decision_interview.html') + + if os.path.exists(template_source) and not os.path.exists(template_dest): + import shutil + shutil.copy2(template_source, template_dest) + +# 启动决策分析服务 +def start_genagents_server(instruction_text="", port=None, host=None, debug=None): + global instruction, genagents_port, genagents_host, genagents_debug, shutdown_flag + + # 重置关闭标志 + shutdown_flag = False + + # 设置指令 + if instruction_text: + instruction = instruction_text + else: + load_instruction() + + # 设置服务器参数 + if port is not None: + genagents_port = port + if host is not None: + genagents_host = host + if debug is not None: + genagents_debug = debug + + # 设置应用 + setup() + + # 返回应用实例,但不启动 + return app + +# 直接运行时启动服务器 +if __name__ == '__main__': + setup() # 确保所有必要的目录和文件都存在 + load_instruction() # 加载指令 + print(f"启动Flask服务器,请访问 http://127.0.0.1:{genagents_port}/ 打开页面") + + # 使用Werkzeug的服务器,并添加关闭检查 + from werkzeug.serving import make_server + + # 创建服务器 + server = make_server(genagents_host, genagents_port, app) + + # 启动服务器,但在单独的线程中运行,以便我们可以检查shutdown_flag + import threading + + def run_server(): + server.serve_forever() + + server_thread = threading.Thread(target=run_server) + server_thread.daemon = True + server_thread.start() + + # 主线程检查shutdown_flag + try: + while not is_shutdown_requested(): + time.sleep(1) + except KeyboardInterrupt: + print("接收到键盘中断,正在关闭服务器...") + finally: + print("正在关闭服务器...") + server.shutdown() \ No newline at end of file diff --git a/genagents/instruction.json b/genagents/instruction.json new file mode 100644 index 0000000..ec4c740 --- /dev/null +++ b/genagents/instruction.json @@ -0,0 +1 @@ +{"instruction": "你是一个严肃认真的程序员"} \ No newline at end of file diff --git a/genagents/modules/interaction.py b/genagents/modules/interaction.py new file mode 100644 index 0000000..68f5e9b --- /dev/null +++ b/genagents/modules/interaction.py @@ -0,0 +1,279 @@ +import math +import sys +import datetime +import random +import string +import re +import os + +from numpy import dot +from numpy.linalg import norm + +from simulation_engine.settings import * +from simulation_engine.global_methods import * +from simulation_engine.gpt_structure import * +from simulation_engine.llm_json_parser import * +from utils import util + +def _main_agent_desc(agent, anchor): + agent_desc = "" + agent_desc += f"Self description: {agent.get_self_description()}\n==\n" + agent_desc += f"Other observations about the subject:\n\n" + + retrieved = agent.memory_stream.retrieve([anchor], 0, n_count=120) + if len(retrieved) == 0: + return agent_desc + nodes = list(retrieved.values())[0] + for node in nodes: + agent_desc += f"{node.content}\n" + return agent_desc + + +def _utterance_agent_desc(agent, anchor): + agent_desc = "" + agent_desc += f"Self description: {agent.get_self_description()}\n==\n" + agent_desc += f"Other observations about the subject:\n\n" + + retrieved = agent.memory_stream.retrieve([anchor], 0, n_count=120) + if len(retrieved) == 0: + return agent_desc + + nodes = list(retrieved.values())[0] + for node in nodes: + agent_desc += f"{node.content}\n" + return agent_desc + + +def run_gpt_generate_categorical_resp( + agent_desc, + questions, + prompt_version="1", + gpt_version="GPT4o", + verbose=False): + + def create_prompt_input(agent_desc, questions): + str_questions = "" + for key, val in questions.items(): + str_questions += f"Q: {key}\n" + str_questions += f"Option: {val}\n\n" + str_questions = str_questions.strip() + return [agent_desc, str_questions] + + def _func_clean_up(gpt_response, prompt=""): + responses, reasonings = extract_first_json_dict_categorical(gpt_response) + ret = {"responses": responses, "reasonings": reasonings} + return ret + + def _get_fail_safe(): + return None + + if len(questions) > 1: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/interaction/categorical_resp/batch_v1.txt" + else: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/interaction/categorical_resp/singular_v1.txt" + + prompt_input = create_prompt_input(agent_desc, questions) + fail_safe = _get_fail_safe() + + output, prompt, prompt_input, fail_safe = chat_safe_generate( + prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, + _func_clean_up, verbose) + + return output, [output, prompt, prompt_input, fail_safe] + + +def categorical_resp(agent, questions): + anchor = " ".join(list(questions.keys())) + agent_desc = _main_agent_desc(agent, anchor) + return run_gpt_generate_categorical_resp( + agent_desc, questions, "1", LLM_VERS)[0] + + +def run_gpt_generate_numerical_resp( + agent_desc, + questions, + float_resp, + prompt_version="1", + gpt_version="GPT4o", + verbose=False): + + def create_prompt_input(agent_desc, questions, float_resp): + str_questions = "" + for key, val in questions.items(): + str_questions += f"Q: {key}\n" + str_questions += f"Range: {str(val)}\n\n" + str_questions = str_questions.strip() + + if float_resp: + resp_type = "float" + else: + resp_type = "integer" + return [agent_desc, str_questions, resp_type] + + def _func_clean_up(gpt_response, prompt=""): + responses, reasonings = extract_first_json_dict_numerical(gpt_response) + ret = {"responses": responses, "reasonings": reasonings} + return ret + + def _get_fail_safe(): + return None + + if len(questions) > 1: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/interaction/numerical_resp/batch_v1.txt" + else: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/interaction/numerical_resp/singular_v1.txt" + + prompt_input = create_prompt_input(agent_desc, questions, float_resp) + fail_safe = _get_fail_safe() + + output, prompt, prompt_input, fail_safe = chat_safe_generate( + prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, + _func_clean_up, verbose) + + if float_resp: + output["responses"] = [float(i) for i in output["responses"]] + else: + output["responses"] = [int(i) for i in output["responses"]] + + return output, [output, prompt, prompt_input, fail_safe] + + +def numerical_resp(agent, questions, float_resp): + anchor = " ".join(list(questions.keys())) + agent_desc = _main_agent_desc(agent, anchor) + return run_gpt_generate_numerical_resp( + agent_desc, questions, float_resp, "1", LLM_VERS)[0] + + +def run_gpt_generate_utterance( + agent_desc, + str_dialogue, + context, + prompt_version="1", + gpt_version="GPT4o", + verbose=False): + """ + 运行GPT生成对话回复 + + 参数: + agent_desc: 代理描述 + str_dialogue: 对话字符串 + context: 上下文 + prompt_version: 提示版本,默认为"1" + gpt_version: GPT版本,默认为"GPT4o" + verbose: 是否输出详细信息,默认为False + + 返回: + output: 生成的回复 + 详细信息: [output, prompt, prompt_input, fail_safe] + """ + def create_prompt_input(agent_desc, str_dialogue, context): + return [agent_desc, context, str_dialogue] + + def _func_clean_up(gpt_response, prompt=""): + try: + # 确保gpt_response是字符串类型 + if not isinstance(gpt_response, str): + util.log(1, f"GPT响应不是字符串类型: {type(gpt_response)}") + return "抱歉,我现在太忙了,休息一会,请稍后再试。" + + # 提取JSON字典 + json_dict = extract_first_json_dict(gpt_response) + if json_dict is None or "utterance" not in json_dict: + util.log(1, f"无法从GPT响应中提取有效的JSON或缺少utterance字段: {gpt_response[:100]}...") + return "抱歉,我现在太忙了,休息一会,请稍后再试。" + + # 返回utterance字段 + return json_dict["utterance"] + except Exception as e: + util.log(1, f"处理GPT响应时出错: {str(e)}") + return "抱歉,我现在太忙了,休息一会,请稍后再试。" + + def _get_fail_safe(): + return "对不起,我现在无法回答这个问题。" + + # 确保模板文件路径正确 + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/interaction/utternace/utterance_v1.txt" + if not os.path.exists(prompt_lib_file): + util.log(1, f"模板文件不存在: {prompt_lib_file}") + return "抱歉,我现在太忙了,休息一会,请稍后再试。", ["抱歉,我现在太忙了,休息一会,请稍后再试。", "", [], ""] + + prompt_input = create_prompt_input(agent_desc, str_dialogue, context) + fail_safe = _get_fail_safe() + + # 调用chat_safe_generate函数生成回复 + try: + output, prompt, prompt_input, fail_safe = chat_safe_generate( + prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, + _func_clean_up, verbose) + + # 确保输出是字符串类型 + if output is None: + util.log(1, "GPT生成的输出为None") + output = fail_safe + except Exception as e: + util.log(1, f"调用chat_safe_generate时出错: {str(e)}") + output = fail_safe + prompt = "" + prompt_input = [] + + return output, [output, prompt, prompt_input, fail_safe] + + +def utterance(agent, curr_dialogue, context): + str_dialogue = "" + for row in curr_dialogue: + str_dialogue += f"[{row[0]}]: {row[1]}\n" + str_dialogue += f"[{agent.get_fullname()}]: [Fill in]\n" + + anchor = str_dialogue + agent_desc = _utterance_agent_desc(agent, anchor) + return run_gpt_generate_utterance( + agent_desc, str_dialogue, context, "1", LLM_VERS, False)[0] + +## Ask function. +def run_gpt_generate_ask( + agent_desc, + questions, + prompt_version="1", + gpt_version="GPT4o", + verbose=False): + + def create_prompt_input(agent_desc, questions): + str_questions = "" + i = 1 + for q in questions: + str_questions += f"Q{i}: {q['question']}\n" + str_questions += f"Type: {q['response-type']}\n" + if q['response-type'] == 'categorical': + str_questions += f"Options: {', '.join(q['response-options'])}\n" + elif q['response-type'] in ['int', 'float']: + str_questions += f"Range: {q['response-scale']}\n" + elif q['response-type'] == 'open': + char_limit = q.get('response-char-limit', 200) + str_questions += f"Character Limit: {char_limit}\n" + str_questions += "\n" + i += 1 + return [agent_desc, str_questions.strip()] + + def _func_clean_up(gpt_response, prompt=""): + responses = extract_first_json_dict(gpt_response) + return responses + + def _get_fail_safe(): + return None + + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/interaction/ask/batch_v1.txt" + + prompt_input = create_prompt_input(agent_desc, questions) + fail_safe = _get_fail_safe() + + output, prompt, prompt_input, fail_safe = chat_safe_generate( + prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, + _func_clean_up, verbose) + + return output, [output, prompt, prompt_input, fail_safe] + + + + diff --git a/genagents/modules/memory_stream.py b/genagents/modules/memory_stream.py new file mode 100644 index 0000000..7f25b92 --- /dev/null +++ b/genagents/modules/memory_stream.py @@ -0,0 +1,607 @@ +import math +import sys +import datetime +import random +import string +import re + +from numpy import dot +from numpy.linalg import norm + +from simulation_engine.settings import * +from simulation_engine.global_methods import * +from simulation_engine.gpt_structure import * +from simulation_engine.llm_json_parser import * + + +def run_gpt_generate_importance( + records, + prompt_version="1", + gpt_version="GPT4o", + verbose=False): + + def create_prompt_input(records): + records_str = "" + for count, r in enumerate(records): + records_str += f"Item {str(count+1)}:\n" + records_str += f"{r}\n" + return [records_str] + + def _func_clean_up(gpt_response, prompt=""): + gpt_response = extract_first_json_dict(gpt_response) + # 处理gpt_response为None的情况 + if gpt_response is None: + print("警告: extract_first_json_dict返回None,使用默认值") + return [50] # 返回默认重要性分数 + return list(gpt_response.values()) + + def _get_fail_safe(): + return 25 + + if len(records) > 1: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/batch_v1.txt" + else: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/importance_score/singular_v1.txt" + + prompt_input = create_prompt_input(records) + fail_safe = _get_fail_safe() + + output, prompt, prompt_input, fail_safe = chat_safe_generate( + prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, + _func_clean_up, verbose) + + return output, [output, prompt, prompt_input, fail_safe] + + +def generate_importance_score(records): + return run_gpt_generate_importance(records, "1", LLM_VERS)[0] + + +def run_gpt_generate_reflection( + records, + anchor, + reflection_count, + prompt_version="1", + gpt_version="GPT4o", + verbose=False): + + def create_prompt_input(records, anchor, reflection_count): + records_str = "" + for count, r in enumerate(records): + records_str += f"Item {str(count+1)}:\n" + records_str += f"{r}\n" + return [records_str, reflection_count, anchor] + + def _func_clean_up(gpt_response, prompt=""): + return extract_first_json_dict(gpt_response)["reflection"] + + def _get_fail_safe(): + return [] + + if reflection_count > 1: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/batch_v1.txt" + else: + prompt_lib_file = f"{LLM_PROMPT_DIR}/generative_agent/memory_stream/reflection/singular_v1.txt" + + prompt_input = create_prompt_input(records, anchor, reflection_count) + fail_safe = _get_fail_safe() + + output, prompt, prompt_input, fail_safe = chat_safe_generate( + prompt_input, prompt_lib_file, gpt_version, 1, fail_safe, + _func_clean_up, verbose) + + return output, [output, prompt, prompt_input, fail_safe] + + +def generate_reflection(records, anchor, reflection_count): + records = [i.content for i in records] + return run_gpt_generate_reflection(records, anchor, reflection_count, "1", + LLM_VERS)[0] + + +# ############################################################################## +# ### HELPER FUNCTIONS FOR GENERATIVE AGENTS ### +# ############################################################################## + +def get_random_str(length): + """ + Generates a random string of alphanumeric characters with the specified + length. This function creates a random string by selecting characters from + the set of uppercase letters, lowercase letters, and digits. The length of + the random string is determined by the 'length' parameter. + + Parameters: + length (int): The desired length of the random string. + Returns: + random_string: A randomly generated string of the specified length. + + Example: + >>> get_random_str(8) + 'aB3R7tQ2' + """ + characters = string.ascii_letters + string.digits + random_string = ''.join(random.choice(characters) for _ in range(length)) + return random_string + + +def cos_sim(a, b): + """ + This function calculates the cosine similarity between two input vectors + 'a' and 'b'. Cosine similarity is a measure of similarity between two + non-zero vectors of an inner product space that measures the cosine + of the angle between them. + + Parameters: + a: 1-D array object + b: 1-D array object + Returns: + A scalar value representing the cosine similarity between the input + vectors 'a' and 'b'. + + Example: + >>> a = [0.3, 0.2, 0.5] + >>> b = [0.2, 0.2, 0.5] + >>> cos_sim(a, b) + """ + return dot(a, b)/(norm(a)*norm(b)) + + +def normalize_dict_floats(d, target_min, target_max): + """ + This function normalizes the float values of a given dictionary 'd' between + a target minimum and maximum value. The normalization is done by scaling the + values to the target range while maintaining the same relative proportions + between the original values. + + Parameters: + d: Dictionary. The input dictionary whose float values need to be + normalized. + target_min: Integer or float. The minimum value to which the original + values should be scaled. + target_max: Integer or float. The maximum value to which the original + values should be scaled. + Returns: + d: A new dictionary with the same keys as the input but with the float + values normalized between the target_min and target_max. + + Example: + >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8} + >>> target_min = -5 + >>> target_max = 5 + >>> normalize_dict_floats(d, target_min, target_max) + """ + # 检查字典是否为None或为空 + if d is None: + print("警告: normalize_dict_floats接收到None字典") + return {} + + if not d: + print("警告: normalize_dict_floats接收到空字典") + return {} + + try: + min_val = min(val for val in d.values()) + max_val = max(val for val in d.values()) + range_val = max_val - min_val + + if range_val == 0: + for key, val in d.items(): + d[key] = (target_max - target_min)/2 + else: + for key, val in d.items(): + d[key] = ((val - min_val) * (target_max - target_min) + / range_val + target_min) + return d + except Exception as e: + print(f"normalize_dict_floats处理字典时出错: {str(e)}") + # 返回原始字典,避免处理失败 + return d + + +def top_highest_x_values(d, x): + """ + This function takes a dictionary 'd' and an integer 'x' as input, and + returns a new dictionary containing the top 'x' key-value pairs from the + input dictionary 'd' with the highest values. + + Parameters: + d: Dictionary. The input dictionary from which the top 'x' key-value pairs + with the highest values are to be extracted. + x: Integer. The number of top key-value pairs with the highest values to + be extracted from the input dictionary. + Returns: + A new dictionary containing the top 'x' key-value pairs from the input + dictionary 'd' with the highest values. + + Example: + >>> d = {'a':1.2,'b':3.4,'c':5.6,'d':7.8} + >>> x = 3 + >>> top_highest_x_values(d, x) + """ + top_v = dict(sorted(d.items(), + key=lambda item: item[1], + reverse=True)[:x]) + return top_v + + +def extract_recency(seq_nodes): + """ + Gets the current Persona object and a list of nodes that are in a + chronological order, and outputs a dictionary that has the recency score + calculated. + + Parameters: + nodes: A list of Node object in a chronological order. + Returns: + recency_out: A dictionary whose keys are the node.node_id and whose values + are the float that represents the recency score. + """ + # 检查seq_nodes是否为None或为空 + if seq_nodes is None: + print("警告: extract_recency接收到None节点列表") + return {} + + if not seq_nodes: + print("警告: extract_recency接收到空节点列表") + return {} + + try: + # 确保所有的last_retrieved都是整数类型 + normalized_timestamps = [] + for node in seq_nodes: + if node is None: + print("警告: 节点为None,跳过") + continue + + if not hasattr(node, 'last_retrieved'): + print(f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0") + normalized_timestamps.append(0) + continue + + if isinstance(node.last_retrieved, str): + try: + normalized_timestamps.append(int(node.last_retrieved)) + except ValueError: + # 如果无法转换为整数,使用0作为默认值 + normalized_timestamps.append(0) + else: + normalized_timestamps.append(node.last_retrieved) + + if not normalized_timestamps: + return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} + + max_timestep = max(normalized_timestamps) + + recency_decay = 0.99 + recency_out = dict() + for count, node in enumerate(seq_nodes): + if node is None or not hasattr(node, 'node_id') or not hasattr(node, 'last_retrieved'): + continue + + # 获取标准化后的时间戳 + try: + last_retrieved = normalized_timestamps[count] + recency_out[node.node_id] = (recency_decay + ** (max_timestep - last_retrieved)) + except Exception as e: + print(f"计算节点 {node.node_id} 的recency时出错: {str(e)}") + # 使用默认值 + recency_out[node.node_id] = 1.0 + + return recency_out + except Exception as e: + print(f"extract_recency处理节点列表时出错: {str(e)}") + # 返回一个默认字典 + return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} + + +def extract_importance(seq_nodes): + """ + Gets the current Persona object and a list of nodes that are in a + chronological order, and outputs a dictionary that has the importance score + calculated. + + Parameters: + seq_nodes: A list of Node object in a chronological order. + Returns: + importance_out: A dictionary whose keys are the node.node_id and whose + values are the float that represents the importance score. + """ + # 检查seq_nodes是否为None或为空 + if seq_nodes is None: + print("警告: extract_importance接收到None节点列表") + return {} + + if not seq_nodes: + print("警告: extract_importance接收到空节点列表") + return {} + + try: + importance_out = dict() + for count, node in enumerate(seq_nodes): + if node is None: + print("警告: 节点为None,跳过") + continue + + if not hasattr(node, 'node_id') or not hasattr(node, 'importance'): + print(f"警告: 节点缺少必要属性,跳过") + continue + + # 确保importance是数值类型 + if isinstance(node.importance, str): + try: + importance_out[node.node_id] = float(node.importance) + except ValueError: + # 如果无法转换为数值,使用默认值 + print(f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值") + importance_out[node.node_id] = 50.0 + else: + importance_out[node.node_id] = node.importance + + return importance_out + except Exception as e: + print(f"extract_importance处理节点列表时出错: {str(e)}") + # 返回一个默认字典 + return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} + + +def extract_relevance(seq_nodes, embeddings, focal_pt): + """ + Gets the current Persona object, a list of seq_nodes that are in a + chronological order, and the focal_pt string and outputs a dictionary + that has the relevance score calculated. + + Parameters: + seq_nodes: A list of Node object in a chronological order. + focal_pt: A string describing the current thought of revent of focus. + Returns: + relevance_out: A dictionary whose keys are the node.node_id and whose + values are the float that represents the relevance score. + """ + # 确保embeddings不为None + if embeddings is None: + print("警告: embeddings为None,使用空字典代替") + embeddings = {} + + try: + focal_embedding = get_text_embedding(focal_pt) + except Exception as e: + print(f"获取焦点嵌入向量时出错: {str(e)}") + # 如果无法获取嵌入向量,返回默认值 + return {node.node_id: 0.5 for node in seq_nodes} + + relevance_out = dict() + for count, node in enumerate(seq_nodes): + try: + # 检查节点内容是否在embeddings中 + if node.content in embeddings: + node_embedding = embeddings[node.content] + # 计算余弦相似度 + relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding) + else: + # 如果没有对应的嵌入向量,使用默认值 + relevance_out[node.node_id] = 0.5 + except Exception as e: + print(f"计算节点 {node.node_id} 的相关性时出错: {str(e)}") + # 如果计算过程中出错,使用默认值 + relevance_out[node.node_id] = 0.5 + + return relevance_out + + +# ############################################################################## +# ### CONCEPT NODE ### +# ############################################################################## + +class ConceptNode: + def __init__(self, node_dict): + # Loading the content of a memory node in the memory stream. + self.node_id = node_dict["node_id"] + self.node_type = node_dict["node_type"] + self.content = node_dict["content"] + self.importance = node_dict["importance"] + # 确保created是整数类型 + self.created = int(node_dict["created"]) if node_dict["created"] is not None else 0 + # 确保last_retrieved是整数类型 + self.last_retrieved = int(node_dict["last_retrieved"]) if node_dict["last_retrieved"] is not None else 0 + self.pointer_id = node_dict["pointer_id"] + + + def package(self): + """ + Packaging the ConceptNode + + Parameters: + None + Returns: + packaged dictionary + """ + curr_package = {} + curr_package["node_id"] = self.node_id + curr_package["node_type"] = self.node_type + curr_package["content"] = self.content + curr_package["importance"] = self.importance + curr_package["created"] = self.created + curr_package["last_retrieved"] = self.last_retrieved + curr_package["pointer_id"] = self.pointer_id + + return curr_package + + +# ############################################################################## +# ### MEMORY STREAM ### +# ############################################################################## + +class MemoryStream: + def __init__(self, nodes, embeddings): + # Loading the memory stream for the agent. + self.seq_nodes = [] + self.id_to_node = dict() + for node in nodes: + new_node = ConceptNode(node) + self.seq_nodes += [new_node] + self.id_to_node[new_node.node_id] = new_node + + self.embeddings = embeddings + + + def count_observations(self): + """ + Counting the number of observations (basically, the number of all nodes in + memory stream except for the reflections) + + Parameters: + None + Returns: + Count + """ + count = 0 + for i in self.seq_nodes: + if i.node_type == "observation": + count += 1 + return count + + + def retrieve(self, focal_points, time_step, n_count=120, curr_filter="all", + hp=[0, 1, 0.5], stateless=False, verbose=False): + """ + Retrieve elements from the memory stream. + + Parameters: + focal_points: This is the query sentence. It is in a list form where + the elemnts of the list are the query sentences. + time_step: Current time_step + n_count: The number of nodes that we want to retrieve. + curr_filter: Filtering the node.type that we want to retrieve. + Acceptable values are 'all', 'reflection', 'observation' + hp: Hyperparameter for [recency_w, relevance_w, importance_w] + verbose: verbose + Returns: + retrieved: A dictionary whose keys are a focal_pt query str, and whose + values are a list of nodes that are retrieved for that query str. + """ + curr_nodes = [] + + # If the memory stream is empty, we return an empty dictionary. + if len(self.seq_nodes) == 0: + return dict() + + # Filtering for the desired node type. curr_filter can be one of the three + # elements: 'all', 'reflection', 'observation' + if curr_filter == "all": + curr_nodes = self.seq_nodes + else: + for curr_node in self.seq_nodes: + if curr_node.node_type == curr_filter: + curr_nodes += [curr_node] + + # 确保embeddings不为None + if self.embeddings is None: + print("警告: 在retrieve方法中,embeddings为None,初始化为空字典") + self.embeddings = {} + + # is the main dictionary that we are returning + retrieved = dict() + for focal_pt in focal_points: + # Calculating the component dictionaries and normalizing them. + x = extract_recency(curr_nodes) + recency_out = normalize_dict_floats(x, 0, 1) + x = extract_importance(curr_nodes) + importance_out = normalize_dict_floats(x, 0, 1) + x = extract_relevance(curr_nodes, self.embeddings, focal_pt) + relevance_out = normalize_dict_floats(x, 0, 1) + + # Computing the final scores that combines the component values. + master_out = dict() + for key in recency_out.keys(): + recency_w = hp[0] + relevance_w = hp[1] + importance_w = hp[2] + master_out[key] = (recency_w * recency_out[key] + + relevance_w * relevance_out[key] + + importance_w * importance_out[key]) + + if verbose: + master_out = top_highest_x_values(master_out, len(master_out.keys())) + for key, val in master_out.items(): + print (self.id_to_node[key].content, val) + print (recency_w*recency_out[key]*1, + relevance_w*relevance_out[key]*1, + importance_w*importance_out[key]*1) + + # Extracting the highest x values. + # has the key of node.id and value of float. Once we get + # the highest x values, we want to translate the node.id into nodes + # and return the list of nodes. + master_out = top_highest_x_values(master_out, n_count) + master_nodes = [self.id_to_node[key] for key in list(master_out.keys())] + + # **Sort the master_nodes list by last_retrieved in descending order** + master_nodes = sorted(master_nodes, + key=lambda node: node.created, reverse=False) + + # We do not want to update the last retrieved time_step for these nodes + # if we are in a stateless mode. + if not stateless: + for n in master_nodes: + n.retrieved_time_step = time_step + + retrieved[focal_pt] = master_nodes + + return retrieved + + + def _add_node(self, time_step, node_type, content, importance, pointer_id): + """ + Adding a new node to the memory stream. + + Parameters: + time_step: Current time_step + node_type: type of node -- it's either reflection, observation + content: the str content of the memory record + importance: int score of the importance score + pointer_id: the str of the parent node + Returns: + retrieved: A dictionary whose keys are a focal_pt query str, and whose + values are a list of nodes that are retrieved for that query str. + """ + node_dict = dict() + node_dict["node_id"] = len(self.seq_nodes) + node_dict["node_type"] = node_type + node_dict["content"] = content + node_dict["importance"] = importance + node_dict["created"] = time_step + node_dict["last_retrieved"] = time_step + node_dict["pointer_id"] = pointer_id + new_node = ConceptNode(node_dict) + + self.seq_nodes += [new_node] + self.id_to_node[new_node.node_id] = new_node + + # 确保embeddings不为None + if self.embeddings is None: + self.embeddings = {} + + try: + self.embeddings[content] = get_text_embedding(content) + except Exception as e: + print(f"获取文本嵌入时出错: {str(e)}") + # 如果获取嵌入失败,使用空列表代替 + self.embeddings[content] = [] + + + def remember(self, content, time_step=0): + score = generate_importance_score([content])[0] + self._add_node(time_step, "observation", content, score, None) + + + def reflect(self, anchor, reflection_count=5, + retrieval_count=120, time_step=0): + records = self.retrieve([anchor], time_step, retrieval_count)[anchor] + record_ids = [i.node_id for i in records] + reflections = generate_reflection(records, anchor, reflection_count) + scores = generate_importance_score(reflections) + + for count, reflection in enumerate(reflections): + self._add_node(time_step, "reflection", reflections[count], + scores[count], record_ids) diff --git a/genagents/templates/decision_interview.html b/genagents/templates/decision_interview.html new file mode 100644 index 0000000..9982a16 --- /dev/null +++ b/genagents/templates/decision_interview.html @@ -0,0 +1,627 @@ + + + + + + 价值观及经历克隆 + + + + + +
+
+
+
+

价值观及经历克隆

+

本采访将从多个维度分析你的决策风格和价值观

+

回答越详细,创建的代理就越接近你的真实决策模式

+
+ + +
+

克隆要求

+

{{ instruction }}

+
+ +
+ +
+
+

核心价值观维度

+

这个维度分析你在核心价值观方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+

风险态度维度

+

这个维度分析你在风险态度方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+

时间偏好维度

+

这个维度分析你在时间偏好方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+

自主性与集体性维度

+

这个维度分析你在自主性与集体性方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+

理性与情感维度

+

这个维度分析你在理性与情感方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+

资源分配维度

+

这个维度分析你在资源分配方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+

道德伦理维度

+

这个维度分析你在道德伦理方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+

决策过程维度

+

这个维度分析你在决策过程方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+

成长与变化维度

+

这个维度分析你在成长与变化方面的特点和偏好

+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+
+
+
+
+
+
+
+ 0% 完成 +
+
+ +
+
+
+
+
+
+
+
+
+
+ + + + + + + + + + + + \ No newline at end of file diff --git a/gui/flask_server.py b/gui/flask_server.py index 2f3ddca..50e62d2 100644 --- a/gui/flask_server.py +++ b/gui/flask_server.py @@ -27,6 +27,11 @@ from flask_httpauth import HTTPBasicAuth from core import qa_service from core import stream_manager +# 全局变量,用于跟踪当前的genagents服务器 +genagents_server = None +genagents_thread = None +monitor_thread = None + __app = Flask(__name__) # 禁用 Flask 默认日志 __app.logger.disabled = True @@ -327,7 +332,7 @@ def api_send_v1_chat_completions(): util.printInfo(1, username, '[文字沟通接口]{}'.format(interact.data["msg"]), time.time()) text = fay_booter.feiFei.on_interact(interact) - if config_util.key_chat_module == 'gpt_stream': + if config_util.key_chat_module.endswith('_stream'): return gpt_stream_response(member_db.new_instance().find_user(username)) elif model == 'fay-streaming': return stream_response(text) @@ -346,7 +351,7 @@ def api_get_Member_list(): except Exception as e: return jsonify({'list': [], 'message': f'获取成员列表时出错: {e}'}), 500 -@__app.route('/api/get_run_status', methods=['post']) +@__app.route('/api/get-run-status', methods=['post']) def api_get_run_status(): # 获取运行状态 try: @@ -355,7 +360,7 @@ def api_get_run_status(): except Exception as e: return jsonify({'status': False, 'message': f'获取运行状态时出错: {e}'}), 500 -@__app.route('/api/adopt_msg', methods=['POST']) +@__app.route('/api/adopt-msg', methods=['POST']) def adopt_msg(): # 采纳消息 data = request.get_json() @@ -531,7 +536,7 @@ def serve_gif(filename): return jsonify({'error': '文件未找到'}), 404 #打招呼 -@__app.route('/to_greet', methods=['POST']) +@__app.route('/to-greet', methods=['POST']) def to_greet(): data = request.get_json() username = data.get('username', 'User') @@ -541,7 +546,7 @@ def to_greet(): return jsonify({'status': 'success', 'data': text, 'msg': '已进行打招呼'}), 200 #唤醒:在普通唤醒模式,进行大屏交互才有意义 -@__app.route('/to_wake', methods=['POST']) +@__app.route('/to-wake', methods=['POST']) def to_wake(): data = request.get_json() username = data.get('username', 'User') @@ -550,7 +555,7 @@ def to_wake(): return jsonify({'status': 'success', 'msg': '已唤醒'}), 200 #打断 -@__app.route('/to_stop_talking', methods=['POST']) +@__app.route('/to-stop-talking', methods=['POST']) def to_stop_talking(): try: data = request.get_json() @@ -572,7 +577,7 @@ def to_stop_talking(): #消息透传接口 -@__app.route('/transparent_pass', methods=['post']) +@__app.route('/transparent-pass', methods=['post']) def transparent_pass(): try: data = request.form.get('data') @@ -592,6 +597,152 @@ def transparent_pass(): except Exception as e: return jsonify({'code': 500, 'message': f'出错: {e}'}), 500 +# 清除记忆API +@__app.route('/api/clear-memory', methods=['POST']) +def api_clear_memory(): + try: + # 获取memory目录路径 + memory_dir = os.path.join(os.getcwd(), "memory") + + # 检查目录是否存在 + if not os.path.exists(memory_dir): + return jsonify({'success': False, 'message': '记忆目录不存在'}), 400 + + # 清空memory目录下的所有文件(保留目录结构) + for root, dirs, files in os.walk(memory_dir): + for file in files: + file_path = os.path.join(root, file) + try: + if os.path.isfile(file_path): + os.remove(file_path) + util.log(1, f"已删除文件: {file_path}") + except Exception as e: + util.log(1, f"删除文件时出错: {file_path}, 错误: {str(e)}") + + # 删除memory_stream目录(如果存在) + memory_stream_dir = os.path.join(memory_dir, "memory_stream") + if os.path.exists(memory_stream_dir): + import shutil + try: + shutil.rmtree(memory_stream_dir) + util.log(1, f"已删除目录: {memory_stream_dir}") + except Exception as e: + util.log(1, f"删除目录时出错: {memory_stream_dir}, 错误: {str(e)}") + + # 创建一个标记文件,表示记忆已被清除,防止退出时重新保存 + with open(os.path.join(memory_dir, ".memory_cleared"), "w") as f: + f.write("Memory has been cleared. Do not save on exit.") + + # 修改fay_booter.py中的保存记忆逻辑 + try: + # 导入并修改nlp_cognitive_stream模块中的保存函数 + from llm.nlp_cognitive_stream import set_memory_cleared_flag + set_memory_cleared_flag(True) + except Exception as e: + util.log(1, f"设置记忆清除标记时出错: {str(e)}") + + util.log(1, "记忆已清除,需要重启应用才能生效") + return jsonify({'success': True, 'message': '记忆已清除,请重启应用使更改生效'}), 200 + except Exception as e: + util.log(1, f"清除记忆时出错: {str(e)}") + return jsonify({'success': False, 'message': f'清除记忆时出错: {str(e)}'}), 500 + +# 启动genagents_flask.py的API +@__app.route('/api/start-genagents', methods=['POST']) +def api_start_genagents(): + try: + # 只有在数字人启动后才能克隆人格 + if not fay_booter.is_running(): + return jsonify({'success': False, 'message': 'Fay未启动,无法启动决策分析'}), 400 + + # 获取克隆要求 + data = request.get_json() + if not data or 'instruction' not in data: + return jsonify({'success': False, 'message': '缺少克隆要求参数'}), 400 + + instruction = data['instruction'] + + # 保存指令到临时文件,供genagents_flask.py读取 + instruction_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'genagents', 'instruction.json') + with open(instruction_file, 'w', encoding='utf-8') as f: + json.dump({'instruction': instruction}, f, ensure_ascii=False) + + # 导入genagents_flask模块 + import sys + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) + from genagents.genagents_flask import start_genagents_server, is_shutdown_requested + from werkzeug.serving import make_server + + # 关闭之前的genagents服务器(如果存在) + global genagents_server, genagents_thread, monitor_thread + if genagents_server is not None: + try: + # 主动关闭之前的服务器 + util.log(1, "关闭之前的决策分析服务...") + genagents_server.shutdown() + # 等待线程结束 + if genagents_thread and genagents_thread.is_alive(): + genagents_thread.join(timeout=2) + if monitor_thread and monitor_thread.is_alive(): + monitor_thread.join(timeout=2) + except Exception as e: + util.log(1, f"关闭之前的决策分析服务时出错: {str(e)}") + + # 清除之前的记忆,确保只保留最新的决策分析 + try: + from llm.nlp_cognitive_stream import clear_agent_memory + util.log(1, "已清除之前的决策分析记忆") + except Exception as e: + util.log(1, f"清除之前的决策分析记忆时出错: {str(e)}") + + # 启动决策分析服务(不启动单独进程,而是返回Flask应用实例) + genagents_app = start_genagents_server(instruction_text=instruction) + + # 创建服务器 + genagents_server = make_server('0.0.0.0', 5001, genagents_app) + + # 在后台线程中启动Flask服务 + import threading + def run_genagents_app(): + try: + # 使用serve_forever而不是app.run + genagents_server.serve_forever() + except Exception as e: + util.log(1, f"决策分析服务运行出错: {str(e)}") + finally: + util.log(1, f"决策分析服务已关闭") + + # 启动监控线程,检查是否需要关闭服务器 + def monitor_shutdown(): + try: + while not is_shutdown_requested(): + time.sleep(1) + util.log(1, f"检测到关闭请求,正在关闭决策分析服务...") + genagents_server.shutdown() + except Exception as e: + util.log(1, f"监控决策分析服务时出错: {str(e)}") + + # 启动服务器线程 + genagents_thread = threading.Thread(target=run_genagents_app) + genagents_thread.daemon = True + genagents_thread.start() + + # 启动监控线程 + monitor_thread = threading.Thread(target=monitor_shutdown) + monitor_thread.daemon = True + monitor_thread.start() + + util.log(1, f"已启动决策分析页面,指令: {instruction}") + + # 返回决策分析页面的URL + return jsonify({ + 'success': True, + 'message': '已启动决策分析页面', + 'url': 'http://127.0.0.1:5001/' + }), 200 + except Exception as e: + util.log(1, f"启动决策分析页面时出错: {str(e)}") + return jsonify({'success': False, 'message': f'启动决策分析页面时出错: {str(e)}'}), 500 def run(): class NullLogHandler: diff --git a/gui/static/js/index.js b/gui/static/js/index.js index 0ef7412..231209f 100644 --- a/gui/static/js/index.js +++ b/gui/static/js/index.js @@ -95,7 +95,7 @@ class FayInterface { } getRunStatus() { - return this.fetchData(`${this.baseApiUrl}/api/get_run_status`, { + return this.fetchData(`${this.baseApiUrl}/api/get-run-status`, { method: 'POST' }); } @@ -446,18 +446,24 @@ new Vue({ this.fayService.getUserList().then((response) => { if (response && response.list) { if (response.list.length == 0){ - info = []; - info[0] = 1; - info[1] = 'User'; - this.userList.push(info) - this.selectUser(info); - }else{ - this.userList = response.list; - if (!this.selectedUser) { - this.selectUser(this.userList[0]); + // 检查是否已经有默认用户 + const defaultUserExists = this.userList.some(user => user[1] === 'User'); + if (!defaultUserExists) { + // 只有在不存在默认用户时才添加 + const info = []; + info[0] = 1; + info[1] = 'User'; + this.userList.push(info); + this.selectUser(info); + console.log('添加默认用户: User'); + } + } else { + this.userList = response.list; + if (!this.selectedUser) { + this.selectUser(this.userList[0]); + } } } - } }); }, startUserListTimer() { @@ -521,12 +527,12 @@ new Vue({ } , adoptText(id) { // 调用采纳接口 -this.fayService.fetchData(`${this.base_url}/api/adopt_msg`, { +this.fayService.fetchData(`${this.base_url}/api/adopt-msg`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ id }) // 发送采纳请求 + body: JSON.stringify({ id }) }) .then((response) => { if (response && response.status === 'success') { diff --git a/gui/static/js/setting.js b/gui/static/js/setting.js index 4b9352e..8b4ff9b 100644 --- a/gui/static/js/setting.js +++ b/gui/static/js/setting.js @@ -74,7 +74,7 @@ class FayInterface { } getRunStatus() { - return this.fetchData(`${this.baseApiUrl}/api/get_run_status`, { + return this.fetchData(`${this.baseApiUrl}/api/get-run-status`, { method: 'POST' }); } @@ -312,7 +312,7 @@ new Vue({ } } } - this.sendSuccessMsg("配置已保存!") + this.sendSuccessMsg("配置已保存!"); }, startLive() { this.liveState = 2 @@ -335,5 +335,93 @@ new Vue({ type: 'success', }); }, + clearMemory() { + this.$confirm('清除记忆操作将删除Fay的所有对话记忆,清除后需要重启应用才能生效,确认继续吗?', '提示', { + confirmButtonText: '确定', + cancelButtonText: '取消', + type: 'warning' + }).then(() => { + // 发送清除记忆请求 + fetch(`${this.host_url}/api/clear-memory`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' } + }) + .then(response => response.json()) + .then(data => { + if (data.success) { + this.sendSuccessMsg(data.message || "记忆已清除,请重启应用使更改生效"); + } else { + this.$notify({ + title: '错误', + message: data.message || '清除记忆失败', + type: 'error' + }); + } + }) + .catch(error => { + this.$notify({ + title: '错误', + message: '清除记忆请求失败', + type: 'error' + }); + }); + }).catch(() => { + // 用户取消操作 + }); + }, + clonePersonality() { + if (this.liveState === 1) { + this.$prompt('请输入克隆要求', '克隆人格', { + confirmButtonText: '确定', + cancelButtonText: '取消', + inputPlaceholder: '请输入克隆要求,例如:你现在是一个活泼开朗的助手...' + }).then(({ value }) => { + if (!value) { + this.$notify({ + title: '提示', + message: '克隆要求不能为空', + type: 'warning' + }); + return; + } + + // 直接启动genagents_flask.py并打开decision_interview.html页面 + fetch(`${this.host_url}/api/start-genagents`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ instruction: value }) + }) + .then(response => response.json()) + .then(data => { + if (data.success) { + // 弹出提示,显示克隆地址,不自动打开 + this.$alert(`决策分析页面已启动,请复制以下链接在新窗口中打开:

${data.url}`, '克隆人格', { + confirmButtonText: '确定', + dangerouslyUseHTMLString: true + }); + } else { + this.$notify({ + title: '错误', + message: data.message || '启动决策分析页面失败', + type: 'error' + }); + } + }) + .catch(error => { + this.$notify({ + title: '错误', + message: '启动决策分析页面请求失败', + type: 'error' + }); + }); + }); + } else { + this.$notify({ + title: '提示', + message: '请先开Fay后再执行此操作', + type: 'warning' + }); + } + }, }, }); diff --git a/gui/templates/setting.html b/gui/templates/setting.html index fb402b5..f676b66 100644 --- a/gui/templates/setting.html +++ b/gui/templates/setting.html @@ -131,20 +131,32 @@ -
- 保存配置 - +
+
+ 保存配置 +
- - 关闭(运行中) - 正在开启... - 正在关闭... - 开启 +
+ + 关闭(运行中) + 正在开启... + 正在关闭... + 开启 +
+ +
diff --git a/llm/agent/agent_service.py b/llm/agent/agent_service.py index b357a7f..45646cc 100644 --- a/llm/agent/agent_service.py +++ b/llm/agent/agent_service.py @@ -15,7 +15,7 @@ agent_running = False # 数据库初始化 def init_db(): - conn = sqlite3.connect('timer.db') + conn = sqlite3.connect('memory/timer.db') cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS timer ( @@ -33,7 +33,7 @@ def init_db(): # 插入测试数据 def insert_test_data(): - conn = sqlite3.connect('timer.db') + conn = sqlite3.connect('memory/timer.db') cursor = conn.cursor() cursor.execute("INSERT INTO timer (time, repeat_rule, content) VALUES (?, ?, ?)", ('16:20', '1010001', 'Meeting Reminder')) conn.commit() @@ -66,7 +66,7 @@ def execute_task(task_time, id, content, uid): if text is not None and id in scheduled_tasks: del scheduled_tasks[id] # 如果不重复,执行后删除记录 - conn = sqlite3.connect('timer.db') + conn = sqlite3.connect('memory/timer.db') cursor = conn.cursor() cursor.execute("DELETE FROM timer WHERE repeat_rule = '0000000' AND id = ?", (id,)) conn.commit() @@ -76,7 +76,7 @@ def execute_task(task_time, id, content, uid): # 30秒扫描一次数据库,当扫描到今天的不存在于定时任务列表的记录,则添加到定时任务列表。执行完的记录从定时任务列表中清除。 def check_and_execute(): while agent_running: - conn = sqlite3.connect('timer.db') + conn = sqlite3.connect('memory/timer.db') cursor = conn.cursor() cursor.execute("SELECT * FROM timer") rows = cursor.fetchall() @@ -99,7 +99,7 @@ def agent_start(): agent_running = True #初始计划 - if not os.path.exists("./timer.db"): + if not os.path.exists("./memory/timer.db"): init_db() content ="""执行任务-> 你是一个数字人,你的责任是陪伴主人生活、工作: diff --git a/llm/agent/tools/DeleteTimer.py b/llm/agent/tools/DeleteTimer.py index d92c01c..739adb7 100644 --- a/llm/agent/tools/DeleteTimer.py +++ b/llm/agent/tools/DeleteTimer.py @@ -20,7 +20,7 @@ class DeleteTimer(BaseTool): return "输入的 ID 无效,必须是数字。" try: - with sqlite3.connect('timer.db') as conn: + with sqlite3.connect('memory/timer.db') as conn: cursor = conn.cursor() cursor.execute("DELETE FROM timer WHERE id = ?", (id,)) conn.commit() diff --git a/llm/agent/tools/MyTimer.py b/llm/agent/tools/MyTimer.py index 64bdd7b..a784e00 100644 --- a/llm/agent/tools/MyTimer.py +++ b/llm/agent/tools/MyTimer.py @@ -38,7 +38,7 @@ class MyTimer(BaseTool, abc.ABC): return "事项内容必须为非空字符串。" # 数据库操作 - conn = sqlite3.connect('timer.db') + conn = sqlite3.connect('memory/timer.db') cursor = conn.cursor() try: cursor.execute("INSERT INTO timer (time, repeat_rule, content, uid) VALUES (?, ?, ?, ?)", (time, repeat_rule, content, self.uid)) diff --git a/llm/agent/tools/QueryTimerDB.py b/llm/agent/tools/QueryTimerDB.py index 3018488..be7286e 100644 --- a/llm/agent/tools/QueryTimerDB.py +++ b/llm/agent/tools/QueryTimerDB.py @@ -20,7 +20,7 @@ class QueryTimerDB(BaseTool, abc.ABC): def _run(self, para) -> str: - conn = sqlite3.connect('timer.db') + conn = sqlite3.connect('memory/timer.db') cursor = conn.cursor() # 执行查询 cursor.execute("SELECT * FROM timer") diff --git a/llm/nlp_cognitive_stream.py b/llm/nlp_cognitive_stream.py new file mode 100644 index 0000000..2c9bdcf --- /dev/null +++ b/llm/nlp_cognitive_stream.py @@ -0,0 +1,634 @@ +import os +import json +import time +import threading +import requests +from utils import util +import utils.config_util as cfg +from genagents.genagents import GenerativeAgent +from genagents.modules.memory_stream import ConceptNode +from core import member_db +from urllib3.exceptions import InsecureRequestWarning +import schedule +from scheduler.thread_manager import MyThread +import datetime +from core import stream_manager + +# 加载配置 +cfg.load_config() + +# 禁用不安全请求警告 +requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) + +agent = None +memory_dir = None +memory_stream_dir = None +agent_lock = threading.RLock() # 使用可重入锁保护agent对象 +memory_cleared = False # 添加记忆清除标记 + +def get_current_time_step(): + """ + 获取当前时间作为time_step + + 返回: + int: 当前时间的时间戳(秒) + """ + global agent + try: + if agent and agent.memory_stream and agent.memory_stream.seq_nodes: + # 如果有记忆节点,则使用最后一个节点的created属性加1 + return int(agent.memory_stream.seq_nodes[-1].created) + 1 + else: + # 如果没有记忆节点或agent未初始化,则使用0 + return 0 + except Exception as e: + util.log(1, f"获取time_step时出错: {str(e)},使用0代替") + return 0 + +# 定时保存记忆的线程 +def memory_scheduler_thread(): + """ + 定时任务线程,运行schedule调度器 + """ + while True: + schedule.run_pending() + time.sleep(60) # 每分钟检查一次是否有定时任务需要执行 + +# 初始化定时保存记忆的任务 +def init_memory_scheduler(): + """ + 初始化定时保存记忆的任务 + """ + global agent + + # 确保agent已经创建 + if agent is None: + util.log(1, '创建代理实例...') + agent = create_agent() + + # 设置每天0点保存记忆 + schedule.every().day.at("00:00").do(save_agent_memory) + + # 设置每天晚上11点执行反思 + schedule.every().day.at("23:30").do(perform_daily_reflection) + + # 启动定时任务线程 + scheduler_thread = MyThread(target=memory_scheduler_thread) + scheduler_thread.start() + + util.log(1, '定时任务已启动:每天0点保存记忆,每天23点执行反思') + +def check_memory_files(): + """ + 检查memory目录及其必要文件是否存在 + + 返回: + memory_dir: memory目录路径 + is_complete: 是否已经存在完整的memory目录结构 + """ + global memory_dir + global memory_stream_dir + + # 获取memory目录路径 + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + memory_dir = os.path.join(base_dir, "memory") + + # 检查memory目录是否存在,不存在则创建 + if not os.path.exists(memory_dir): + os.makedirs(memory_dir) + util.log(1, f"创建memory目录: {memory_dir}") + + # 删除.memory_cleared标记文件(如果存在) + memory_cleared_flag_file = os.path.join(memory_dir, ".memory_cleared") + if os.path.exists(memory_cleared_flag_file): + try: + os.remove(memory_cleared_flag_file) + util.log(1, f"删除记忆清除标记文件: {memory_cleared_flag_file}") + # 重置记忆清除标记 + global memory_cleared + memory_cleared = False + except Exception as e: + util.log(1, f"删除记忆清除标记文件时出错: {str(e)}") + + # 检查meta.json是否存在 + meta_file = os.path.join(memory_dir, "meta.json") + is_complete = os.path.exists(meta_file) + + # 检查memory_stream目录是否存在,不存在则创建 + memory_stream_dir = os.path.join(memory_dir, "memory_stream") + if not os.path.exists(memory_stream_dir): + os.makedirs(memory_stream_dir) + util.log(1, f"创建memory_stream目录: {memory_stream_dir}") + + # 检查必要的文件是否存在 + scratch_path = os.path.join(memory_dir, "scratch.json") + embeddings_path = os.path.join(memory_stream_dir, "embeddings.json") + nodes_path = os.path.join(memory_stream_dir, "nodes.json") + + # 检查文件是否存在且不为空 + is_complete = (os.path.exists(scratch_path) and os.path.getsize(scratch_path) > 2 and + os.path.exists(embeddings_path) and os.path.getsize(embeddings_path) > 2 and + os.path.exists(nodes_path) and os.path.getsize(nodes_path) > 2) + + # 如果文件不存在,创建空的JSON文件 + if not os.path.exists(scratch_path): + with open(scratch_path, 'w', encoding='utf-8') as f: + f.write('{}') + + if not os.path.exists(embeddings_path): + with open(embeddings_path, 'w', encoding='utf-8') as f: + f.write('{}') + + if not os.path.exists(nodes_path): + with open(nodes_path, 'w', encoding='utf-8') as f: + f.write('[]') + + return memory_dir, is_complete + +def create_agent(): + """ + 创建一个GenerativeAgent实例 + + 返回: + agent: GenerativeAgent对象 + """ + global agent + + # 创建代理 + with agent_lock: + if agent is None: + memory_dir, is_exist = check_memory_files() + agent = GenerativeAgent(memory_dir) + + # 检查是否有scratch属性,如果没有则添加 + if not hasattr(agent, 'scratch'): + agent.scratch = {} + + # 如果memory目录不存在或为空,则初始化代理 + if not is_exist: + # 初始化代理的scratch数据 + scratch_data = { + "first_name": cfg.config["attribute"]["name"], + "last_name": "", + "age": cfg.config["attribute"]["age"], + "gender": cfg.config["attribute"]["gender"], + "traits": cfg.config["attribute"]["additional"], + "status": "active", + "location": "home", + "occupation": cfg.config["attribute"]["job"], + "interests": [], + "current_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + agent.scratch = scratch_data + else: + # 加载之前保存的记忆 + load_agent_memory(agent) + + return agent + +def load_agent_memory(agent): + """ + 从文件加载代理的记忆 + + 参数: + agent: GenerativeAgent对象 + """ + try: + # 获取memory目录路径 + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + memory_dir = os.path.join(base_dir, "memory") + memory_stream_dir = os.path.join(memory_dir, "memory_stream") + + # 加载scratch.json + scratch_path = os.path.join(memory_dir, "scratch.json") + if os.path.exists(scratch_path) and os.path.getsize(scratch_path) > 2: # 文件存在且不为空 + with open(scratch_path, 'r', encoding='utf-8') as f: + scratch_data = json.load(f) + agent.scratch = scratch_data + + # 加载nodes.json + nodes_path = os.path.join(memory_stream_dir, "nodes.json") + if os.path.exists(nodes_path) and os.path.getsize(nodes_path) > 2: # 文件存在且不为空 + with open(nodes_path, 'r', encoding='utf-8') as f: + nodes_data = json.load(f) + + # 清空当前的seq_nodes + agent.memory_stream.seq_nodes = [] + agent.memory_stream.id_to_node = {} + + # 重新创建节点 + for node_dict in nodes_data: + new_node = ConceptNode(node_dict) + agent.memory_stream.seq_nodes.append(new_node) + agent.memory_stream.id_to_node[new_node.node_id] = new_node + + # 加载embeddings.json + embeddings_path = os.path.join(memory_stream_dir, "embeddings.json") + if os.path.exists(embeddings_path) and os.path.getsize(embeddings_path) > 2: # 文件存在且不为空 + with open(embeddings_path, 'r', encoding='utf-8') as f: + embeddings_data = json.load(f) + agent.memory_stream.embeddings = embeddings_data + + util.log(1, f"已加载代理记忆") + except Exception as e: + util.log(1, f"加载代理记忆失败: {str(e)}") + +# 记忆对话内容的线程函数 +def remember_conversation_thread(username, content, response_text): + """ + 在单独线程中记录对话内容到代理记忆 + + 参数: + username: 用户名 + content: 用户问题内容 + response_text: 代理回答内容 + """ + global agent + try: + with agent_lock: + # 获取当前时间作为time_step + time_step = get_current_time_step() + + # 记录对话内容 + memory_content = f"在对话中,我回答了用户{username}的问题:{content}\n,我的回答是:{response_text}" + agent.remember(memory_content, time_step) + except Exception as e: + util.log(1, f"记忆对话内容出错: {str(e)}") + +def question(content, uid=0, observation=""): + """ + 处理用户问题并返回回答 + + 参数: + content: 用户问题内容 + uid: 用户ID,默认为0 + observation: 额外的观察信息,默认为空 + + 返回: + response_text: 回答内容 + """ + global agent + # 获取用户名 + username = member_db.new_instance().find_username_by_uid(uid) if uid != 0 else "User" + + # 创建代理 + agent = create_agent() + + # 获取对话历史 + history_messages = [] + history_messages.append([username, content]) + + # 构建提示信息 + str_dialogue = "" + for row in history_messages: + str_dialogue += f"[{row[0]}]: {row[1]}\n" + str_dialogue += f"[Fay]: [Fill in]\n" + + # 构建代理描述 + agent_desc = { + "first_name": agent.scratch.get("first_name", "Fay"), + "last_name": agent.scratch.get("last_name", ""), + "age": agent.scratch.get("age", "25"), + "gender": agent.scratch.get("gender", "女"), + "traits": agent.scratch.get("traits", "友好、乐于助人"), + "status": agent.scratch.get("status", "active"), + "location": agent.scratch.get("location", "home"), + "occupation": agent.scratch.get("occupation", "助手"), + } + + # 获取相关记忆作为上下文 + context = "" + if agent.memory_stream and len(agent.memory_stream.seq_nodes) > 0: + # 获取当前时间步 + current_time_step = get_current_time_step() + + # 使用retrieve方法获取相关记忆 + try: + related_memories = agent.memory_stream.retrieve( + [content], # 查询句子列表 + current_time_step, # 当前时间步 + n_count=5, # 获取5条相关记忆 + curr_filter="all", # 获取所有类型的记忆 + hp=[0, 1, 0.5] # 权重:[时间近度权重recency_w, 相关性权重relevance_w, 重要性权重importance_w] + ) + + if related_memories and content in related_memories: + memory_nodes = related_memories[content] + if memory_nodes: + context = "以下是相关的记忆:\n" + for node in memory_nodes: + context += f"- {node.content}\n" + except Exception as e: + util.log(1, f"获取相关记忆时出错: {str(e)}") + + # 使用流式请求获取回答 + session = requests.Session() + session.verify = False + httpproxy = cfg.proxy_config + if httpproxy: + session.proxies = { + "http": f"http://{httpproxy}", + "https": f"https://{httpproxy}" + } + + # 构建消息 + prompt = f"""你是我的数字人,你名字是:{agent_desc['first_name']},你性别为{agent_desc['gender']}, + 你年龄为{agent_desc['age']},你职业为{agent_desc['occupation']}, + {agent_desc['traits']}。 + 你有以下记忆和上下文信息:{context} + 回答之前请一步一步想清楚。对于大部分问题,请直接回答并提供有用和准确的信息。 + 所有回复请尽量控制在20字内。 + """ + + messages = [{"role": "system", "content": prompt}] + messages.append({"role": "user", "content": content}) + + # 构建请求数据 + data = { + "model": cfg.gpt_model_engine, + "messages": messages, + "temperature": 0.3, + "max_tokens": 4096, + "user": f"user_{uid}" + } + + # 开启流式传输 + data["stream"] = True + + url = cfg.gpt_base_url + "/chat/completions" + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {cfg.key_gpt_api_key}' + } + + try: + response = session.post(url, json=data, headers=headers, stream=True) + response.raise_for_status() + + full_response_text = "" + accumulated_text = "" + punctuation_marks = ["。", "!", "?", ".", "!", "?", "\n"] + is_first_sentence = True + for raw_line in response.iter_lines(decode_unicode=False): + line = raw_line.decode('utf-8', errors='ignore') + if not line or line.strip() == "": + continue + + if line.startswith("data: "): + chunk = line[len("data: "):].strip() + try: + json_data = json.loads(chunk) + finish_reason = json_data["choices"][0].get("finish_reason") + if finish_reason is not None: + if finish_reason == "stop": + # 确保最后一段文本也被发送 + if accumulated_text: + if is_first_sentence: + accumulated_text += "_" + is_first_sentence = False + stream_manager.new_instance().write_sentence(uid, accumulated_text) + # 发送结束标记 + stream_manager.new_instance().write_sentence(uid, "_") + break + + flush_text = json_data["choices"][0]["delta"].get("content", "") + if not flush_text: + continue + + accumulated_text += flush_text + + for mark in punctuation_marks: + if mark in accumulated_text: + last_punct_pos = max(accumulated_text.rfind(p) for p in punctuation_marks if p in accumulated_text) + if last_punct_pos != -1: + to_write = accumulated_text[:last_punct_pos + 1] + accumulated_text = accumulated_text[last_punct_pos + 1:] + if is_first_sentence: + to_write += "_" + is_first_sentence = False + stream_manager.new_instance().write_sentence(uid, to_write) + break + + full_response_text += flush_text + except json.JSONDecodeError: + continue + + # 在单独线程中记忆对话内容 + from scheduler.thread_manager import MyThread + MyThread(target=remember_conversation_thread, args=(username, content, full_response_text)).start() + + return full_response_text + + except requests.exceptions.RequestException as e: + util.log(1, f"请求失败: {e}") + error_message = "抱歉,我现在太忙了,休息一会,请稍后再试。" + stream_manager.new_instance().write_sentence(uid, error_message) + return error_message + +def set_memory_cleared_flag(flag=True): + """ + 设置记忆清除标记 + + 参数: + flag: 是否清除记忆,默认为True + """ + global memory_cleared + memory_cleared = flag + if not flag: + # 删除.memory_cleared标记文件(如果存在) + memory_cleared_flag_file = os.path.join(memory_dir, ".memory_cleared") + if os.path.exists(memory_cleared_flag_file): + try: + os.remove(memory_cleared_flag_file) + util.log(1, f"删除记忆清除标记文件: {memory_cleared_flag_file}") + except Exception as e: + util.log(1, f"删除记忆清除标记文件时出错: {str(e)}") + +def clear_agent_memory(): + """ + 清除已加载的agent记忆,但不删除文件 + + 该方法仅清除内存中已加载的记忆,不影响持久化存储。 + 如果需要同时清除文件存储,请使用genagents_flask.py中的api_clear_memory方法。 + """ + global agent + + try: + with agent_lock: + if agent is None: + util.log(1, "代理未初始化,无需清除记忆") + return + + # 清除记忆流中的节点 + agent = None + + # 设置记忆清除标记,防止在退出时保存空记忆 + set_memory_cleared_flag(True) + + util.log(1, "已成功清除代理在内存中的记忆") + + return True + except Exception as e: + util.log(1, f"清除代理记忆时出错: {str(e)}") + return False + +# 反思 +def perform_daily_reflection(): + # 获取当前时间作为time_step + current_time_step = get_current_time_step() + + # 获取今天的日期,用于确定反思主题 + today = datetime.datetime.now().weekday() + + # 根据星期几选择不同反思主题 + reflection_topics = [ + "我与用户的关系发展,以及我如何更好地理解和服务他们", + "我的知识库如何得到扩展,哪些概念需要进一步理解", + "我的情感响应模式以及它们如何反映我的核心价值观", + "我的沟通方式如何影响互动质量,哪些模式最有效", + "我的行为如何体现我的核心特质,我的自我认知有何变化", + "今天的经历如何与我的过往记忆建立联系,形成什么样的模式", + "本周的整体经历与学习" + ] + + # 选择今天的主题(可以按星期轮换或其他逻辑) + topic = reflection_topics[today % len(reflection_topics)] + + # 执行反思 + agent.reflect(topic) + + # 记录反思执行情况 + util.log(1, f"反思主题: {topic}") + +def save_agent_memory(): + """ + 保存代理的记忆到文件 + """ + global agent + + # 检查记忆清除标记,如果已清除则不保存 + global memory_cleared + if memory_cleared: + util.log(1, "检测到记忆已被清除,跳过保存操作") + return + + try: + with agent_lock: + # 检查.memory_cleared标记文件是否存在 + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + memory_dir = os.path.join(base_dir, "memory") + memory_cleared_flag_file = os.path.join(memory_dir, ".memory_cleared") + + if os.path.exists(memory_cleared_flag_file): + util.log(1, "检测到记忆清除标记文件,跳过保存操作") + return + + # 确保agent和memory_stream已初始化 + if agent is None: + util.log(1, "代理未初始化,无法保存记忆") + return + + if agent.memory_stream is None: + util.log(1, "代理记忆流未初始化,无法保存记忆") + return + + # 确保embeddings不为None + if agent.memory_stream.embeddings is None: + util.log(1, "代理embeddings为None,初始化为空字典") + agent.memory_stream.embeddings = {} + + # 确保seq_nodes不为None + if agent.memory_stream.seq_nodes is None: + util.log(1, "代理seq_nodes为None,初始化为空列表") + agent.memory_stream.seq_nodes = [] + + # 确保id_to_node不为None + if agent.memory_stream.id_to_node is None: + util.log(1, "代理id_to_node为None,初始化为空字典") + agent.memory_stream.id_to_node = {} + + # 确保scratch不为None + if agent.scratch is None: + util.log(1, "代理scratch为None,初始化为空字典") + agent.scratch = {} + + # 保存记忆前进行完整性检查 + try: + # 检查seq_nodes中的每个节点是否有效 + valid_nodes = [] + for node in agent.memory_stream.seq_nodes: + if node is None: + util.log(1, "发现无效节点(None),跳过") + continue + + if not hasattr(node, 'node_id') or not hasattr(node, 'content'): + util.log(1, f"发现无效节点(缺少必要属性),跳过") + continue + + valid_nodes.append(node) + + # 更新seq_nodes为有效节点列表 + agent.memory_stream.seq_nodes = valid_nodes + + # 重建id_to_node字典 + agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')} + except Exception as e: + util.log(1, f"检查记忆完整性时出错: {str(e)}") + + # 保存记忆 + try: + agent.save(memory_dir) + except Exception as e: + util.log(1, f"调用agent.save()时出错: {str(e)}") + # 尝试手动保存关键数据 + try: + # 创建必要的目录 + memory_stream_dir = os.path.join(memory_dir, "memory_stream") + os.makedirs(memory_stream_dir, exist_ok=True) + + # 保存embeddings + with open(os.path.join(memory_stream_dir, "embeddings.json"), "w", encoding='utf-8') as f: + json.dump(agent.memory_stream.embeddings or {}, f, ensure_ascii=False, indent=2) + + # 保存nodes + with open(os.path.join(memory_stream_dir, "nodes.json"), "w", encoding='utf-8') as f: + nodes_data = [] + for node in agent.memory_stream.seq_nodes: + if node is not None and hasattr(node, 'package'): + try: + nodes_data.append(node.package()) + except Exception as node_e: + util.log(1, f"打包节点时出错: {str(node_e)}") + json.dump(nodes_data, f, ensure_ascii=False, indent=2) + + # 保存scratch + with open(os.path.join(memory_dir, "scratch.json"), "w", encoding='utf-8') as f: + json.dump(agent.scratch or {}, f, ensure_ascii=False, indent=2) + + # 保存meta + with open(os.path.join(memory_dir, "meta.json"), "w", encoding='utf-8') as f: + meta_data = {"id": str(agent.id)} if hasattr(agent, 'id') else {} + json.dump(meta_data, f, ensure_ascii=False, indent=2) + + util.log(1, "通过备用方法成功保存记忆") + except Exception as backup_e: + util.log(1, f"备用保存方法也失败: {str(backup_e)}") + + # 更新scratch中的时间 + try: + agent.scratch["current_time"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + except Exception as e: + util.log(1, f"更新时间时出错: {str(e)}") + + util.log(1, f"已保存代理记忆") + except Exception as e: + util.log(1, f"保存代理记忆失败: {str(e)}") + +if __name__ == "__main__": + init_memory_scheduler() + for _ in range(3): + query = "Fay是什么" + response = question(query) + print(f"Q: {query}") + print(f"A: {response}") + time.sleep(1) diff --git a/main.py b/main.py index 105026f..641ab22 100644 --- a/main.py +++ b/main.py @@ -40,6 +40,10 @@ def __clear_logs(): if file_name.endswith('.log'): os.remove('./logs/' + file_name) +def __create_memory(): + if not os.path.exists("./memory"): + os.mkdir("./memory") + def kill_process_by_port(port): for conn in psutil.net_connections(kind='inet'): if conn.laddr.port == port and conn.pid: @@ -108,6 +112,7 @@ def console_listener(): if __name__ == '__main__': __clear_samples() + __create_memory() __clear_logs() #init_db diff --git a/requirements.txt b/requirements.txt index 2540366..ad80429 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,5 @@ psutil langchain langchain_openai langgraph -bs4 \ No newline at end of file +bs4 +schedule \ No newline at end of file diff --git a/samples/sample-1735661618567.mp3 b/samples/sample-1735661618567.mp3 deleted file mode 100644 index 7dd01fc..0000000 Binary files a/samples/sample-1735661618567.mp3 and /dev/null differ diff --git a/shell/brew.sh b/shell/brew.sh deleted file mode 100644 index c529c5e..0000000 --- a/shell/brew.sh +++ /dev/null @@ -1,816 +0,0 @@ -#HomeBrew自动安装脚本 -#qq 467665317 -#brew brew brew brew - -#获取硬件信息 判断inter还是苹果M -UNAME_MACHINE="$(uname -m)" -#在X86电脑上测试arm电脑 -# UNAME_MACHINE="arm64" - -# 判断是Linux还是Mac os -OS="$(uname)" -if [[ "$OS" == "Linux" ]]; then - HOMEBREW_ON_LINUX=1 -elif [[ "$OS" != "Darwin" ]]; then - echo "Homebrew 只运行在 Mac OS 或 Linux." -fi - -# 字符串染色程序 -if [[ -t 1 ]]; then - tty_escape() { printf "\033[%sm" "$1"; } -else - tty_escape() { :; } -fi -tty_universal() { tty_escape "0;$1"; } #正常显示 -tty_mkbold() { tty_escape "1;$1"; } #设置高亮 -tty_underline="$(tty_escape "4;39")" #下划线 -tty_blue="$(tty_universal 34)" #蓝色 -tty_red="$(tty_universal 31)" #红色 -tty_green="$(tty_universal 32)" #绿色 -tty_yellow="$(tty_universal 33)" #黄色 -tty_bold="$(tty_universal 39)" #加黑 -tty_cyan="$(tty_universal 36)" #青色 -tty_reset="$(tty_escape 0)" #去除颜色 - -#用户输入极速安装speed,git克隆只取最近新版本 -#但是update会出错,提示需要下载全部数据 -GIT_SPEED="" - -if [[ $0 == "speed" ]]; then - GIT_SPEED="--depth=1" -else - for dir in $@; do - echo $dir - if [[ $dir == "speed" ]]; then - GIT_SPEED="--depth=1" - fi - done -fi - -if [[ $GIT_SPEED != "" ]]; then -echo "${tty_red} - 检测到参数speed,只拉取最新数据,可以正常install使用! - 腾讯和阿里不支持speed拉取,需要腾讯阿里需要完全模式。 - 但是以后brew update的时候会报错,运行报错提示的两句命令即可修复 - ${tty_reset}" -fi - -#获取前面两个.的数据 -major_minor() { - echo "${1%%.*}.$(x="${1#*.}"; echo "${x%%.*}")" -} - -#设置一些平台地址 -if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then - #Mac - if [[ "$UNAME_MACHINE" == "arm64" ]]; then - #M1 - HOMEBREW_PREFIX="/opt/homebrew" - HOMEBREW_REPOSITORY="${HOMEBREW_PREFIX}" - else - #Inter - HOMEBREW_PREFIX="/usr/local" - HOMEBREW_REPOSITORY="${HOMEBREW_PREFIX}/Homebrew" - fi - - HOMEBREW_CACHE="${HOME}/Library/Caches/Homebrew" - HOMEBREW_LOGS="${HOME}/Library/Logs/Homebrew" - - #国内没有homebrew-services,手动在gitee创建了一个,有少数人用到。 - USER_SERVICES_GIT=https://gitee.com/cunkai/homebrew-services.git - - STAT="stat -f" - CHOWN="/usr/sbin/chown" - CHGRP="/usr/bin/chgrp" - GROUP="admin" - TOUCH="/usr/bin/touch" - - #获取Mac系统版本 - macos_version="$(major_minor "$(/usr/bin/sw_vers -productVersion)")" -else - #Linux - UNAME_MACHINE="$(uname -m)" - - HOMEBREW_PREFIX="/home/linuxbrew/.linuxbrew" - HOMEBREW_REPOSITORY="${HOMEBREW_PREFIX}/Homebrew" - - HOMEBREW_CACHE="${HOME}/.cache/Homebrew" - HOMEBREW_LOGS="${HOME}/.logs/Homebrew" - - STAT="stat --printf" - CHOWN="/bin/chown" - CHGRP="/bin/chgrp" - GROUP="$(id -gn)" - TOUCH="/bin/touch" -fi - - - -#获取系统时间 -TIME=$(date "+%Y-%m-%d %H:%M:%S") - -JudgeSuccess() -{ - if [ $? -ne 0 ];then - echo "${tty_red}此步骤失败 '$1'${tty_reset}" - if [[ "$2" == 'out' ]]; then - exit 0 - fi - else - echo "${tty_green}此步骤成功${tty_reset}" - - fi -} -# 判断是否有系统权限 -have_sudo_access() { - if [[ -z "${HAVE_SUDO_ACCESS-}" ]]; then - /usr/bin/sudo -l mkdir &>/dev/null - HAVE_SUDO_ACCESS="$?" - fi - - if [[ "$HAVE_SUDO_ACCESS" -ne 0 ]]; then - echo "${tty_red}开机密码输入错误,获取权限失败!${tty_reset}" - fi - - return "$HAVE_SUDO_ACCESS" -} - - -abort() { - printf "%s\n" "$1" - # exit 1 -} - -shell_join() { - local arg - printf "%s" "$1" - shift - for arg in "$@"; do - printf " " - printf "%s" "${arg// /\ }" - done -} - -execute() { - if ! "$@"; then - abort "$(printf "${tty_red}此命令运行失败: %s${tty_reset}" "$(shell_join "$@")")" - fi -} - - - -ohai() { - printf "${tty_blue}运行代码 ==>${tty_bold} %s${tty_reset}\n" "$(shell_join "$@")" -} - -# 管理员运行 -execute_sudo() -{ - - local -a args=("$@") - if have_sudo_access; then - if [[ -n "${SUDO_ASKPASS-}" ]]; then - args=("-A" "${args[@]}") - fi - ohai "/usr/bin/sudo" "${args[@]}" - execute "/usr/bin/sudo" "${args[@]}" - else - ohai "${args[@]}" - execute "${args[@]}" - fi -} -#添加文件夹权限 -AddPermission() -{ - execute_sudo "/bin/chmod" "-R" "a+rwx" "$1" - execute_sudo "$CHOWN" "$USER" "$1" - execute_sudo "$CHGRP" "$GROUP" "$1" -} -#创建文件夹 -CreateFolder() -{ - echo '-> 创建文件夹' $1 - execute_sudo "/bin/mkdir" "-p" "$1" - JudgeSuccess - AddPermission $1 -} - -RmAndCopy() -{ - if [[ -d $1 ]]; then - echo " ---备份要删除的$1到系统桌面...." - if ! [[ -d $HOME/Desktop/Old_Homebrew/$TIME/$1 ]]; then - sudo mkdir -p "$HOME/Desktop/Old_Homebrew/$TIME/$1" - fi - sudo cp -rf $1 "$HOME/Desktop/Old_Homebrew/$TIME/$1" - echo " ---$1 备份完成" - fi - sudo rm -rf $1 -} - -RmCreate() -{ - RmAndCopy $1 - CreateFolder $1 -} - -#判断文件夹存在但不可写 -exists_but_not_writable() { - [[ -e "$1" ]] && ! [[ -r "$1" && -w "$1" && -x "$1" ]] -} -#文件所有者 -get_owner() { - $(shell_join "$STAT %u $1" ) -} -#文件本人无权限 -file_not_owned() { - [[ "$(get_owner "$1")" != "$(id -u)" ]] -} -#获取所属的组 -get_group() { - $(shell_join "$STAT %g $1" ) -} -#不在所属组 -file_not_grpowned() { - [[ " $(id -G "$USER") " != *" $(get_group "$1") "* ]] -} -#获得当前文件夹权限 例如777 -get_permission() { - $(shell_join "$STAT %A $1" ) -} -#授权当前用户权限 -user_only_chmod() { - [[ -d "$1" ]] && [[ "$(get_permission "$1")" != "755" ]] -} - - -#创建brew需要的目录 直接复制于国外版本,同步 -CreateBrewLinkFolder() -{ - echo "--创建Brew所需要的目录" - directories=(bin etc include lib sbin share opt var - Frameworks - etc/bash_completion.d lib/pkgconfig - share/aclocal share/doc share/info share/locale share/man - share/man/man1 share/man/man2 share/man/man3 share/man/man4 - share/man/man5 share/man/man6 share/man/man7 share/man/man8 - var/log var/homebrew var/homebrew/linked - bin/brew) - group_chmods=() - for dir in "${directories[@]}"; do - if exists_but_not_writable "${HOMEBREW_PREFIX}/${dir}"; then - group_chmods+=("${HOMEBREW_PREFIX}/${dir}") - fi - done - - directories=(share/zsh share/zsh/site-functions) - zsh_dirs=() - for dir in "${directories[@]}"; do - zsh_dirs+=("${HOMEBREW_PREFIX}/${dir}") - done - - directories=(bin etc include lib sbin share var opt - share/zsh share/zsh/site-functions - var/homebrew var/homebrew/linked - Cellar Caskroom Frameworks) - mkdirs=() - for dir in "${directories[@]}"; do - if ! [[ -d "${HOMEBREW_PREFIX}/${dir}" ]]; then - mkdirs+=("${HOMEBREW_PREFIX}/${dir}") - fi - done - - user_chmods=() - if [[ "${#zsh_dirs[@]}" -gt 0 ]]; then - for dir in "${zsh_dirs[@]}"; do - if user_only_chmod "${dir}"; then - user_chmods+=("${dir}") - fi - done - fi - - chmods=() - if [[ "${#group_chmods[@]}" -gt 0 ]]; then - chmods+=("${group_chmods[@]}") - fi - if [[ "${#user_chmods[@]}" -gt 0 ]]; then - chmods+=("${user_chmods[@]}") - fi - - chowns=() - chgrps=() - if [[ "${#chmods[@]}" -gt 0 ]]; then - for dir in "${chmods[@]}"; do - if file_not_owned "${dir}"; then - chowns+=("${dir}") - fi - if file_not_grpowned "${dir}"; then - chgrps+=("${dir}") - fi - done - fi - - if [[ -d "${HOMEBREW_PREFIX}" ]]; then - if [[ "${#chmods[@]}" -gt 0 ]]; then - execute_sudo "/bin/chmod" "u+rwx" "${chmods[@]}" - fi - if [[ "${#group_chmods[@]}" -gt 0 ]]; then - execute_sudo "/bin/chmod" "g+rwx" "${group_chmods[@]}" - fi - if [[ "${#user_chmods[@]}" -gt 0 ]]; then - execute_sudo "/bin/chmod" "755" "${user_chmods[@]}" - fi - if [[ "${#chowns[@]}" -gt 0 ]]; then - execute_sudo "$CHOWN" "$USER" "${chowns[@]}" - fi - if [[ "${#chgrps[@]}" -gt 0 ]]; then - execute_sudo "$CHGRP" "$GROUP" "${chgrps[@]}" - fi - else - execute_sudo "/bin/mkdir" "-p" "${HOMEBREW_PREFIX}" - if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then - execute_sudo "$CHOWN" "root:wheel" "${HOMEBREW_PREFIX}" - else - execute_sudo "$CHOWN" "$USER:$GROUP" "${HOMEBREW_PREFIX}" - fi - fi - - if [[ "${#mkdirs[@]}" -gt 0 ]]; then - execute_sudo "/bin/mkdir" "-p" "${mkdirs[@]}" - execute_sudo "/bin/chmod" "g+rwx" "${mkdirs[@]}" - execute_sudo "$CHOWN" "$USER" "${mkdirs[@]}" - execute_sudo "$CHGRP" "$GROUP" "${mkdirs[@]}" - fi - - if ! [[ -d "${HOMEBREW_REPOSITORY}" ]]; then - execute_sudo "/bin/mkdir" "-p" "${HOMEBREW_REPOSITORY}" - fi - execute_sudo "$CHOWN" "-R" "$USER:$GROUP" "${HOMEBREW_REPOSITORY}" - - if ! [[ -d "${HOMEBREW_CACHE}" ]]; then - if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then - execute_sudo "/bin/mkdir" "-p" "${HOMEBREW_CACHE}" - else - execute "/bin/mkdir" "-p" "${HOMEBREW_CACHE}" - fi - fi - if exists_but_not_writable "${HOMEBREW_CACHE}"; then - execute_sudo "/bin/chmod" "g+rwx" "${HOMEBREW_CACHE}" - fi - if file_not_owned "${HOMEBREW_CACHE}"; then - execute_sudo "$CHOWN" "-R" "$USER" "${HOMEBREW_CACHE}" - fi - if file_not_grpowned "${HOMEBREW_CACHE}"; then - execute_sudo "$CHGRP" "-R" "$GROUP" "${HOMEBREW_CACHE}" - fi - if [[ -d "${HOMEBREW_CACHE}" ]]; then - execute "$TOUCH" "${HOMEBREW_CACHE}/.cleaned" - fi - echo "--依赖目录脚本运行完成" -} - -#git提交 -git_commit(){ - git add . - git commit -m "your del" -} - -#version_gt 判断$1是否大于$2 -version_gt() { - [[ "${1%.*}" -gt "${2%.*}" ]] || [[ "${1%.*}" -eq "${2%.*}" && "${1#*.}" -gt "${2#*.}" ]] -} -#version_ge 判断$1是否大于等于$2 -version_ge() { - [[ "${1%.*}" -gt "${2%.*}" ]] || [[ "${1%.*}" -eq "${2%.*}" && "${1#*.}" -ge "${2#*.}" ]] -} -#version_lt 判断$1是否小于$2 -version_lt() { - [[ "${1%.*}" -lt "${2%.*}" ]] || [[ "${1%.*}" -eq "${2%.*}" && "${1#*.}" -lt "${2#*.}" ]] -} - -#发现错误 关闭脚本 提示如何解决 -error_game_over(){ - echo " - ${tty_red}失败$MY_DOWN_NUM 右键下面地址查看常见错误解决办法 - https://github.com/TheRamU/Fay - 如果没有解决,把全部运行过程截图发到 467665317@qq.com ${tty_reset} - " - - exit 0 -} - -#一些警告判断 -warning_if(){ - git_https_proxy=$(git config --global https.proxy) - git_http_proxy=$(git config --global http.proxy) - if [[ -z "$git_https_proxy" && -z "$git_http_proxy" ]]; then - echo "未发现Git代理(属于正常状态)" - else - echo "${tty_yellow} - 提示:发现你电脑设置了Git代理,如果Git报错,请运行下面两句话: - - git config --global --unset https.proxy - - git config --global --unset http.proxy${tty_reset} - " - fi -} - -echo " - ${tty_green} 开始执行Brew自动安装程序 ${tty_reset} - ${tty_cyan} [467665317@qq.com] ${tty_reset} - ['$TIME']['$macos_version'] - ${tty_cyan} https://github.com/TheRamU/Fay${tty_reset} -" -#选择一个brew下载源 -echo -n "${tty_green} -请选择一个下载brew本体的序号,例如中科大,输入1回车。 -源有时候不稳定,如果git克隆报错重新运行脚本选择源。 -1、中科大下载源 -2、清华大学下载源 -3、北京外国语大学下载源 ${tty_reset}" -if [[ $GIT_SPEED == "" ]]; then - echo -n "${tty_green} -4、腾讯下载源 -5、阿里巴巴下载源 ${tty_reset}" -fi -echo -n " -${tty_blue}请输入序号: " -read MY_DOWN_NUM -echo "${tty_reset}" -case $MY_DOWN_NUM in -"2") - echo " - 你选择了清华大学brew本体下载源 - " - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.tuna.tsinghua.edu.cn/homebrew-bottles/ - #HomeBrew基础框架 - USER_BREW_GIT=https://mirrors.tuna.tsinghua.edu.cn/git/homebrew/brew.git - #HomeBrew Core - USER_CORE_GIT=https://mirrors.tuna.tsinghua.edu.cn/git/homebrew/homebrew-core.git - #HomeBrew Cask - USER_CASK_GIT=https://mirrors.tuna.tsinghua.edu.cn/git/homebrew/homebrew-cask.git - USER_CASK_FONTS_GIT=https://mirrors.tuna.tsinghua.edu.cn/git/homebrew/homebrew-cask-fonts.git - USER_CASK_DRIVERS_GIT=https://mirrors.tuna.tsinghua.edu.cn/git/homebrew/homebrew-cask-drivers.git -;; -"3") - echo " - 北京外国语大学brew本体下载源 - " - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.bfsu.edu.cn/homebrew-bottles - #HomeBrew基础框架 - USER_BREW_GIT=https://mirrors.bfsu.edu.cn/git/homebrew/brew.git - #HomeBrew Core - USER_CORE_GIT=https://mirrors.bfsu.edu.cn/git/homebrew/homebrew-core.git - #HomeBrew Cask - USER_CASK_GIT=https://mirrors.bfsu.edu.cn/git/homebrew/homebrew-cask.git - USER_CASK_FONTS_GIT=https://mirrors.bfsu.edu.cn/git/homebrew/homebrew-cask-fonts.git - USER_CASK_DRIVERS_GIT=https://mirrors.bfsu.edu.cn/git/homebrew/homebrew-cask-drivers.git -;; -"4") - echo " - 你选择了腾讯brew本体下载源 - " - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.cloud.tencent.com/homebrew-bottles - #HomeBrew基础框架 - USER_BREW_GIT=https://mirrors.cloud.tencent.com/homebrew/brew.git - #HomeBrew Core - USER_CORE_GIT=https://mirrors.cloud.tencent.com/homebrew/homebrew-core.git - #HomeBrew Cask - USER_CASK_GIT=https://mirrors.cloud.tencent.com/homebrew/homebrew-cask.git -;; -"5") - echo " - 你选择了阿里巴巴brew本体下载源 - " - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.aliyun.com/homebrew/homebrew-bottles - #HomeBrew基础框架 - USER_BREW_GIT=https://mirrors.aliyun.com/homebrew/brew.git - #HomeBrew Core - USER_CORE_GIT=https://mirrors.aliyun.com/homebrew/homebrew-core.git - #HomeBrew Cask - USER_CASK_GIT=https://mirrors.aliyun.com/homebrew/homebrew-cask.git -;; -*) - echo " - 你选择了中国科学技术大学brew本体下载源 - " - #HomeBrew 下载源 install - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.ustc.edu.cn/homebrew-bottles - #HomeBrew基础框架 - USER_BREW_GIT=https://mirrors.ustc.edu.cn/brew.git - #HomeBrew Core - USER_CORE_GIT=https://mirrors.ustc.edu.cn/homebrew-core.git - #HomeBrew Cask - USER_CASK_GIT=https://mirrors.ustc.edu.cn/homebrew-cask.git -;; -esac -echo -n "${tty_green}!!!此脚本将要删除之前的brew(包括它下载的软件),请自行备份。 -->是否现在开始执行脚本(N/Y) " -read MY_Del_Old -echo "${tty_reset}" -case $MY_Del_Old in -"y") -echo "--> 脚本开始执行" -;; -"Y") -echo "--> 脚本开始执行" -;; -*) -echo "你输入了 $MY_Del_Old ,自行备份老版brew和它下载的软件, 如果继续运行脚本应该输入Y或者y -" -exit 0 -;; -esac - - -if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then -#MAC - echo "${tty_yellow} Mac os设置开机密码方法: - (设置开机密码:在左上角苹果图标->系统偏好设置->"用户与群组"->更改密码) - (如果提示This incident will be reported. 在"用户与群组"中查看是否管理员) ${tty_reset}" -fi - -echo "==> 通过命令删除之前的brew、创建一个新的Homebrew文件夹 -${tty_cyan}请输入开机密码,输入过程不显示,输入完后回车${tty_reset}" - -sudo echo '开始执行' -#删除以前的Homebrew -RmCreate ${HOMEBREW_REPOSITORY} -RmAndCopy $HOMEBREW_CACHE -RmAndCopy $HOMEBREW_LOGS - -# 让环境暂时纯粹,脚本运行结束后恢复 -if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then - export PATH=/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:${HOMEBREW_REPOSITORY}/bin -fi -git --version -if [ $? -ne 0 ];then - - if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then - sudo rm -rf "/Library/Developer/CommandLineTools/" - echo "${tty_cyan}安装Git${tty_reset}后再运行此脚本,${tty_red}在系统弹窗中点击“安装”按钮 - 如果没有弹窗的老系统,需要自己下载安装:https://sourceforge.net/projects/git-osx-installer/ ${tty_reset}" - xcode-select --install - exit 0 - else - echo "${tty_red} 发现缺少git,开始安装,请输入Y ${tty_reset}" - sudo apt install git - fi -fi - -echo " -${tty_cyan}下载速度觉得慢可以ctrl+c或control+c重新运行脚本选择下载源${tty_reset} -==> 从 $USER_BREW_GIT 克隆Homebrew基本文件 -" -warning_if -sudo git clone ${GIT_SPEED} $USER_BREW_GIT ${HOMEBREW_REPOSITORY} -JudgeSuccess 尝试再次运行自动脚本选择其他下载源或者切换网络 out - -#依赖目录创建 授权等等 -CreateBrewLinkFolder - -echo '==> 创建brew的替身' -if [[ "${HOMEBREW_REPOSITORY}" != "${HOMEBREW_PREFIX}" ]]; then - find ${HOMEBREW_PREFIX}/bin -name brew -exec sudo rm -f {} \; - execute "ln" "-sf" "${HOMEBREW_REPOSITORY}/bin/brew" "${HOMEBREW_PREFIX}/bin/brew" -fi - -echo "==> 从 $USER_CORE_GIT 克隆Homebrew Core -${tty_cyan}此处如果显示Password表示需要再次输入开机密码,输入完后回车${tty_reset}" -sudo mkdir -p ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-core -sudo git clone ${GIT_SPEED} $USER_CORE_GIT ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-core/ -JudgeSuccess 尝试再次运行自动脚本选择其他下载源或者切换网络 out - -if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then -#MAC - echo "==> 从 $USER_CASK_GIT 克隆Homebrew Cask 图形化软件 - ${tty_cyan}此处如果显示Password表示需要再次输入开机密码,输入完后回车${tty_reset}" - sudo mkdir -p ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-cask - sudo git clone ${GIT_SPEED} $USER_CASK_GIT ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-cask/ - if [ $? -ne 0 ];then - sudo rm -rf ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-cask - echo "${tty_red}尝试切换下载源或者切换网络,不过Cask组件非必须模块。可以忽略${tty_reset}" - else - echo "${tty_green}此步骤成功${tty_reset}" - - fi - - echo "==> 从 $USER_SERVICES_GIT 克隆Homebrew services 管理服务的启停 - " - sudo mkdir -p ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-cask - sudo git clone ${GIT_SPEED} $USER_SERVICES_GIT ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-services/ - JudgeSuccess -else -#Linux - echo "${tty_yellow} Linux 不支持Cask图形化软件下载 此步骤跳过${tty_reset}" -fi -echo '==> 配置国内镜像源HOMEBREW BOTTLE' - -#判断下mac os终端是Bash还是zsh -case "$SHELL" in - */bash*) - if [[ -r "$HOME/.bash_profile" ]]; then - shell_profile="${HOME}/.bash_profile" - else - shell_profile="${HOME}/.profile" - fi - ;; - */zsh*) - shell_profile="${HOME}/.zprofile" - ;; - *) - shell_profile="${HOME}/.profile" - ;; -esac - -if [[ -n "${HOMEBREW_ON_LINUX-}" ]]; then - #Linux - shell_profile="/etc/profile" -fi - -if [[ -f ${shell_profile} ]]; then - AddPermission ${shell_profile} -fi -#删除之前的环境变量 -if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then - #Mac - sed -i "" "/ckbrew/d" ${shell_profile} -else - #Linux - sed -i "/ckbrew/d" ${shell_profile} -fi - -#选择一个homebrew-bottles下载源 -echo -n "${tty_green} - - Brew本体已经安装成功,接下来配置国内源。 - -请选择今后brew install的时候访问那个国内镜像,例如阿里巴巴,输入5回车。 - -1、中科大国内源 -2、清华大学国内源 -3、北京外国语大学国内源 -4、腾讯国内源 -5、阿里巴巴国内源 ${tty_reset}" - -echo -n " -${tty_blue}请输入序号: " -read MY_DOWN_NUM -echo "${tty_reset}" -case $MY_DOWN_NUM in -"2") - echo " - 你选择了清华大学国内源 - " - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.tuna.tsinghua.edu.cn/homebrew-bottles/ -;; -"3") - echo " - 北京外国语大学国内源 - " - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.bfsu.edu.cn/homebrew-bottles -;; -"4") - echo " - 你选择了腾讯国内源 - " - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.cloud.tencent.com/homebrew-bottles -;; -"5") - echo " - 你选择了阿里巴巴国内源 - " - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.aliyun.com/homebrew/homebrew-bottles -;; -*) - echo " - 你选择了中国科学技术大学国内源 - " - #HomeBrew 下载源 install - USER_HOMEBREW_BOTTLE_DOMAIN=https://mirrors.ustc.edu.cn/homebrew-bottles -;; -esac - -#写入环境变量到文件 -echo " - - 环境变量写入->${shell_profile} - -" - -echo " - export HOMEBREW_BOTTLE_DOMAIN=${USER_HOMEBREW_BOTTLE_DOMAIN} #ckbrew - eval \$(${HOMEBREW_REPOSITORY}/bin/brew shellenv) #ckbrew -" >> ${shell_profile} -JudgeSuccess -source "${shell_profile}" -if [ $? -ne 0 ];then - echo "${tty_red}发现错误,${shell_profile} 文件中有错误,建议根据上一句提示修改; - 否则会导致提示 permission denied: brew${tty_reset}" -fi - -AddPermission ${HOMEBREW_REPOSITORY} - -if [[ -n "${HOMEBREW_ON_LINUX-}" ]]; then - #检测linux curl是否有安装 - echo "${tty_red}-检测curl是否安装 留意是否需要输入Y${tty_reset}" - curl -V - if [ $? -ne 0 ];then - sudo apt-get install curl - if [ $? -ne 0 ];then - sudo yum install curl - if [ $? -ne 0 ];then - echo '失败 请自行安装curl 可以参考https://www.howtoing.com/install-curl-in-linux' - error_game_over - fi - fi - fi -fi - -echo ' -==> 安装完成,brew版本 -' -brew -v -if [ $? -ne 0 ];then - echo '发现错误,自动修复一次!' - rm -rf $HOMEBREW_CACHE - export PATH=/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:${HOMEBREW_REPOSITORY}/bin - brew update-reset - brew -v - if [ $? -ne 0 ];then - error_game_over - fi -else - echo "${tty_green}Brew前期配置成功${tty_reset}" -fi - -#brew 3.1.2版本 修改了很多地址,都写死在了代码中,没有调用环境变量。。额。。 -#ruby下载需要改官方文件 -ruby_URL_file=$HOMEBREW_REPOSITORY/Library/Homebrew/cmd/vendor-install.sh - -#判断Mac系统版本 -if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then - if version_gt "$macos_version" "10.14"; then - echo "电脑系统版本:$macos_version" - else - echo "${tty_red}检测到你不是最新系统,会有一些报错,请稍等Ruby下载安装;${tty_reset} - " - fi - - if [[ -f ${ruby_URL_file} ]]; then - sed -i "" "s/ruby_URL=/ruby_URL=\"https:\/\/mirrors.tuna.tsinghua.edu.cn\/homebrew-bottles\/bottles-portable-ruby\/\$ruby_FILENAME\" \#/g" $ruby_URL_file - fi -else - if [[ -f ${ruby_URL_file} ]]; then - sed -i "s/ruby_URL=/ruby_URL=\"https:\/\/mirrors.tuna.tsinghua.edu.cn\/linuxbrew-bottles\/bottles-portable-ruby\/\$ruby_FILENAME\" \#/g" $ruby_URL_file - fi -fi - -brew services cleanup - -if [[ $GIT_SPEED == "" ]];then - echo ' - ==> brew update-reset - ' - brew update-reset - if [[ $? -ne 0 ]];then - brew config - error_game_over - exit 0 - fi -else - #极速模式提示Update修复方法 - echo " -${tty_red} 极速版本安装完成,${tty_reset} install功能正常,如果需要update功能请自行运行下面三句命令 -git -C ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-core fetch --unshallow -git -C ${HOMEBREW_REPOSITORY}/Library/Taps/homebrew/homebrew-cask fetch --unshallow -brew update-reset - " -fi - -echo " - ${tty_green}Brew自动安装程序运行完成${tty_reset} - ${tty_green}国内地址已经配置完成${tty_reset} - - 桌面的Old_Homebrew文件夹,大致看看没有你需要的可以删除。 - - 初步介绍几个brew命令 -本地软件库列表:brew ls -查找软件:brew search google(其中google替换为要查找的关键字) -查看brew版本:brew -v 更新brew版本:brew update -安装cask软件:brew install --cask firefox 把firefox换成你要安装的 - ${tty_green} - 欢迎右键点击下方地址-打开URL 来给个星${tty_reset} - ${tty_underline} https://github.com/TheRamU/Fay ${tty_reset} -" - -if [[ -z "${HOMEBREW_ON_LINUX-}" ]]; then - #Mac - echo "${tty_red} 安装成功 但还需要重启终端 或者 运行${tty_bold} source ${shell_profile} ${tty_reset} ${tty_red}否则可能无法使用${tty_reset} - " -else - #Linux - echo "${tty_red} Linux需要重启电脑 或者暂时运行${tty_bold} source ${shell_profile} ${tty_reset} ${tty_red}否则可能无法使用${tty_reset} - " -fi diff --git a/simulation_engine/example-settings.py b/simulation_engine/example-settings.py new file mode 100644 index 0000000..150d0bd --- /dev/null +++ b/simulation_engine/example-settings.py @@ -0,0 +1,18 @@ +from pathlib import Path + +OPENAI_API_KEY = "sk-hAuN7OLqKJTdyDjNFdEfF4B0E53642E4B2BbCa248594Cd29" +OPENAI_API_BASE = "https://api.zyai.online/v1" # 可以修改为你的自定义 base URL +KEY_OWNER = "xszyou" + + +DEBUG = False + +MAX_CHUNK_SIZE = 4 + +LLM_VERS = "gpt-4o-mini" + +BASE_DIR = f"{Path(__file__).resolve().parent.parent}" + +## To do: Are the following needed in the new structure? Ideally Populations_Dir is for the user to define. +POPULATIONS_DIR = f"{BASE_DIR}/agent_bank/populations" +LLM_PROMPT_DIR = f"{BASE_DIR}/simulation_engine/prompt_template" \ No newline at end of file diff --git a/simulation_engine/global_methods.py b/simulation_engine/global_methods.py new file mode 100644 index 0000000..2e0ffb9 --- /dev/null +++ b/simulation_engine/global_methods.py @@ -0,0 +1,387 @@ +import random +import json +import string +import csv +import datetime as dt +import os +import numpy +import math +import shutil, errno + +from os import listdir + + +def create_folder_if_not_there(curr_path): + """ + Checks if a folder in the curr_path exists. If it does not exist, creates + the folder. + Note that if the curr_path designates a file location, it will operate on + the folder that contains the file. But the function also works even if the + path designates to just a folder. + Args: + curr_list: list to write. The list comes in the following form: + [['key1', 'val1-1', 'val1-2'...], + ['key2', 'val2-1', 'val2-2'...],] + outfile: name of the csv file to write + RETURNS: + True: if a new folder is created + False: if a new folder is not created + """ + outfolder_name = curr_path.split("/") + if len(outfolder_name) != 1: + # This checks if the curr path is a file or a folder. + if "." in outfolder_name[-1]: + outfolder_name = outfolder_name[:-1] + + outfolder_name = "/".join(outfolder_name) + if not os.path.exists(outfolder_name): + os.makedirs(outfolder_name) + return True + + return False + + +def write_list_of_list_to_csv(curr_list_of_list, outfile): + """ + Writes a list of list to csv. + Unlike write_list_to_csv_line, it writes the entire csv in one shot. + ARGS: + curr_list_of_list: list to write. The list comes in the following form: + [['key1', 'val1-1', 'val1-2'...], + ['key2', 'val2-1', 'val2-2'...],] + outfile: name of the csv file to write + RETURNS: + None + """ + create_folder_if_not_there(outfile) + with open(outfile, "w") as f: + writer = csv.writer(f) + writer.writerows(curr_list_of_list) + + +def write_list_to_csv_line(line_list, outfile): + """ + Writes one line to a csv file. + Unlike write_list_of_list_to_csv, this opens an existing outfile and then + appends a line to that file. + This also works if the file does not exist already. + ARGS: + curr_list: list to write. The list comes in the following form: + ['key1', 'val1-1', 'val1-2'...] + Importantly, this is NOT a list of list. + outfile: name of the csv file to write + RETURNS: + None + """ + create_folder_if_not_there(outfile) + + # Opening the file first so we can write incrementally as we progress + curr_file = open(outfile, 'a',) + csvfile_1 = csv.writer(curr_file) + csvfile_1.writerow(line_list) + curr_file.close() + + +def read_file_to_list(curr_file, header=False, strip_trail=True): + """ + Reads in a csv file to a list of list. If header is True, it returns a + tuple with (header row, all rows) + ARGS: + curr_file: path to the current csv file. + RETURNS: + List of list where the component lists are the rows of the file. + """ + if not header: + analysis_list = [] + with open(curr_file) as f_analysis_file: + data_reader = csv.reader(f_analysis_file, delimiter=",") + for count, row in enumerate(data_reader): + if strip_trail: + row = [i.strip() for i in row] + analysis_list += [row] + return analysis_list + else: + analysis_list = [] + with open(curr_file) as f_analysis_file: + data_reader = csv.reader(f_analysis_file, delimiter=",") + for count, row in enumerate(data_reader): + if strip_trail: + row = [i.strip() for i in row] + analysis_list += [row] + return analysis_list[0], analysis_list[1:] + + +def read_file_to_set(curr_file, col=0): + """ + Reads in a "single column" of a csv file to a set. + ARGS: + curr_file: path to the current csv file. + RETURNS: + Set with all items in a single column of a csv file. + """ + analysis_set = set() + with open(curr_file) as f_analysis_file: + data_reader = csv.reader(f_analysis_file, delimiter=",") + for count, row in enumerate(data_reader): + analysis_set.add(row[col]) + return analysis_set + + +def get_row_len(curr_file): + """ + Get the number of rows in a csv file + ARGS: + curr_file: path to the current csv file. + RETURNS: + The number of rows + False if the file does not exist + """ + try: + analysis_set = set() + with open(curr_file) as f_analysis_file: + data_reader = csv.reader(f_analysis_file, delimiter=",") + for count, row in enumerate(data_reader): + analysis_set.add(row[0]) + return len(analysis_set) + except: + return False + + +def check_if_file_exists(curr_file): + """ + Checks if a file exists + ARGS: + curr_file: path to the current csv file. + RETURNS: + True if the file exists + False if the file does not exist + """ + try: + with open(curr_file) as f_analysis_file: pass + return True + except: + return False + + +def find_filenames(path_to_dir, suffix=".csv"): + """ + Given a directory, find all files that ends with the provided suffix and + returns their paths. + ARGS: + path_to_dir: Path to the current directory + suffix: The target suffix. + RETURNS: + A list of paths to all files in the directory. + """ + filenames = listdir(path_to_dir) + new_filenames = [] + for i in filenames: + if ".DS_Store" not in i: + new_filenames += [i] + filenames = new_filenames + return [ path_to_dir+"/"+filename + for filename in filenames if filename.endswith( suffix ) ] + + +def average(list_of_val): + """ + Finds the average of the numbers in a list. + ARGS: + list_of_val: a list of numeric values + RETURNS: + The average of the values + """ + try: + list_of_val = [float(i) for i in list_of_val if not math.isnan(i)] + return sum(list_of_val)/float(len(list_of_val)) + except: + return float('nan') + + +def std(list_of_val): + """ + Finds the std of the numbers in a list. + ARGS: + list_of_val: a list of numeric values + RETURNS: + The std of the values + """ + try: + list_of_val = [float(i) for i in list_of_val if not math.isnan(i)] + std = numpy.std(list_of_val) + return std + except: + return float('nan') + + +def copyanything(src, dst): + """ + Copy over everything in the src folder to dst folder. + ARGS: + src: address of the source folder + dst: address of the destination folder + RETURNS: + None + """ + try: + shutil.copytree(src, dst) + except OSError as exc: # python >2.5 + if exc.errno in (errno.ENOTDIR, errno.EINVAL): + shutil.copy(src, dst) + else: raise + + +def generate_alphanumeric_string(length): + characters = string.ascii_letters + string.digits + result = ''.join(random.choice(characters) for _ in range(length)) + return result + + +def extract_first_json_dict(input_str): + """ + 从字符串中提取第一个JSON字典 + + 参数: + input_str: 包含JSON字典的字符串 + + 返回: + 解析后的JSON字典,如果解析失败则返回None + """ + try: + # 确保输入是字符串类型 + if not isinstance(input_str, str): + print("提取JSON错误: 输入必须是字符串类型") + return None + + # 替换特殊引号为标准双引号 + input_str = (input_str.replace(""", "\"") + .replace(""", "\"") + .replace("'", "'") + .replace("'", "'")) + + # 查找第一个'{'的位置 + try: + start_index = input_str.index('{') + except ValueError: + print("提取JSON错误: 未找到JSON开始标记'{'") + return None + + # 初始化计数器,用于跟踪开闭括号 + count = 1 + end_index = start_index + 1 + + # 循环查找与第一个'{'匹配的'}' + while count > 0 and end_index < len(input_str): + if input_str[end_index] == '{': + count += 1 + elif input_str[end_index] == '}': + count -= 1 + end_index += 1 + + # 如果没有找到匹配的'}' + if count > 0: + print("提取JSON错误: JSON格式不完整,缺少匹配的'}'") + return None + + # 提取JSON子字符串 + json_str = input_str[start_index:end_index] + + # 解析JSON字符串为Python字典 + try: + json_dict = json.loads(json_str) + return json_dict + except json.JSONDecodeError as e: + print(f"解析JSON错误: {str(e)}") + return None + except Exception as e: + # 处理所有其他异常 + print(f"提取JSON时发生错误: {str(e)}") + return None + + +def read_file_to_string(file_path): + try: + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + return content + except FileNotFoundError: + return "The file was not found." + except Exception as e: + return str(e) + + +def write_string_to_file(full_path, text_content): + create_folder_if_not_there(full_path) + import os + try: + with open(full_path, 'w', encoding='utf-8') as file: + file.write(text_content) + return f"File successfully written to {full_path}" + except Exception as e: + return str(e) + + +def chunk_list(lst, q_chunk_size): + """ + Splits the given list into sublists of specified chunk size. + + Parameters: + lst (list): The list to be split into chunks. + q_chunk_size (int): The size of each chunk. + + Returns: + list: A list of sublists where each sublist has a length of q_chunk_size. + """ + # Initialize the result list + chunked_list = [] + + # Loop through the list in steps of q_chunk_size + for i in range(0, len(lst), q_chunk_size): + # Append the sublist to the result list + chunked_list.append(lst[i:i + q_chunk_size]) + + return chunked_list + + +def write_dict_to_json(data, filename): + """ + Writes a dictionary to a JSON file. + + Parameters: + data (dict): The dictionary to write to the JSON file. + filename (str): The name of the file to write the JSON data to. + """ + try: + # 确保目录存在 + directory = os.path.dirname(filename) + if directory and not os.path.exists(directory): + os.makedirs(directory) + + # 使用UTF-8编码写入JSON文件 + with open(filename, 'w', encoding='utf-8') as file: + json.dump(data, file, ensure_ascii=False, indent=4) + except Exception as e: + print(f"写入JSON文件时出错: {str(e)}") + + +def read_json_to_dict(file_path): + """ + Reads a JSON file and converts it to a Python dictionary. + + Parameters: + file_path (str): The path to the JSON file. + + Returns: + dict: The content of the JSON file as a dictionary. + """ + try: + # 使用UTF-8编码读取JSON文件 + with open(file_path, 'r', encoding='utf-8') as file: + data = json.load(file) + return data + except FileNotFoundError: + print(f"未找到文件: {file_path}") + except json.JSONDecodeError: + print(f"解析JSON文件出错: {file_path}") + except Exception as e: + print(f"发生错误: {str(e)}") diff --git a/simulation_engine/gpt_structure.py b/simulation_engine/gpt_structure.py new file mode 100644 index 0000000..67228ca --- /dev/null +++ b/simulation_engine/gpt_structure.py @@ -0,0 +1,281 @@ +import openai +import time +import base64 +from typing import List, Dict, Any, Union, Optional +import os +from simulation_engine.settings import * +from utils import config_util as cfg + + +# 确保配置已加载 +cfg.load_config() + +# 初始化 OpenAI 客户端 +client = openai.OpenAI( + api_key=OPENAI_API_KEY, + base_url=OPENAI_API_BASE +) + +# 设置全局API密钥(兼容性考虑) +openai.api_key = OPENAI_API_KEY + +# 如果环境变量中没有设置,则设置环境变量(某些库可能依赖环境变量) +if "OPENAI_API_KEY" not in os.environ: + os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY +if "OPENAI_API_BASE" not in os.environ and OPENAI_API_BASE: + os.environ["OPENAI_API_BASE"] = OPENAI_API_BASE + + +# ============================================================================ +# #######################[SECTION 1: HELPER FUNCTIONS] ####################### +# ============================================================================ + +def print_run_prompts(prompt_input: Union[str, List[str]], + prompt: str, + output: str) -> None: + print (f"=== START =======================================================") + print ("~~~ prompt_input ----------------------------------------------") + print (prompt_input, "\n") + print ("~~~ prompt ----------------------------------------------------") + print (prompt, "\n") + print ("~~~ output ----------------------------------------------------") + print (output, "\n") + print ("=== END ==========================================================") + print ("\n\n\n") + + +def generate_prompt(prompt_input: Union[str, List[str]], + prompt_lib_file: str) -> str: + """ + 通过用输入替换模板文件中的占位符来生成提示 + + 参数: + prompt_input: 输入文本,可以是字符串或字符串列表 + prompt_lib_file: 模板文件路径 + + 返回: + 生成的提示文本 + """ + # 确保prompt_input是列表类型 + if isinstance(prompt_input, str): + prompt_input = [prompt_input] + + # 确保所有输入都是字符串类型 + prompt_input = [str(i) for i in prompt_input] + + try: + # 使用UTF-8编码读取模板文件 + with open(prompt_lib_file, "r", encoding='utf-8') as f: + prompt = f.read() + except FileNotFoundError: + print(f"生成提示错误: 未找到模板文件 {prompt_lib_file}") + return "ERROR: 模板文件不存在" + except Exception as e: + print(f"读取模板文件时出错: {str(e)}") + return f"ERROR: 读取模板文件时出错 - {str(e)}" + + # 替换占位符 + for count, input_text in enumerate(prompt_input): + prompt = prompt.replace(f"!!", input_text) + + # 处理注释块 + if "###" in prompt: + prompt = prompt.split("###")[1] + + return prompt.strip() + + +# ============================================================================ +# ####################### [SECTION 2: SAFE GENERATE] ######################### +# ============================================================================ + +def gpt_request(prompt: str, + model: str = "gpt-4o", + max_tokens: int = 1500) -> str: + """ + 向OpenAI的GPT模型发送请求 + + 参数: + prompt: 提示文本 + model: 模型名称,默认为"gpt-4o" + max_tokens: 最大生成令牌数,默认为1500 + + 返回: + 模型生成的响应文本 + """ + # 确保prompt是字符串类型 + if not isinstance(prompt, str): + print("GPT请求错误: 提示文本必须是字符串类型") + return "GENERATION ERROR: 提示文本必须是字符串类型" + + # 处理o1-preview模型 + if model == "o1-preview": + try: + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}] + ) + # 确保返回的内容是UTF-8编码 + return response.choices[0].message.content + except Exception as e: + error_msg = f"GENERATION ERROR: {str(e)}" + print(error_msg) + return error_msg + + # 处理其他模型 + try: + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + temperature=0.7 + ) + # 确保返回的内容是UTF-8编码 + return response.choices[0].message.content + except Exception as e: + error_msg = f"GENERATION ERROR: {str(e)}" + print(error_msg) + return error_msg + + +def gpt4_vision(messages: List[dict], max_tokens: int = 1500) -> str: + """Make a request to OpenAI's GPT-4 Vision model.""" + try: + client = openai.OpenAI( + api_key=OPENAI_API_KEY, + base_url=OPENAI_API_BASE + ) + response = client.chat.completions.create( + model="gpt-4o", + messages=messages, + max_tokens=max_tokens, + temperature=0.7 + ) + return response.choices[0].message.content + except Exception as e: + return f"GENERATION ERROR: {str(e)}" + + +def chat_safe_generate(prompt_input: Union[str, List[str]], + prompt_lib_file: str, + gpt_version: str = "gpt-4o", + repeat: int = 1, + fail_safe: str = "error", + func_clean_up: callable = None, + verbose: bool = False, + max_tokens: int = 1500, + file_attachment: str = None, + file_type: str = None) -> tuple: + """Generate a response using GPT models with error handling & retries.""" + if file_attachment and file_type: + prompt = generate_prompt(prompt_input, prompt_lib_file) + messages = [{"role": "user", "content": prompt}] + + if file_type.lower() == 'image': + with open(file_attachment, "rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode('utf-8') + messages.append({ + "role": "user", + "content": [ + {"type": "text", "text": "Please refer to the attached image."}, + {"type": "image_url", "image_url": + {"url": f"data:image/jpeg;base64,{base64_image}"}} + ] + }) + response = gpt4_vision(messages, max_tokens) + + elif file_type.lower() == 'pdf': + pdf_text = extract_text_from_pdf_file(file_attachment) + pdf = f"PDF attachment in text-form:\n{pdf_text}\n\n" + instruction = generate_prompt(prompt_input, prompt_lib_file) + prompt = f"{pdf}" + prompt += f"\n=\nTask description:\n{instruction}" + response = gpt_request(prompt, gpt_version, max_tokens) + + else: + prompt = generate_prompt(prompt_input, prompt_lib_file) + for i in range(repeat): + response = gpt_request(prompt, model=gpt_version) + if response != "GENERATION ERROR": + break + time.sleep(2**i) + else: + response = fail_safe + + if func_clean_up: + response = func_clean_up(response, prompt=prompt) + + + if verbose or DEBUG: + print_run_prompts(prompt_input, prompt, response) + + return response, prompt, prompt_input, fail_safe + +# ============================================================================ +# #################### [SECTION 3: OTHER API FUNCTIONS] ###################### +# ============================================================================ + +# 添加模拟embedding函数 +def _create_mock_embedding(dimension=1536): + """创建一个模拟的embedding函数,用于替代真实API""" + import random + import math + import hashlib + + def _get_mock_vector(text): + """生成一个随机但一致的embedding向量""" + # 使用文本的哈希值作为随机种子,确保相同文本生成相同向量 + # 使用hashlib代替hash()函数,确保编码一致性 + try: + # 确保文本是UTF-8编码 + if isinstance(text, str): + text_bytes = text.encode('utf-8') + else: + text_bytes = str(text).encode('utf-8') + + # 使用SHA256生成哈希值 + hash_value = int(hashlib.sha256(text_bytes).hexdigest(), 16) % (10 ** 8) + random.seed(hash_value) + except Exception as e: + # 如果出现编码错误,使用一个固定的种子 + print(f"处理文本哈希时出错: {str(e)}") + random.seed(42) + + # 生成随机向量 + vector = [random.uniform(-1, 1) for _ in range(dimension)] + + # 归一化向量 + magnitude = math.sqrt(sum(x*x for x in vector)) + normalized_vector = [x/magnitude for x in vector] + + return normalized_vector + + return _get_mock_vector + +# 创建模拟函数实例 +_mock_embedding_function = _create_mock_embedding(1536) + +def get_text_embedding(text: str, + model: str = "text-embedding-3-small") -> List[float]: + """生成文本的embedding向量,使用模拟函数""" + try: + # 确保输入是有效的字符串 + if not isinstance(text, str): + print("Embedding错误: 输入必须是字符串类型") + return [0.0] * 1536 # 返回默认embedding + + # 处理空字符串 + if not text.strip(): + print("Embedding警告: 输入字符串为空") + return [0.0] * 1536 # 返回默认embedding + + # 标准化文本,替换换行符并去除首尾空格 + text = text.replace("\n", " ").strip() + + # 使用模拟函数生成embedding + return _mock_embedding_function(text) + except Exception as e: + # 捕获所有异常,确保函数不会崩溃 + print(f"生成embedding时出错: {str(e)}") + # 返回一个默认的embedding + return [0.0] * 1536 diff --git a/simulation_engine/llm_json_parser.py b/simulation_engine/llm_json_parser.py new file mode 100644 index 0000000..058d7c4 --- /dev/null +++ b/simulation_engine/llm_json_parser.py @@ -0,0 +1,57 @@ +import json +import re + + +def extract_first_json_dict(input_str): + try: + # Replace curly quotes with standard double quotes + input_str = (input_str.replace("“", "\"") + .replace("”", "\"") + .replace("‘", "'") + .replace("’", "'")) + + # Find the first occurrence of '{' in the input_str + start_index = input_str.index('{') + + # Initialize a count to keep track of open and close braces + count = 1 + end_index = start_index + 1 + + # Loop to find the closing '}' for the first JSON dictionary + while count > 0 and end_index < len(input_str): + if input_str[end_index] == '{': + count += 1 + elif input_str[end_index] == '}': + count -= 1 + end_index += 1 + + # Extract the JSON substring + json_str = input_str[start_index:end_index] + + # Parse the JSON string into a Python dictionary + json_dict = json.loads(json_str) + + return json_dict + except ValueError: + # Handle the case where the JSON parsing fails + return None + + +def extract_first_json_dict_categorical(input_str): + reasoning_pattern = r'"Reasoning":\s*"([^"]+)"' + response_pattern = r'"Response":\s*"([^"]+)"' + + reasonings = re.findall(reasoning_pattern, input_str) + responses = re.findall(response_pattern, input_str) + + return responses, reasonings + + +def extract_first_json_dict_numerical(input_str): + reasoning_pattern = re.compile(r'"Reasoning":\s*"([^"]+)"') + response_pattern = re.compile(r'"Response":\s*(\d+\.?\d*)') + + reasonings = reasoning_pattern.findall(input_str) + responses = response_pattern.findall(input_str) + return responses, reasonings + diff --git a/simulation_engine/prompt_template/generative_agent/interaction/categorical_resp/batch_v1.txt b/simulation_engine/prompt_template/generative_agent/interaction/categorical_resp/batch_v1.txt new file mode 100644 index 0000000..ef58931 --- /dev/null +++ b/simulation_engine/prompt_template/generative_agent/interaction/categorical_resp/batch_v1.txt @@ -0,0 +1,44 @@ +Variables: + +Note: basically main version (ver 3) but with "reasoning" step + +### +!! + +===== + +Task: What you see above is an interview transcript. Based on the interview transcript, I want you to predict the participant's survey responses. All questions are multiple choice where you must guess from one of the options presented. + +As you answer, I want you to take the following steps: +Step 1) Describe in a few sentences the kind of person that would choose each of the response options. ("Option Interpretation") +Step 2) For each response options, reason about why the Participant might answer with the particular option. ("Option Choice") +Step 3) Write a few sentences reasoning on which of the option best predicts the participant's response ("Reasoning") +Step 4) Predict how the participant will actually respond in the survey. Predict based on the interview and your thoughts, but ultimately, DON'T over think it. Use your system 1 (fast, intuitive) thinking. ("Response") + +Here are the questions: + +!! + +----- + +Output format -- output your response in json, where you provide the following: + +{"1": {"Q": "", + "Option Interpretation": { + "