From 0df7a26e8e18bf5a406234d007fba087b5c39e2a Mon Sep 17 00:00:00 2001 From: xszyou Date: Fri, 27 Jun 2025 23:33:24 +0800 Subject: [PATCH] =?UTF-8?q?fay=E8=BF=9B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 内置RAG知识库(请把docx、pptx、txt文件存放到llm/data目录); 2. 流式回复逻辑优化; 3. 语音交互逻辑优化; 4. 线程安全增强; 5. 数字人驱动接口增加流式输出开始结束标记; 6. 修复因记忆反思而导致的记忆混乱,无法多轮对话问题; 7. 修复mcp工具获取于调用的线程同步问题; 8. 修复funasr依赖版本问题。 --- asr/funasr/requirments.txt | 6 + core/fay_core.py | 104 +++++-- core/recorder.py | 7 +- core/stream_manager.py | 60 ++-- fay_booter.py | 5 + llm/data/测试.docx | Bin 0 -> 10181 bytes llm/nlp_cognitive_stream.py | 542 +++++++++++++++++++++++++++++++-- tts/ms_tts_sdk.py | 2 - utils/stream_state_manager.py | 249 +++++++++++++++ utils/stream_text_processor.py | 183 +++++++++++ 10 files changed, 1084 insertions(+), 74 deletions(-) create mode 100644 asr/funasr/requirments.txt create mode 100644 llm/data/测试.docx create mode 100644 utils/stream_state_manager.py create mode 100644 utils/stream_text_processor.py diff --git a/asr/funasr/requirments.txt b/asr/funasr/requirments.txt new file mode 100644 index 0000000..5b693d4 --- /dev/null +++ b/asr/funasr/requirments.txt @@ -0,0 +1,6 @@ +torch +modelscope +testresources +torchaudio +FunASR +websockets~=10.4 \ No newline at end of file diff --git a/core/fay_core.py b/core/fay_core.py index a575be0..f759a41 100644 --- a/core/fay_core.py +++ b/core/fay_core.py @@ -77,20 +77,70 @@ class FeiFei: self.think_mode_users = {} # 使用字典存储每个用户的think模式状态 def __remove_emojis(self, text): + """ + 改进的表情包过滤,避免误删除正常Unicode字符 + """ + # 更精确的emoji范围,避免误删除正常字符 emoji_pattern = re.compile( "[" - "\U0001F600-\U0001F64F" # 表情符号 - "\U0001F300-\U0001F5FF" # 图标符号 - "\U0001F680-\U0001F6FF" # 交通工具和符号 - "\U0001F1E0-\U0001F1FF" # 国旗 - "\U00002700-\U000027BF" # 杂项符号 - "\U0001F900-\U0001F9FF" # 补充表情符号 - "\U00002600-\U000026FF" # 杂项符号 - "\U0001FA70-\U0001FAFF" # 更多表情 + "\U0001F600-\U0001F64F" # 表情符号 (Emoticons) + "\U0001F300-\U0001F5FF" # 杂项符号和象形文字 (Miscellaneous Symbols and Pictographs) + "\U0001F680-\U0001F6FF" # 交通和地图符号 (Transport and Map Symbols) + "\U0001F1E0-\U0001F1FF" # 区域指示符号 (Regional Indicator Symbols) + "\U0001F900-\U0001F9FF" # 补充符号和象形文字 (Supplemental Symbols and Pictographs) + "\U0001FA70-\U0001FAFF" # 扩展A符号和象形文字 (Symbols and Pictographs Extended-A) + "\U00002600-\U000026FF" # 杂项符号 (Miscellaneous Symbols) + "\U00002700-\U000027BF" # 装饰符号 (Dingbats) + "\U0000FE00-\U0000FE0F" # 变体选择器 (Variation Selectors) + "\U0001F000-\U0001F02F" # 麻将牌 (Mahjong Tiles) + "\U0001F0A0-\U0001F0FF" # 扑克牌 (Playing Cards) "]+", flags=re.UNICODE, ) - return emoji_pattern.sub(r'', text) + + # 保护常用的中文标点符号和特殊字符 + protected_chars = ["。", ",", "!", "?", ":", ";", "、", """, """, "'", "'", "(", ")", "【", "】", "《", "》"] + + # 先保存保护字符的位置 + protected_positions = {} + for i, char in enumerate(text): + if char in protected_chars: + protected_positions[i] = char + + # 执行emoji过滤 + filtered_text = emoji_pattern.sub('', text) + + # 如果过滤后文本长度变化太大,可能误删了正常字符,返回原文本 + if len(filtered_text) < len(text) * 0.5: # 如果删除了超过50%的内容 + return text + + return filtered_text + + def __process_qa_stream(self, text, username): + """ + 按流式方式分割和发送Q&A答案 + 使用安全的流式文本处理器和状态管理器 + """ + if not text or text.strip() == "": + return + + # 使用安全的流式文本处理器 + from utils.stream_text_processor import get_processor + from utils.stream_state_manager import get_state_manager + + processor = get_processor() + state_manager = get_state_manager() + + # 处理Q&A流式文本,is_qa=True表示Q&A模式 + success = processor.process_stream_text(text, username, is_qa=True, session_type="qa") + + if success: + # Q&A模式结束会话(不再需要发送额外的结束标记) + state_manager.end_session(username) + else: + util.log(1, f"Q&A流式处理失败,文本长度: {len(text)}") + # 失败时也要确保结束会话 + state_manager.force_reset_user_state(username) #语音消息处理检查是否命中q&a def __get_answer(self, interleaver, text): @@ -139,7 +189,8 @@ class FeiFei: else: text = answer - stream_manager.new_instance().write_sentence(username, "_" + text + "_") + # 使用流式分割处理Q&A答案 + self.__process_qa_stream(text, username) #完整文本记录回复并输出到各个终端 self.__process_text_output(text, username, uid ) @@ -259,7 +310,7 @@ class FeiFei: is_end = interact.data.get("isend", False) is_first = interact.data.get("isfirst", False) - if is_first and (text is None or text.strip() == ""): + if not is_first and not is_end and (text is None or text.strip() == ""): return None self.__send_panel_message(text, interact.data.get('user'), uid, 0, type) @@ -366,20 +417,24 @@ class FeiFei: except Exception as e: util.printInfo(1, "System", "音频播放初始化失败,本机无法播放音频") return - + while self.__running: time.sleep(0.01) if not self.sound_query.empty(): # 如果队列不为空则播放音频 file_url, audio_length, interact = self.sound_query.get() - is_first = False - is_end = False - if interact.data.get('isfirst'): - is_first = True - if interact.data.get('isend'): - is_end = True + + is_first = interact.data.get('isfirst') is True + is_end = interact.data.get('isend') is True + + + if file_url is not None: util.printInfo(1, interact.data.get('user'), '播放音频...') - self.speaking = True + + if is_first: + self.speaking = True + elif not is_end: + self.speaking = True #自动播报关闭 global auto_play_lock @@ -392,7 +447,7 @@ class FeiFei: if wsa_server.get_web_instance().is_connected(interact.data.get('user')): wsa_server.get_web_instance().add_cmd({"panelMsg": "播放中 ...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}) - + if file_url is not None: pygame.mixer.music.load(file_url) pygame.mixer.music.play() @@ -402,10 +457,10 @@ class FeiFei: while length < audio_length: length += 0.01 time.sleep(0.01) - + if is_end: self.play_end(interact) - util.printInfo(1, interact.data.get('user'), '结束播放!') + if wsa_server.get_web_instance().is_connected(interact.data.get('user')): wsa_server.get_web_instance().add_cmd({"panelMsg": "", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'}) # 播放完毕后通知 @@ -467,7 +522,7 @@ class FeiFei: #发送音频给数字人接口 if file_url is not None and wsa_server.get_instance().is_connected(interact.data.get("user")): - content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'} + content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'} #计算lips if platform.system() == "Windows": try: @@ -580,12 +635,13 @@ class FeiFei: :param text: 消息文本 :param username: 用户名 """ + full_text = self.__remove_emojis(text.replace("*", "")) if wsa_server.get_instance().is_connected(username): content = { 'Topic': 'human', 'Data': { 'Key': 'text', - 'Value': text + 'Value': full_text }, 'Username': username } diff --git a/core/recorder.py b/core/recorder.py index e2d3c18..e04eef0 100644 --- a/core/recorder.py +++ b/core/recorder.py @@ -128,7 +128,12 @@ class Recorder: with fay_core.auto_play_lock: fay_core.can_auto_play = False #self.on_speaking(text) - intt = interact.Interact("auto_play", 2, {'user': self.username, 'text': "在呢,你说?"}) + # 使用状态管理器处理唤醒回复 + from utils.stream_state_manager import get_state_manager + state_manager = get_state_manager() + state_manager.start_new_session(self.username, "auto_play") + + intt = interact.Interact("auto_play", 2, {'user': self.username, 'text': "在呢,你说?" , "isfirst" : True, "isend" : True}) self.__fay.on_interact(intt) self.processing = False self.timer.cancel() # 取消之前的计时器任务 diff --git a/core/stream_manager.py b/core/stream_manager.py index e725fbb..e29ea2f 100644 --- a/core/stream_manager.py +++ b/core/stream_manager.py @@ -45,45 +45,52 @@ class StreamManager: def get_Stream(self, username): """ - 获取指定用户ID的文本流,如果不存在则创建新的 + 获取指定用户ID的文本流,如果不存在则创建新的(线程安全) :param username: 用户名 :return: 对应的句子缓存对象 """ - need_start_thread = False - stream = None - nlp_stream = None - - with self.lock: - if username not in self.streams or username not in self.nlp_streams: - self.streams[username] = stream_sentence.SentenceCache(self.max_sentences) - self.nlp_streams[username] = stream_sentence.SentenceCache(self.max_sentences) - need_start_thread = True + # 注意:这个方法应该在已经获得锁的情况下调用 + # 如果从外部调用,需要先获得锁 + + if username not in self.streams or username not in self.nlp_streams: + # 创建新的流缓存 + self.streams[username] = stream_sentence.SentenceCache(self.max_sentences) + self.nlp_streams[username] = stream_sentence.SentenceCache(self.max_sentences) + + # 启动监听线程(如果还没有) + if username not in self.listener_threads: stream = self.streams[username] nlp_stream = self.nlp_streams[username] - - if need_start_thread: - thread = MyThread(target=self.listen, args=(username, stream, nlp_stream), daemon=True) - self.listener_threads[username] = thread - thread.start() - else: - stream = self.streams[username] - nlp_stream = self.nlp_streams[username] - - return stream, nlp_stream + thread = MyThread(target=self.listen, args=(username, stream, nlp_stream), daemon=True) + self.listener_threads[username] = thread + thread.start() + + return self.streams[username], self.nlp_streams[username] def write_sentence(self, username, sentence): """ - 写入句子到指定用户的文本流 + 写入句子到指定用户的文本流(线程安全) :param username: 用户名 :param sentence: 要写入的句子 :return: 写入是否成功 """ + # 检查句子长度,防止过大的句子导致内存问题 + if len(sentence) > 10240: # 10KB限制 + sentence = sentence[:10240] + if sentence.endswith('_'): self.clear_Stream(username) - Stream, nlp_Stream = self.get_Stream(username) - success = Stream.write(sentence) - nlp_success = nlp_Stream.write(sentence) - return success and nlp_success + + # 使用锁保护获取和写入操作 + with self.lock: + try: + Stream, nlp_Stream = self.get_Stream(username) + success = Stream.write(sentence) + nlp_success = nlp_Stream.write(sentence) + return success and nlp_success + except Exception as e: + print(f"写入句子时出错: {e}") + return False def clear_Stream(self, username): """ @@ -111,12 +118,11 @@ class StreamManager: :param sentence: 要处理的句子 """ fay_core = fay_booter.feiFei - # 处理普通消息,区分是否是会话的第一句 is_first = "_" in sentence is_end = "_" in sentence sentence = sentence.replace("_", "").replace("_", "") - + if sentence or is_first or is_end : interact = Interact("stream", 1, {"user": username, "msg": sentence, "isfirst" : is_first, "isend" : is_end}) fay_core.say(interact, sentence) # 调用核心处理模块进行响应 diff --git a/fay_booter.py b/fay_booter.py index 75191a3..d22cd57 100644 --- a/fay_booter.py +++ b/fay_booter.py @@ -345,6 +345,11 @@ def start(): from llm.nlp_cognitive_stream import init_memory_scheduler init_memory_scheduler() + #初始化知识库 + util.log(1, '初始化本地知识库...') + from llm.nlp_cognitive_stream import init_knowledge_base + init_knowledge_base() + #开启录音服务 record = config_util.config['source']['record'] if record['enabled']: diff --git a/llm/data/测试.docx b/llm/data/测试.docx new file mode 100644 index 0000000000000000000000000000000000000000..616b7a193c0dcdcf649624287f2b70f288576c18 GIT binary patch literal 10181 zcma)i1yo$ivi9KaZoz}QySuvtcZcBa4#6e327m zMCyP50B=D70OWs@89F%7yW84i#VgqdGNOm9rG6KPYZn3O<5FQn{5Dma7Jl)Cbq91T z*LKnBa;+*_Y~gVAn5&8Jc2OiHweo5_7a^5IObM9;2|SS;%O&zhUw;nigeJB0qBIEA zr$C{R?`Huv4XU|xP+Kj~hRV~LWtiA_vw{+G>$Baq)FnJtS<`GOYJE86AWmjm3XsHe!aEebgV!jwO^G__M?YPp=g5_e#=|zlry4P*{xGPeaZT{wZ=Ke zvyA;7@^?${CS)#211(_!v;@k(En#f$X!>G`! zOG}~$yx00+ZH?k*g7@_}fC!Kk)f`Iau8%8FSkrt7nm2>oxX1308IRCMzjoVW+(zp_Ca zkt#?M(kG+z-hR+Bl-py`0+HWx4}{?C4kvBK2=!#l*Bjc)on|OH29GxSDCF;qNud#AFTmkEghJK4WaU@Xeq59B#fNheZv?8^GW{ct3!C>dxWKsm@uz}Noklk?5H*N` zmF7%bB1Fw~h+E%fVh?OxT?Vs{(p#)>TT<`MqHcH=slFq(|$It5*onFk6(!oW@S!N7S8LJO_;fOJPd2V52<62( z&i>F3!QX}EmgM*%I_3~4JHQUMWsi-mfLK*u^Hf6xeYhl@4J92a=R&3ry({uK8X2w8 z+G=w>0wu3e>`j~aSRyN}eXQ>4k;YRtwc2i|b$|GfyXsov(&Ko?)XvPb7?YtZKj+L< z2^&2-)NoZIX_*+dk&#rFGfcO2REMcmJ4j-^A(ywclpU2nD5X2cl34Ohh(pF+?oxnw zs}(jjiu)Pn#f1Oe*+BqJ=w|O|@`t0l3|Q372Rg!1Qvd+zZxSbG4;xdb7Yn!PzHutx z#PnS&eMCZ6_>2_clwra>b6j$)!>*g?Y&tSDS_fZSVJTSd&$+MPA+c#>6tu|guuH|n zzUIc6iX0;`rtExS9hZaVQ42}P4&3*Uz z$9~T9O>_I`c~iLgD*)MoDZcEZT6S@2q7-Ji`)I95+BtZ;jwD7MK7Z*Bnc0C?{3~!~ zc`f=~64cwJ?d`q2T0u>ERY-G|uGnPB3l$#;a^%;`#-!S@pi=cyBUEi(2{~5APCO{YrRBj_-UebQ zW$(uxD5(*aDeO*E$exeB(XXXN-h@%c1Q>y6y5%Iw-kgp}KTq8vm}@<8p_M=G$hpYo z;Fx&8okl^22;5%2HIglPZj_bA0~^t|T|n}H)#9NULFhD%-&T3v#<#+#OlElUvX-_v zn1jkrOoqPAp~>bo=MAo{^yS>DX+|>~bScpo$nM~$CdNSDoPIvd<925VI^oUS5iD0~ z=*XvS^+0ehm4UQ))^NhjG>V$URt`s~<%(%6aC5R6{6n5046H~GYEnojIVCKDj3O#- zN-hyAtq@ld0`UpWGJAse!_pVwVcvMIYsV*PzyvCQgO8d7bmd_)vY!FHdFQCZiq$o} zZ6^HK&INL`zJ&(_a>%oiM=o2(9j=BJ{qpd$_IePu`*}C%9=0bfe1Ly~&>21?Vt+?) zkp;TG1B3{vw|9*C4sxw1$%dYCuT!Ls!a@RKCPm&CI^l2A=&WDoRdH#~laH<8BCVou z;8Wh^WVM)N%k@H1^|lQ{+lk@M+$)Qs6oWpCt4j1c1~!&;kz2IlF4vtUj<2J#>Lb{t z;h^gVv-2i5(Hxd!>UF8yw$tJ>uW6jBwpXo}cm1l+zjNYaWrqS^{_1F?hde143ns=+ZqE%H zd@MOkt)EJJldup=vQ>R?Mp*uG`#uwSmL%C?M+;)0*%@SJZeKMBgR$(!Jpp3mVfeZw z= zR7CKYWLTC>Ly(bQC8gUh33{KzW^$<8dUD7J_yM(^v@DQj3@c4$0aq!}-UDYTm3C&S zT;vn-nzS&7%c5O-cVu#FH#u~^lAE~mS%_l9El57)1QK{dp^C3vQ_w5tlz(-TXyg$S zN(*AIn3z}Iuq4(KZ!+g_KQI~gE!r7*fjhotw*Ym^@c5}%ZAmY8a(*r9 zE&!fxltfUXwtU>9J~q5@PXbMP*pR8w!4E_=xioS|&Qr&LHwt~LQ9IPvW@R9;S>t1i zY&7k&euZZgo^)9l_$NOuRN!vFki<~tD!bIDZWo@=2h+KqcfO(TLPp!5X7|)?t|F`I zG2dw3Gx)~Xe?rx~H$1wU_|5^G7;~J+Mz44$+HGFgBfgsbx=k1|8Z`Fjejzflo z)d*`x38@oU5svo&LSpFj%yS!Tt^h*IB_tCrD2|jtt=1aR z@y&N+j%fC@5po!~Sxcz)w5f7lgJHycvAJxonk&~B0s8s;CEpU8^KLh zx*J#hpVovCZ>`Z`x3yqxpaX83!OJRYdnS(B62)=eQeNP0tYrhTh+$6~MaS*4`EI8> zoFirAFOJ9^G3{yN>wruErIFfxy{d^C?QYy+vp8vC^#;f^yN5;i;tS(g_S^MX5@-Fj z^Pu%KtD2Z^obP~AhM)k&_zNww{V$Od2IX37$gqLRSg6nYGW~BUx9L|0wL)Sx-q|6_1OaDzhaIpKLuLm<|ua({cYF>L)Yl?scrL zb!y&^BL$AL1mB~pm}W_mtogt+hb#`U{2;CZT7oqg0r>NJS&(?ohjlY+9W&2&=R6vo zIo?hSGseN+C7g28QBoe8zg7`Z`l-5kV1HNEC)gV5x{&p1_TA{5ey-tezAHi++o%9l z$tek1_;ye_w$nY60r(zDt0^=J+O~)}ehjlvvT)czz`PVAl~H2cNcrPUQpQ#N<80=4 zO_hosjw$sK_=O2b$t4d}Gq4AzzJk4gwBUiIfxpF4rygZ3IurA#nz0H6sN(!FF$KO;UT zQ)g#OJM*6bU4!}?`z3aCU&88VgT0lMWX>+?UaAC6JDqu3Me-#UF@iTb?kJMOr5A^8 zl9B4wg+gBpeVJV7*IEeAH@DonB)ld=skjtshgfj}3QLU~ITT(Irfdl$#Mg+b2h+e@ zYy)23y5HX0W9iy;YNNA|2FJa|a?YNfW%o$O)ea&u%g2oyT&AWA?^9@#f%u|e%TgrG z+FP6O)(mNX$k{XdlF=%=UVE3XSCg?cS?SH8s$;_nGwh@O^rV0>hXp;lErjGDT@fzl zU2b8CmN7~CjT$R_O$^cfb>KYQtDR3aciAR{cE!#v(>FNRnZgAa5XRMlfng31hUd_U zV|pM2Dtuk4#a*rkxr%%lcT?*v8&#KUA6-9grYEQGDVo{%W9hN(7O|;EK-6r9&$+A= zT8{<~uSQ_0jmWE*AV?2MPI@;aN+%{5d$}&qw)C!gt4l|JY6X$2S-#lrVU;2v5QR~D z%ZS;lEM?7G(~JEU6!RdeR;;YS%G3o4+!DHm0^1U^ehpQJ+EgD#-6v@t#GJjlcb!o(|8AW`BR2arXq1F|J#R%nwmk1% z!yH@Mdon7R&rT7-*isA^pM~i&HMcVvytwj|%99>B=WeWEZ<~P~(S99=T=Xd1 zIQ%eTjgK{!=G~oz92?WFmRW{a`dKoQqDqn&#uryT;k*&l2Dw#iR2ma=O`g7|C+hH= z@c*A9!u{C z3>!M)99t=6OP#Zaq_y?&Od@qMMfPndSlmP(rd3EuQewqB1g1bkvhg~Y@r|YTXYf+) zJ4525Utspe_QcsDaq)0SJg~PZ!5Z8U?>Bzvn8yZg>cFxw=C@X{Zju{#t19L6$vRUa z#h}=Ro?cQ<@D6Ho2A8}1i-(h%H=^ajcA?gJwOP?2DIHRlK zwRr~66B~&C*0N+V+RNido-ZBX6aB`_BmQ(K?!`E2H*xNASiyco%s%jf3w zvTH&-p0pwCyf?-3_7TfZ;EDC-aNmuA&}U}=4)6J1A?NYOVX7~Y>ST8&mcGF4)fA#9}R?Fyl4w(;TkBH3OeN2XnsVBO)CK#~WFvhgl4F z7hxt636q1%<4X|IGvfy%UX!;$w1cu0_tA4ZtBDFO6`VGH0H{zzYfP%9 z8r`t>6CZ#+R0s1}*`LQ5ny^_j3vHTf&Qoi@`B5`)?ABm<;;>nf*7iZX4Y6rCncyD2g5)LE9E`Oz+J}E6a1gmn`Mc2u|O< z@JSm4E$j*3K}F}zxl%jK>4s-BCBIL(+tRK$XSWMRNAV+cNEfu6_6D}R@#%XxB6;B>6=b<} zL}Rj+p~H+*)9#S(B=PCx*3@Y;#H6Sf{t4w6cIo(Oy^}gBJ8jIz9q8{ws?&GpGJ8zN zo{(Hr&Zfukb}-%d&yiyX+{tXsv`Hi3b{j0%Y+&p6i_fGO3srp4K-6fIKIC!Yqk)(| zdxASFGdO6dRbk9@5IJ_HwThT5DtrMy#v9+$43}3n`3R>{Uh@4bSr+7ow!_R}{J$2Z81Sd?mE*9Yc5qtjGb88M6irC%>}cAydD;a*S*mZF8kQ5zaumdpqjm zcXadw`rG-d<rU+=f?_S7ek#gpc+`v(HLi4}vU_f)1Kc;T_%AbC z8(Q$LJy|yHJ)?iW_hWB`(k+6!u27*4*$3d(KO+4eml`0^4)g#Il3k$w6_)~QFD|yG zcFr#^jt^>5c1w)tD_DC3;K#O7+3V2pL9_CKwA$bfwe%K|3D7cRLds@YXwFNz_FvxY z;6F~AG|30S)5z!8-;fDRFtn$JG!;m*bQz-~f^93wi*~zX7TkX>>%#X~G{{~TcEBy* zo7pouJL_qAhbmN>;J5_>J{O-%Cf>rteC{LXsBOeYi=!kQH4Ps@=@VZRt+)o8d{aOK ze~_&js8|tTA90IlJ>8vLY$2?Ss8gay&hAO6x|^zLOTz_ljzCzd4ou@SfT}|j@BOIE zYXyPeDP$)z4Q-WS7smbe4BWoXX~rzq>=J`cFle3I{M=xx!oq;*gzPArrG_Soetg)$ zv3u@~?+{KodZPn6R3}<9MM6Al#&G^n5YJjq$y3w~YG(%u_cAKWDx6owmjcz0E+;8l z`%pB;+{|nXKlY}AJ6}R2X?Nul5286~)!J0jh5Jdno5lf0hI@?Rz=a65M{57XrjGHN z!Z!Sp)MWO29m60m9qXe>X+E-6Ix>kW2K@VK=lS5)1W^tgELeNkQ?s53^PZA#ub{OB zj~2sCPPAdB5+#j7>=UrxX{2_|ST&Z5Y*#p=jWA(1c=icyO?2U@FKFG}_r@KpKDXMK-ig1B59#LAw zDE(o`<`IB``SkvDG}(NR?9%+=tW_vs0ju0#78>RvQEKd?W{OB-kB1@>M|WohCoL8$ zpY|A)Jf-T%Evxfs335Ne#JITKd$pohkO%}QAv|()Odtx0U#9y`hXXi3PCx6??CPYo z1;+)xy}2{k<8B&7juSw%wn#8;u=BTamNYlV|1_o+B!8_>{{;B`iRbs486^Y*0M>y) z73$yNl9|1ov#Ozy&CkMBTgsZl6MywBAn zjQN#;#^-d`7Pq}u1_#+@$+J!KHDhZDL~)C+LAxS+OrFoPPHrB~;U?PfN?8;0x3*Na zy;Myc2YZ{h{95G|P}io&QA84Cgm|7Zt)9;}ugWg8yvW2#_}()(XP6WCDW{0>Ks&SN zW$vkHDS{wLD8E(i)@MWHRLUvtlX~qTB+YKXMe|K4TyaFdP$|qLz5r!p4=+uQ3*$o; zA!WLVIMrn^hG&|st7K!EYLDFX$StSg4k;6s&0%?W6Yc}fY@be)h`HsPHVQ{0x2)dBYJG# z@z{2+1O^HQ4GuUxuLO^;FBOclq|)_{>WPaPWMyz;n(qTWVhd%IQbt?p0>G|D5eG`$ zaKh&Z#uUN3kVh2hF3)O^!=UryL|nCV=ocv5h<4S1UIiZN{qiiw^gZRiQ=Mb-0=n%hDXBP zFm^9769K(UNz`ho=n865yh28`abimsbl7W-yql8$E4eJzsw)f{;~al#^Rv_a$?;sz zM3n&+8>^>ak0N7((n*@alt`9_%`5V5I|19v44JI*(kAQ=7+-6Sl!mFdC=cn0A+sXnK)wdZnh!YsH*vZ0lhNp+ZJ!1)m)O6PdFQ>rL zdt*EqypbgZ@WlqjPcrnO3f*hhB41svSv;=ZrS)UO ztBS$E21OlVdxD*e9j>5J8sTXl8ViE4pngAC$~eOF9665&Wr{~l6@-M`y)0*I0ilo) zlf^Ea>f_bIO!LYrzpv*zkkjHn@IC;!N-JJ$f;?~c0nCg_B|Klp=yGs-M;u#scVdw^ zV}Wt}=qY%=xJ-)2nF31;;e)g3D+isjkqMwxX0Ug8XIS?djJU4+{nM_)M;(!m>SF5) zBTNKXLQ(v^hjsU*ejvZQc<&)8Un5|J935C&eW{SY5dpe)U_7I*>fvDOr2Df%o-|_D z&4?^^oqUHB-=t==mtRT^8pWSN2X_Zj$pTX5JH(=Rb1x5S2!bxH#1W6WbdA5J`6@bP z+zd|AB9K9>V=n@}6*2j2O1V`*@7yb%m`a}tk2lE(f>ONau&m1j1B6BIs0N)S|AvVC z0EaU~EPUOg5OIiVbB;AxXSn9Mgf~PO2%+Vl`G|KsN(6cE8@6pFWxNjIZ+RiV(YcgkG{Sf))gy zJaN8B-lqF7$e7}sIGFh;frg@B-41LiuOk^5bF~Zq@GNoH?P+C!+0{ZXm$Jh*7B}|t zrT6;X9m1vQ?Q^=ii!O=%Y}SI?TSqoEa<0QS#~<-a+pa#|GRjYfJ&K(05PXlYDNv9m zLayP}dm#LQ|7m*)w!Ze15Y+Zqiua5Ftk&lTgI(Uc{kYn;sn9h=84yr(z(1Wke;Ng7 z7XS!=2zZqx`L8DWwZMNbn!k9OU*n|-sG0C@pYrFXe|nx5%&$$C|AYDYKl$G_dud1d zHC{G*8UJcd`d!FB+ml|1f9*%{Z`fbWN`G?uXG_gL=$e;Co4-4v{%)}O6aHso$Uksu zV7&7W{6B3X|C`V+Z3(}|i_k~FqwBw#6kY`W??!oX?Z3tg1eEwU&-5p=e>$-j;4e+l z{0{toT-u*2{h7l2vyvOmUo8B4jX&Xk#)JRBErHJUr_jGehkxS#JlX#TzlQi1{@>^P zf5QLl^Y{mj3^c*7GQoekK7N7!FWN@|`n{jvzfTN*e@_2AGyG>)?%#p`oFFR7K>YmI RAp`sFCV|h>8tDt@e*nM*$Y1~f literal 0 HcmV?d00001 diff --git a/llm/nlp_cognitive_stream.py b/llm/nlp_cognitive_stream.py index 7e2de15..d49604f 100644 --- a/llm/nlp_cognitive_stream.py +++ b/llm/nlp_cognitive_stream.py @@ -11,6 +11,28 @@ from pydantic import create_model from langchain.tools import StructuredTool from langgraph.prebuilt import create_react_agent +# 新增:本地知识库相关导入 +import re +from pathlib import Path +import docx +from docx.document import Document +from docx.oxml.table import CT_Tbl +from docx.oxml.text.paragraph import CT_P +from docx.table import _Cell, Table +from docx.text.paragraph import Paragraph +try: + from pptx import Presentation + PPTX_AVAILABLE = True +except ImportError: + PPTX_AVAILABLE = False + +# 用于处理 .doc 文件的库 +try: + import win32com.client + WIN32COM_AVAILABLE = True +except ImportError: + WIN32COM_AVAILABLE = False + from utils import util import utils.config_util as cfg from genagents.genagents import GenerativeAgent @@ -21,7 +43,7 @@ from core import stream_manager os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" os.environ["LANGCHAIN_API_KEY"] = "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" -os.environ["LANGCHAIN_PROJECT"] = "fay3.0.0_github" +os.environ["LANGCHAIN_PROJECT"] = "fay3.1.2_github" # 加载配置 cfg.load_config() @@ -83,6 +105,395 @@ def get_current_time_step(username=None): util.log(1, f"获取time_step时出错: {str(e)},使用0代替") return 0 +# 新增:本地知识库相关函数 +def read_doc_file(file_path): + """ + 读取doc文件内容 + + 参数: + file_path: doc文件路径 + + 返回: + str: 文档内容 + """ + try: + # 方法1: 使用 win32com.client(Windows系统,推荐用于.doc文件) + if WIN32COM_AVAILABLE: + word = None + doc = None + try: + import pythoncom + pythoncom.CoInitialize() # 初始化COM组件 + + word = win32com.client.Dispatch("Word.Application") + word.Visible = False + doc = word.Documents.Open(file_path) + content = doc.Content.Text + + # 先保存内容,再尝试关闭 + if content and content.strip(): + try: + doc.Close() + word.Quit() + except Exception as close_e: + util.log(1, f"关闭Word应用程序时出错: {str(close_e)},但内容已成功提取") + + try: + pythoncom.CoUninitialize() # 清理COM组件 + except: + pass + + return content.strip() + + except Exception as e: + util.log(1, f"使用 win32com 读取 .doc 文件失败: {str(e)}") + finally: + # 确保资源被释放 + try: + if doc: + doc.Close() + except: + pass + try: + if word: + word.Quit() + except: + pass + try: + pythoncom.CoUninitialize() + except: + pass + + # 方法2: 简单的二进制文本提取(备选方案) + try: + with open(file_path, 'rb') as f: + raw_data = f.read() + # 尝试提取可打印的文本 + text_parts = [] + current_text = "" + + for byte in raw_data: + char = chr(byte) if 32 <= byte <= 126 or byte in [9, 10, 13] else None + if char: + current_text += char + else: + if len(current_text) > 3: # 只保留长度大于3的文本片段 + text_parts.append(current_text.strip()) + current_text = "" + + if len(current_text) > 3: + text_parts.append(current_text.strip()) + + # 过滤和清理文本 + filtered_parts = [] + for part in text_parts: + # 移除过多的重复字符和无意义的片段 + if (len(part) > 5 and + not part.startswith('Microsoft') and + not all(c in '0123456789-_.' for c in part) and + len(set(part)) > 3): # 字符种类要多样 + filtered_parts.append(part) + + if filtered_parts: + return '\n'.join(filtered_parts) + + except Exception as e: + util.log(1, f"使用二进制方法读取 .doc 文件失败: {str(e)}") + + util.log(1, f"无法读取 .doc 文件 {file_path},建议转换为 .docx 格式") + return "" + + except Exception as e: + util.log(1, f"读取doc文件 {file_path} 时出错: {str(e)}") + return "" + +def read_docx_file(file_path): + """ + 读取docx文件内容 + + 参数: + file_path: docx文件路径 + + 返回: + str: 文档内容 + """ + try: + doc = docx.Document(file_path) + content = [] + + for element in doc.element.body: + if isinstance(element, CT_P): + paragraph = Paragraph(element, doc) + if paragraph.text.strip(): + content.append(paragraph.text.strip()) + elif isinstance(element, CT_Tbl): + table = Table(element, doc) + for row in table.rows: + row_text = [] + for cell in row.cells: + if cell.text.strip(): + row_text.append(cell.text.strip()) + if row_text: + content.append(" | ".join(row_text)) + + return "\n".join(content) + except Exception as e: + util.log(1, f"读取docx文件 {file_path} 时出错: {str(e)}") + return "" + +def read_pptx_file(file_path): + """ + 读取pptx文件内容 + + 参数: + file_path: pptx文件路径 + + 返回: + str: 演示文稿内容 + """ + if not PPTX_AVAILABLE: + util.log(1, "python-pptx 库未安装,无法读取 PowerPoint 文件") + return "" + + try: + prs = Presentation(file_path) + content = [] + + for i, slide in enumerate(prs.slides): + slide_content = [f"第{i+1}页:"] + + for shape in slide.shapes: + if hasattr(shape, "text") and shape.text.strip(): + slide_content.append(shape.text.strip()) + + if len(slide_content) > 1: # 有内容才添加 + content.append("\n".join(slide_content)) + + return "\n\n".join(content) + except Exception as e: + util.log(1, f"读取pptx文件 {file_path} 时出错: {str(e)}") + return "" + +def load_local_knowledge_base(): + """ + 加载本地知识库内容 + + 返回: + dict: 文件名到内容的映射 + """ + knowledge_base = {} + + # 获取llm/data目录路径 + current_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.join(current_dir, "data") + + if not os.path.exists(data_dir): + util.log(1, f"知识库目录不存在: {data_dir}") + return knowledge_base + + # 遍历data目录中的文件 + for file_path in Path(data_dir).iterdir(): + if not file_path.is_file(): + continue + + file_name = file_path.name + file_extension = file_path.suffix.lower() + + try: + if file_extension == '.docx': + content = read_docx_file(str(file_path)) + elif file_extension == '.doc': + content = read_doc_file(str(file_path)) + elif file_extension == '.pptx': + content = read_pptx_file(str(file_path)) + else: + # 尝试作为文本文件读取 + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + except UnicodeDecodeError: + try: + with open(file_path, 'r', encoding='gbk') as f: + content = f.read() + except UnicodeDecodeError: + util.log(1, f"无法解码文件: {file_name}") + continue + + if content.strip(): + knowledge_base[file_name] = content + util.log(1, f"成功加载知识库文件: {file_name} ({len(content)} 字符)") + + except Exception as e: + util.log(1, f"加载知识库文件 {file_name} 时出错: {str(e)}") + + return knowledge_base + +def search_knowledge_base(query, knowledge_base, max_results=3): + """ + 在知识库中搜索相关内容 + + 参数: + query: 查询内容 + knowledge_base: 知识库字典 + max_results: 最大返回结果数 + + 返回: + list: 相关内容列表 + """ + if not knowledge_base: + return [] + + results = [] + query_lower = query.lower() + + # 搜索关键词 + query_keywords = re.findall(r'\w+', query_lower) + + for file_name, content in knowledge_base.items(): + content_lower = content.lower() + + # 计算匹配度 + score = 0 + matched_sentences = [] + + # 按句子分割内容 + sentences = re.split(r'[。!?\n]', content) + + for sentence in sentences: + if not sentence.strip(): + continue + + sentence_lower = sentence.lower() + sentence_score = 0 + + # 计算关键词匹配度 + for keyword in query_keywords: + if keyword in sentence_lower: + sentence_score += 1 + + # 如果句子有匹配,记录 + if sentence_score > 0: + matched_sentences.append((sentence.strip(), sentence_score)) + score += sentence_score + + # 如果有匹配的内容 + if score > 0: + # 按匹配度排序句子 + matched_sentences.sort(key=lambda x: x[1], reverse=True) + + # 取前几个最相关的句子 + relevant_sentences = [sent[0] for sent in matched_sentences[:5] if sent[0]] + + if relevant_sentences: + results.append({ + 'file_name': file_name, + 'score': score, + 'content': '\n'.join(relevant_sentences) + }) + + # 按匹配度排序 + results.sort(key=lambda x: x['score'], reverse=True) + + return results[:max_results] + +# 全局知识库缓存 +_knowledge_base_cache = None +_knowledge_base_load_time = None +_knowledge_base_file_times = {} # 存储文件的最后修改时间 + +def check_knowledge_base_changes(): + """ + 检查知识库文件是否有变化 + + 返回: + bool: 如果有文件变化返回True,否则返回False + """ + global _knowledge_base_file_times + + # 获取llm/data目录路径 + current_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.join(current_dir, "data") + + if not os.path.exists(data_dir): + return False + + current_file_times = {} + + # 遍历data目录中的文件 + for file_path in Path(data_dir).iterdir(): + if not file_path.is_file(): + continue + + file_name = file_path.name + file_extension = file_path.suffix.lower() + + # 只检查支持的文件格式 + if file_extension in ['.docx', '.doc', '.pptx', '.txt'] or file_extension == '': + try: + mtime = os.path.getmtime(str(file_path)) + current_file_times[file_name] = mtime + except OSError: + continue + + # 检查是否有变化 + if not _knowledge_base_file_times: + # 第一次检查,保存文件时间 + _knowledge_base_file_times = current_file_times + return True + + # 比较文件时间 + if set(current_file_times.keys()) != set(_knowledge_base_file_times.keys()): + # 文件数量发生变化 + _knowledge_base_file_times = current_file_times + return True + + for file_name, mtime in current_file_times.items(): + if file_name not in _knowledge_base_file_times or _knowledge_base_file_times[file_name] != mtime: + # 文件被修改 + _knowledge_base_file_times = current_file_times + return True + + return False + +def init_knowledge_base(): + """ + 初始化知识库,在系统启动时调用 + """ + global _knowledge_base_cache, _knowledge_base_load_time + + util.log(1, "初始化本地知识库...") + _knowledge_base_cache = load_local_knowledge_base() + _knowledge_base_load_time = time.time() + + # 初始化文件修改时间跟踪 + check_knowledge_base_changes() + + util.log(1, f"知识库初始化完成,共 {len(_knowledge_base_cache)} 个文件") + +def get_knowledge_base(): + """ + 获取知识库,使用缓存机制 + + 返回: + dict: 知识库内容 + """ + global _knowledge_base_cache, _knowledge_base_load_time + + # 如果缓存为空,先初始化 + if _knowledge_base_cache is None: + init_knowledge_base() + return _knowledge_base_cache + + # 检查文件是否有变化 + if check_knowledge_base_changes(): + util.log(1, "检测到知识库文件变化,正在重新加载...") + _knowledge_base_cache = load_local_knowledge_base() + _knowledge_base_load_time = time.time() + util.log(1, f"知识库重新加载完成,共 {len(_knowledge_base_cache)} 个文件") + + return _knowledge_base_cache + + # 定时保存记忆的线程 def memory_scheduler_thread(): """ @@ -361,6 +772,21 @@ def question(content, username, observation=None): except Exception as e: util.log(1, f"获取相关记忆时出错: {str(e)}") + # 新增:搜索本地知识库 + knowledge_context = "" + try: + knowledge_base = get_knowledge_base() + if knowledge_base: + knowledge_results = search_knowledge_base(content, knowledge_base, max_results=3) + if knowledge_results: + knowledge_context = "**本地知识库相关信息**:\n" + for result in knowledge_results: + knowledge_context += f"来源文件:{result['file_name']}\n" + knowledge_context += f"{result['content']}\n\n" + util.log(1, f"找到 {len(knowledge_results)} 条相关知识库信息") + except Exception as e: + util.log(1, f"搜索知识库时出错: {str(e)}") + # 使用文件开头定义的llm对象进行流式请求 observation = "**还观察的情况**:" + observation + "\n" if observation else "" @@ -437,12 +863,60 @@ def question(content, username, observation=None): if react_response_text and react_response_text.strip(): if is_agent_think_start: react_response_text = "" + react_response_text - stream_manager.new_instance().write_sentence(username, react_response_text) + # 对React Agent的最终回复也进行分句处理 + accumulated_text += react_response_text + # 使用安全的流式文本处理器和状态管理器 + from utils.stream_text_processor import get_processor + from utils.stream_state_manager import get_state_manager + + processor = get_processor() + state_manager = get_state_manager() + + # 确保有活跃会话 + if not state_manager.is_session_active(username): + state_manager.start_new_session(username, "react_agent") + + # 如果累积文本达到一定长度,进行处理 + if len(accumulated_text) >= 20: # 设置一个合理的阈值 + # 找到最后一个标点符号的位置 + last_punct_pos = -1 + for punct in processor.punctuation_marks: + pos = accumulated_text.rfind(punct) + if pos > last_punct_pos: + last_punct_pos = pos + + if last_punct_pos > 10: # 确保有足够的内容发送 + sentence_text = accumulated_text[:last_punct_pos + 1] + # 使用状态管理器准备句子 + marked_text, _, _ = state_manager.prepare_sentence(username, sentence_text) + stream_manager.new_instance().write_sentence(username, marked_text) + accumulated_text = accumulated_text[last_punct_pos + 1:].lstrip() + except (KeyError, IndexError, AttributeError): react_response_text = f"抱歉,我现在太忙了,休息一会,请稍后再试。" + if is_first_sentence: + react_response_text = "_" + react_response_text + is_first_sentence = False stream_manager.new_instance().write_sentence(username, react_response_text) full_response_text += react_response_text + + # 确保React Agent最后一段文本也被发送,并标记为结束 + from utils.stream_state_manager import get_state_manager + state_manager = get_state_manager() + + if accumulated_text: + # 使用状态管理器准备最后的文本,强制标记为结束 + marked_text, _, _ = state_manager.prepare_sentence(username, accumulated_text, force_end=True) + stream_manager.new_instance().write_sentence(username, marked_text) + else: + # 如果没有剩余文本,检查是否需要发送结束标记 + session_info = state_manager.get_session_info(username) + if session_info and not session_info.get('is_end_sent', False): + # 发送一个空的结束标记 + marked_text, _, _ = state_manager.prepare_sentence(username, "", force_end=True) + stream_manager.new_instance().write_sentence(username, marked_text) + else: try: @@ -452,24 +926,50 @@ def question(content, username, observation=None): if not flush_text: continue accumulated_text += flush_text - for mark in punctuation_marks: - if mark in accumulated_text: - last_punct_pos = max(accumulated_text.rfind(p) for p in punctuation_marks if p in accumulated_text) - if last_punct_pos != -1: - to_write = accumulated_text[:last_punct_pos + 1] - accumulated_text = accumulated_text[last_punct_pos + 1:] - if is_first_sentence: - to_write += "_" - is_first_sentence = False - stream_manager.new_instance().write_sentence(username, to_write) - break + # 使用安全的流式处理逻辑和状态管理器 + from utils.stream_text_processor import get_processor + from utils.stream_state_manager import get_state_manager + + processor = get_processor() + state_manager = get_state_manager() + + # 确保有活跃会话 + if not state_manager.is_session_active(username): + state_manager.start_new_session(username, "llm_stream") + + # 如果累积文本达到一定长度,进行处理 + if len(accumulated_text) >= 20: # 设置一个合理的阈值 + # 找到最后一个标点符号的位置 + last_punct_pos = -1 + for punct in processor.punctuation_marks: + pos = accumulated_text.rfind(punct) + if pos > last_punct_pos: + last_punct_pos = pos + + if last_punct_pos > 10: # 确保有足够的内容发送 + sentence_text = accumulated_text[:last_punct_pos + 1] + # 使用状态管理器准备句子 + marked_text, _, _ = state_manager.prepare_sentence(username, sentence_text) + stream_manager.new_instance().write_sentence(username, marked_text) + accumulated_text = accumulated_text[last_punct_pos + 1:].lstrip() + full_response_text += flush_text - # 确保最后一段文本也被发送 + # 确保最后一段文本也被发送,并标记为结束 + from utils.stream_state_manager import get_state_manager + state_manager = get_state_manager() + if accumulated_text: - if is_first_sentence: #相当于整个回复没有标点 - accumulated_text += "_" - is_first_sentence = False - stream_manager.new_instance().write_sentence(username, accumulated_text) + # 使用状态管理器准备最后的文本,强制标记为结束 + marked_text, _, _ = state_manager.prepare_sentence(username, accumulated_text, force_end=True) + stream_manager.new_instance().write_sentence(username, marked_text) + else: + # 如果没有剩余文本,检查是否需要发送结束标记 + session_info = state_manager.get_session_info(username) + if session_info and not session_info.get('is_end_sent', False): + # 发送一个空的结束标记 + marked_text, _, _ = state_manager.prepare_sentence(username, "", force_end=True) + stream_manager.new_instance().write_sentence(username, marked_text) + except requests.exceptions.RequestException as e: util.log(1, f"请求失败: {e}") @@ -477,8 +977,10 @@ def question(content, username, observation=None): stream_manager.new_instance().write_sentence(username, "_" + error_message + "_") full_response_text = error_message - # 发送结束标记 - stream_manager.new_instance().write_sentence(username, "_") + # 结束会话(不再需要发送额外的结束标记) + from utils.stream_state_manager import get_state_manager + state_manager = get_state_manager() + state_manager.end_session(username) # 在单独线程中记忆对话内容 MyThread(target=remember_conversation_thread, args=(username, content, full_response_text.split("")[-1])).start() diff --git a/tts/ms_tts_sdk.py b/tts/ms_tts_sdk.py index 37c803c..5965f0e 100644 --- a/tts/ms_tts_sdk.py +++ b/tts/ms_tts_sdk.py @@ -6,10 +6,8 @@ from tts import tts_voice from tts.tts_voice import EnumVoice from utils import util, config_util from utils import config_util as cfg -import pygame import edge_tts from pydub import AudioSegment -from scheduler.thread_manager import MyThread class Speech: def __init__(self): diff --git a/utils/stream_state_manager.py b/utils/stream_state_manager.py new file mode 100644 index 0000000..380d4ac --- /dev/null +++ b/utils/stream_state_manager.py @@ -0,0 +1,249 @@ +import threading +import time +from utils import util +from enum import Enum + +class StreamState(Enum): + """流式状态枚举""" + IDLE = "idle" # 空闲状态 + FIRST_SENTENCE = "first" # 第一句话 + MIDDLE_SENTENCE = "middle" # 中间句子 + LAST_SENTENCE = "last" # 最后一句话 + COMPLETED = "completed" # 完成状态 + +class StreamStateManager: + """ + 流式状态管理器 - 统一管理isfirst/isend标记 + 解决多处设置标记导致的状态不一致问题 + """ + + def __init__(self): + self.lock = threading.RLock() + self.user_states = {} # 用户名 -> 状态信息 + self.session_counters = {} # 用户名 -> 会话计数器 + + def start_new_session(self, username, session_type="stream"): + """ + 开始新的流式会话 + + 参数: + username: 用户名 + session_type: 会话类型 (stream, qa, auto_play等) + + 返回: + session_id: 会话ID + """ + with self.lock: + if username not in self.session_counters: + self.session_counters[username] = 0 + + self.session_counters[username] += 1 + session_id = f"{username}_{session_type}_{self.session_counters[username]}_{int(time.time())}" + + self.user_states[username] = { + 'session_id': session_id, + 'session_type': session_type, + 'state': StreamState.IDLE, + 'sentence_count': 0, + 'start_time': time.time(), + 'last_update': time.time(), + 'is_first_sent': False, + 'is_end_sent': False + } + + util.log(1, f"开始新会话: {session_id}") + return session_id + + def prepare_sentence(self, username, text, force_first=False, force_end=False): + """ + 准备发送句子,自动添加适当的标记 + + 参数: + username: 用户名 + text: 文本内容 + force_first: 强制设为第一句 + force_end: 强制设为最后一句 + + 返回: + tuple: (处理后的文本, 是否为第一句, 是否为最后一句) + """ + with self.lock: + if username not in self.user_states: + # 如果没有活跃会话,自动创建一个 + self.start_new_session(username) + + state_info = self.user_states[username] + state_info['last_update'] = time.time() + + # 判断是否为第一句 + is_first = False + if force_first or (not state_info['is_first_sent'] and state_info['sentence_count'] == 0): + is_first = True + state_info['is_first_sent'] = True + state_info['state'] = StreamState.FIRST_SENTENCE + elif state_info['sentence_count'] > 0: + state_info['state'] = StreamState.MIDDLE_SENTENCE + + # 判断是否为最后一句 + is_end = force_end + if is_end: + state_info['is_end_sent'] = True + state_info['state'] = StreamState.LAST_SENTENCE + + # 更新句子计数 + state_info['sentence_count'] += 1 + + # 构造带标记的文本 + marked_text = text + if is_first and not marked_text.endswith('_'): + marked_text += "_" + if is_end and not marked_text.endswith('_'): + marked_text += "_" + return marked_text, is_first, is_end + + def end_session(self, username): + """ + 结束当前会话 + + 参数: + username: 用户名 + + 返回: + str: 空字符串(结束标记应该已经附加到最后一句话上) + """ + with self.lock: + if username not in self.user_states: + util.log(1, f"警告: 尝试结束不存在的会话 [{username}]") + return "" + + state_info = self.user_states[username] + + # 标记会话为完成状态 + if state_info['state'] != StreamState.COMPLETED: + state_info['state'] = StreamState.COMPLETED + + session_duration = time.time() - state_info['start_time'] + + # 检查是否已经发送过结束标记 + if not state_info['is_end_sent']: + util.log(1, f"警告: 会话结束但未发送过结束标记,可能存在逻辑问题") + + return "" # 不再返回单独的_标记 + + def get_session_info(self, username): + """ + 获取用户的会话信息 + + 参数: + username: 用户名 + + 返回: + dict: 会话信息 + """ + with self.lock: + if username in self.user_states: + return self.user_states[username].copy() + return None + + def is_session_active(self, username): + """ + 检查用户是否有活跃的会话 + + 参数: + username: 用户名 + + 返回: + bool: 是否有活跃会话 + """ + with self.lock: + if username not in self.user_states: + return False + + state_info = self.user_states[username] + return state_info['state'] not in [StreamState.COMPLETED] + + def cleanup_expired_sessions(self, timeout_seconds=300): + """ + 清理过期的会话 + + 参数: + timeout_seconds: 超时时间(秒) + """ + with self.lock: + current_time = time.time() + expired_users = [] + + for username, state_info in self.user_states.items(): + if current_time - state_info['last_update'] > timeout_seconds: + expired_users.append(username) + + for username in expired_users: + util.log(1, f"清理过期会话: {self.user_states[username]['session_id']}") + del self.user_states[username] + + def force_reset_user_state(self, username): + """ + 强制重置用户状态(用于异常恢复) + + 参数: + username: 用户名 + """ + with self.lock: + if username in self.user_states: + old_session = self.user_states[username]['session_id'] + del self.user_states[username] + util.log(1, f"强制重置用户状态: {username}, 旧会话: {old_session}") + + def get_all_active_sessions(self): + """ + 获取所有活跃会话的信息 + + 返回: + dict: 用户名 -> 会话信息 + """ + with self.lock: + active_sessions = {} + for username, state_info in self.user_states.items(): + if state_info['state'] != StreamState.COMPLETED: + active_sessions[username] = state_info.copy() + return active_sessions + +# 全局单例实例 +_state_manager_instance = None +_state_manager_lock = threading.Lock() + +def get_state_manager(): + """ + 获取流式状态管理器单例 + + 返回: + StreamStateManager: 状态管理器实例 + """ + global _state_manager_instance + if _state_manager_instance is None: + with _state_manager_lock: + if _state_manager_instance is None: + _state_manager_instance = StreamStateManager() + return _state_manager_instance + +# 定时清理过期会话的线程 +def start_cleanup_thread(): + """ + 启动定时清理线程 + """ + import threading + + def cleanup_worker(): + while True: + try: + time.sleep(60) # 每分钟清理一次 + get_state_manager().cleanup_expired_sessions() + except Exception as e: + util.log(1, f"清理过期会话时出错: {str(e)}") + + cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + cleanup_thread.start() + util.log(1, "流式状态管理器清理线程已启动") + +# 自动启动清理线程 +start_cleanup_thread() diff --git a/utils/stream_text_processor.py b/utils/stream_text_processor.py new file mode 100644 index 0000000..fc44e74 --- /dev/null +++ b/utils/stream_text_processor.py @@ -0,0 +1,183 @@ +import time +from utils import util +from core import stream_manager +from utils.stream_state_manager import get_state_manager + +class StreamTextProcessor: + """ + 安全的流式文本处理器,防止死循环和性能问题 + """ + + def __init__(self, min_length=10, max_iterations=100, timeout_seconds=30, max_cache_size=10240): + """ + 初始化流式文本处理器 + + 参数: + min_length: 最小发送长度阈值 + max_iterations: 最大循环次数限制 + timeout_seconds: 超时时间(秒) + max_cache_size: 最大缓存大小(字符数) + """ + self.min_length = min_length + self.max_iterations = max_iterations + self.timeout_seconds = timeout_seconds + self.max_cache_size = max_cache_size + self.punctuation_marks = [",", ",", "。", "、", "!", "?", ".", "!", "?", "\n"] + + def process_stream_text(self, text, username, is_qa=False, session_type="stream"): + """ + 安全地处理流式文本分割和发送 + + 参数: + text: 要处理的文本 + username: 用户名 + is_qa: 是否为Q&A模式 + session_type: 会话类型 + + 返回: + bool: 处理是否成功 + """ + if not text or not text.strip(): + return True + + # 获取状态管理器并开始新会话 + state_manager = get_state_manager() + if not state_manager.is_session_active(username): + state_manager.start_new_session(username, session_type) + + try: + return self._safe_process_text(text, username, is_qa, state_manager) + except Exception as e: + util.log(1, f"流式文本处理出错: {str(e)}") + # 发生异常时,直接发送完整文本作为备用方案 + self._send_fallback_text(text, username, state_manager) + return False + + def _safe_process_text(self, text, username, is_qa, state_manager): + """ + 安全的文本处理核心逻辑,包含缓存溢出保护 + """ + accumulated_text = text + iteration_count = 0 + start_time = time.time() + + # 缓存溢出检查 + if len(accumulated_text) > self.max_cache_size: + util.log(1, f"文本缓存溢出,长度: {len(accumulated_text)}, 限制: {self.max_cache_size}") + # 截断文本到安全大小 + accumulated_text = accumulated_text[:self.max_cache_size] + util.log(1, f"文本已截断到: {len(accumulated_text)} 字符") + + # 主处理循环,带安全保护 + while accumulated_text and iteration_count < self.max_iterations: + # 超时检查 + if time.time() - start_time > self.timeout_seconds: + util.log(1, f"流式处理超时,剩余文本长度: {len(accumulated_text)}") + break + + # 动态缓存大小检查 + if len(accumulated_text) > self.max_cache_size: + util.log(1, f"处理过程中缓存溢出,强制发送剩余文本") + break + + iteration_count += 1 + + # 查找标点符号位置 + punct_indices = self._find_punctuation_indices(accumulated_text) + + if not punct_indices: + # 没有标点符号,退出循环 + break + + # 尝试发送一个句子 + sent_successfully = False + for punct_index in punct_indices: + sentence_text = accumulated_text[:punct_index + 1] + + if len(sentence_text) >= self.min_length: + # 使用状态管理器准备句子 + marked_text, is_first, is_end = state_manager.prepare_sentence( + username, sentence_text, force_first=False, force_end=False + ) + + success = stream_manager.new_instance().write_sentence(username, marked_text) + if success: + accumulated_text = accumulated_text[punct_index + 1:].lstrip() + sent_successfully = True + break + else: + util.log(1, f"发送句子失败: {marked_text[:50]}...") + + # 如果这轮没有成功发送任何内容,退出循环防止死循环 + if not sent_successfully: + break + + # 发送剩余文本,如果是最后的文本则标记为结束 + if accumulated_text: + marked_text, _, _ = state_manager.prepare_sentence( + username, accumulated_text, force_first=False, force_end=True + ) + stream_manager.new_instance().write_sentence(username, marked_text) + else: + # 如果没有剩余文本,需要确保最后发送的句子包含结束标记 + session_info = state_manager.get_session_info(username) + if session_info and not session_info.get('is_end_sent', False): + marked_text, _, _ = state_manager.prepare_sentence( + username, "", force_first=False, force_end=True + ) + stream_manager.new_instance().write_sentence(username, marked_text) + + # 结束会话 + state_manager.end_session(username) + + # 记录处理统计 + if iteration_count >= self.max_iterations: + util.log(1, f"流式处理达到最大迭代次数限制: {self.max_iterations}") + + return True + + def _find_punctuation_indices(self, text): + """ + 安全地查找标点符号位置 + """ + try: + indices = [] + for punct in self.punctuation_marks: + try: + index = text.find(punct) + if index != -1: + indices.append(index) + except Exception as e: + util.log(1, f"查找标点符号 '{punct}' 时出错: {str(e)}") + continue + + return sorted([i for i in indices if i != -1]) + except Exception as e: + util.log(1, f"查找标点符号时出错: {str(e)}") + return [] + + def _send_fallback_text(self, text, username, state_manager): + """ + 备用发送方案,直接发送完整文本 + """ + try: + # 使用状态管理器准备完整文本 + marked_text, _, _ = state_manager.prepare_sentence( + username, text, force_first=True, force_end=True + ) + stream_manager.new_instance().write_sentence(username, marked_text) + util.log(1, "使用备用方案发送完整文本") + except Exception as e: + util.log(1, f"备用发送方案也失败: {str(e)}") + +# 全局单例实例 +_processor_instance = None + +def get_processor(): + """ + 获取流式文本处理器单例 + """ + global _processor_instance + if _processor_instance is None: + _processor_instance = StreamTextProcessor() + return _processor_instance