mirror of
https://github.com/xszyou/Fay.git
synced 2026-03-12 17:51:28 +08:00
fay进化
1. 内置RAG知识库(请把docx、pptx、txt文件存放到llm/data目录); 2. 流式回复逻辑优化; 3. 语音交互逻辑优化; 4. 线程安全增强; 5. 数字人驱动接口增加流式输出开始结束标记; 6. 修复因记忆反思而导致的记忆混乱,无法多轮对话问题; 7. 修复mcp工具获取于调用的线程同步问题; 8. 修复funasr依赖版本问题。
This commit is contained in:
6
asr/funasr/requirments.txt
Normal file
6
asr/funasr/requirments.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
torch
|
||||
modelscope
|
||||
testresources
|
||||
torchaudio
|
||||
FunASR
|
||||
websockets~=10.4
|
||||
104
core/fay_core.py
104
core/fay_core.py
@@ -77,20 +77,70 @@ class FeiFei:
|
||||
self.think_mode_users = {} # 使用字典存储每个用户的think模式状态
|
||||
|
||||
def __remove_emojis(self, text):
|
||||
"""
|
||||
改进的表情包过滤,避免误删除正常Unicode字符
|
||||
"""
|
||||
# 更精确的emoji范围,避免误删除正常字符
|
||||
emoji_pattern = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # 表情符号
|
||||
"\U0001F300-\U0001F5FF" # 图标符号
|
||||
"\U0001F680-\U0001F6FF" # 交通工具和符号
|
||||
"\U0001F1E0-\U0001F1FF" # 国旗
|
||||
"\U00002700-\U000027BF" # 杂项符号
|
||||
"\U0001F900-\U0001F9FF" # 补充表情符号
|
||||
"\U00002600-\U000026FF" # 杂项符号
|
||||
"\U0001FA70-\U0001FAFF" # 更多表情
|
||||
"\U0001F600-\U0001F64F" # 表情符号 (Emoticons)
|
||||
"\U0001F300-\U0001F5FF" # 杂项符号和象形文字 (Miscellaneous Symbols and Pictographs)
|
||||
"\U0001F680-\U0001F6FF" # 交通和地图符号 (Transport and Map Symbols)
|
||||
"\U0001F1E0-\U0001F1FF" # 区域指示符号 (Regional Indicator Symbols)
|
||||
"\U0001F900-\U0001F9FF" # 补充符号和象形文字 (Supplemental Symbols and Pictographs)
|
||||
"\U0001FA70-\U0001FAFF" # 扩展A符号和象形文字 (Symbols and Pictographs Extended-A)
|
||||
"\U00002600-\U000026FF" # 杂项符号 (Miscellaneous Symbols)
|
||||
"\U00002700-\U000027BF" # 装饰符号 (Dingbats)
|
||||
"\U0000FE00-\U0000FE0F" # 变体选择器 (Variation Selectors)
|
||||
"\U0001F000-\U0001F02F" # 麻将牌 (Mahjong Tiles)
|
||||
"\U0001F0A0-\U0001F0FF" # 扑克牌 (Playing Cards)
|
||||
"]+",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
return emoji_pattern.sub(r'', text)
|
||||
|
||||
# 保护常用的中文标点符号和特殊字符
|
||||
protected_chars = ["。", ",", "!", "?", ":", ";", "、", """, """, "'", "'", "(", ")", "【", "】", "《", "》"]
|
||||
|
||||
# 先保存保护字符的位置
|
||||
protected_positions = {}
|
||||
for i, char in enumerate(text):
|
||||
if char in protected_chars:
|
||||
protected_positions[i] = char
|
||||
|
||||
# 执行emoji过滤
|
||||
filtered_text = emoji_pattern.sub('', text)
|
||||
|
||||
# 如果过滤后文本长度变化太大,可能误删了正常字符,返回原文本
|
||||
if len(filtered_text) < len(text) * 0.5: # 如果删除了超过50%的内容
|
||||
return text
|
||||
|
||||
return filtered_text
|
||||
|
||||
def __process_qa_stream(self, text, username):
|
||||
"""
|
||||
按流式方式分割和发送Q&A答案
|
||||
使用安全的流式文本处理器和状态管理器
|
||||
"""
|
||||
if not text or text.strip() == "":
|
||||
return
|
||||
|
||||
# 使用安全的流式文本处理器
|
||||
from utils.stream_text_processor import get_processor
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
|
||||
processor = get_processor()
|
||||
state_manager = get_state_manager()
|
||||
|
||||
# 处理Q&A流式文本,is_qa=True表示Q&A模式
|
||||
success = processor.process_stream_text(text, username, is_qa=True, session_type="qa")
|
||||
|
||||
if success:
|
||||
# Q&A模式结束会话(不再需要发送额外的结束标记)
|
||||
state_manager.end_session(username)
|
||||
else:
|
||||
util.log(1, f"Q&A流式处理失败,文本长度: {len(text)}")
|
||||
# 失败时也要确保结束会话
|
||||
state_manager.force_reset_user_state(username)
|
||||
|
||||
#语音消息处理检查是否命中q&a
|
||||
def __get_answer(self, interleaver, text):
|
||||
@@ -139,7 +189,8 @@ class FeiFei:
|
||||
|
||||
else:
|
||||
text = answer
|
||||
stream_manager.new_instance().write_sentence(username, "_<isfirst>" + text + "_<isend>")
|
||||
# 使用流式分割处理Q&A答案
|
||||
self.__process_qa_stream(text, username)
|
||||
|
||||
#完整文本记录回复并输出到各个终端
|
||||
self.__process_text_output(text, username, uid )
|
||||
@@ -259,7 +310,7 @@ class FeiFei:
|
||||
is_end = interact.data.get("isend", False)
|
||||
is_first = interact.data.get("isfirst", False)
|
||||
|
||||
if is_first and (text is None or text.strip() == ""):
|
||||
if not is_first and not is_end and (text is None or text.strip() == ""):
|
||||
return None
|
||||
|
||||
self.__send_panel_message(text, interact.data.get('user'), uid, 0, type)
|
||||
@@ -366,20 +417,24 @@ class FeiFei:
|
||||
except Exception as e:
|
||||
util.printInfo(1, "System", "音频播放初始化失败,本机无法播放音频")
|
||||
return
|
||||
|
||||
|
||||
while self.__running:
|
||||
time.sleep(0.01)
|
||||
if not self.sound_query.empty(): # 如果队列不为空则播放音频
|
||||
file_url, audio_length, interact = self.sound_query.get()
|
||||
is_first = False
|
||||
is_end = False
|
||||
if interact.data.get('isfirst'):
|
||||
is_first = True
|
||||
if interact.data.get('isend'):
|
||||
is_end = True
|
||||
|
||||
is_first = interact.data.get('isfirst') is True
|
||||
is_end = interact.data.get('isend') is True
|
||||
|
||||
|
||||
|
||||
if file_url is not None:
|
||||
util.printInfo(1, interact.data.get('user'), '播放音频...')
|
||||
self.speaking = True
|
||||
|
||||
if is_first:
|
||||
self.speaking = True
|
||||
elif not is_end:
|
||||
self.speaking = True
|
||||
|
||||
#自动播报关闭
|
||||
global auto_play_lock
|
||||
@@ -392,7 +447,7 @@ class FeiFei:
|
||||
|
||||
if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
|
||||
wsa_server.get_web_instance().add_cmd({"panelMsg": "播放中 ...", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'})
|
||||
|
||||
|
||||
if file_url is not None:
|
||||
pygame.mixer.music.load(file_url)
|
||||
pygame.mixer.music.play()
|
||||
@@ -402,10 +457,10 @@ class FeiFei:
|
||||
while length < audio_length:
|
||||
length += 0.01
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
if is_end:
|
||||
self.play_end(interact)
|
||||
util.printInfo(1, interact.data.get('user'), '结束播放!')
|
||||
|
||||
if wsa_server.get_web_instance().is_connected(interact.data.get('user')):
|
||||
wsa_server.get_web_instance().add_cmd({"panelMsg": "", "Username" : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Normal.jpg'})
|
||||
# 播放完毕后通知
|
||||
@@ -467,7 +522,7 @@ class FeiFei:
|
||||
|
||||
#发送音频给数字人接口
|
||||
if file_url is not None and wsa_server.get_instance().is_connected(interact.data.get("user")):
|
||||
content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}
|
||||
content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0}, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}
|
||||
#计算lips
|
||||
if platform.system() == "Windows":
|
||||
try:
|
||||
@@ -580,12 +635,13 @@ class FeiFei:
|
||||
:param text: 消息文本
|
||||
:param username: 用户名
|
||||
"""
|
||||
full_text = self.__remove_emojis(text.replace("*", ""))
|
||||
if wsa_server.get_instance().is_connected(username):
|
||||
content = {
|
||||
'Topic': 'human',
|
||||
'Data': {
|
||||
'Key': 'text',
|
||||
'Value': text
|
||||
'Value': full_text
|
||||
},
|
||||
'Username': username
|
||||
}
|
||||
|
||||
@@ -128,7 +128,12 @@ class Recorder:
|
||||
with fay_core.auto_play_lock:
|
||||
fay_core.can_auto_play = False
|
||||
#self.on_speaking(text)
|
||||
intt = interact.Interact("auto_play", 2, {'user': self.username, 'text': "在呢,你说?"})
|
||||
# 使用状态管理器处理唤醒回复
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
state_manager = get_state_manager()
|
||||
state_manager.start_new_session(self.username, "auto_play")
|
||||
|
||||
intt = interact.Interact("auto_play", 2, {'user': self.username, 'text': "在呢,你说?" , "isfirst" : True, "isend" : True})
|
||||
self.__fay.on_interact(intt)
|
||||
self.processing = False
|
||||
self.timer.cancel() # 取消之前的计时器任务
|
||||
|
||||
@@ -45,45 +45,52 @@ class StreamManager:
|
||||
|
||||
def get_Stream(self, username):
|
||||
"""
|
||||
获取指定用户ID的文本流,如果不存在则创建新的
|
||||
获取指定用户ID的文本流,如果不存在则创建新的(线程安全)
|
||||
:param username: 用户名
|
||||
:return: 对应的句子缓存对象
|
||||
"""
|
||||
need_start_thread = False
|
||||
stream = None
|
||||
nlp_stream = None
|
||||
|
||||
with self.lock:
|
||||
if username not in self.streams or username not in self.nlp_streams:
|
||||
self.streams[username] = stream_sentence.SentenceCache(self.max_sentences)
|
||||
self.nlp_streams[username] = stream_sentence.SentenceCache(self.max_sentences)
|
||||
need_start_thread = True
|
||||
# 注意:这个方法应该在已经获得锁的情况下调用
|
||||
# 如果从外部调用,需要先获得锁
|
||||
|
||||
if username not in self.streams or username not in self.nlp_streams:
|
||||
# 创建新的流缓存
|
||||
self.streams[username] = stream_sentence.SentenceCache(self.max_sentences)
|
||||
self.nlp_streams[username] = stream_sentence.SentenceCache(self.max_sentences)
|
||||
|
||||
# 启动监听线程(如果还没有)
|
||||
if username not in self.listener_threads:
|
||||
stream = self.streams[username]
|
||||
nlp_stream = self.nlp_streams[username]
|
||||
|
||||
if need_start_thread:
|
||||
thread = MyThread(target=self.listen, args=(username, stream, nlp_stream), daemon=True)
|
||||
self.listener_threads[username] = thread
|
||||
thread.start()
|
||||
else:
|
||||
stream = self.streams[username]
|
||||
nlp_stream = self.nlp_streams[username]
|
||||
|
||||
return stream, nlp_stream
|
||||
thread = MyThread(target=self.listen, args=(username, stream, nlp_stream), daemon=True)
|
||||
self.listener_threads[username] = thread
|
||||
thread.start()
|
||||
|
||||
return self.streams[username], self.nlp_streams[username]
|
||||
|
||||
def write_sentence(self, username, sentence):
|
||||
"""
|
||||
写入句子到指定用户的文本流
|
||||
写入句子到指定用户的文本流(线程安全)
|
||||
:param username: 用户名
|
||||
:param sentence: 要写入的句子
|
||||
:return: 写入是否成功
|
||||
"""
|
||||
# 检查句子长度,防止过大的句子导致内存问题
|
||||
if len(sentence) > 10240: # 10KB限制
|
||||
sentence = sentence[:10240]
|
||||
|
||||
if sentence.endswith('_<isfirst>'):
|
||||
self.clear_Stream(username)
|
||||
Stream, nlp_Stream = self.get_Stream(username)
|
||||
success = Stream.write(sentence)
|
||||
nlp_success = nlp_Stream.write(sentence)
|
||||
return success and nlp_success
|
||||
|
||||
# 使用锁保护获取和写入操作
|
||||
with self.lock:
|
||||
try:
|
||||
Stream, nlp_Stream = self.get_Stream(username)
|
||||
success = Stream.write(sentence)
|
||||
nlp_success = nlp_Stream.write(sentence)
|
||||
return success and nlp_success
|
||||
except Exception as e:
|
||||
print(f"写入句子时出错: {e}")
|
||||
return False
|
||||
|
||||
def clear_Stream(self, username):
|
||||
"""
|
||||
@@ -111,12 +118,11 @@ class StreamManager:
|
||||
:param sentence: 要处理的句子
|
||||
"""
|
||||
fay_core = fay_booter.feiFei
|
||||
# 处理普通消息,区分是否是会话的第一句
|
||||
|
||||
is_first = "_<isfirst>" in sentence
|
||||
is_end = "_<isend>" in sentence
|
||||
sentence = sentence.replace("_<isfirst>", "").replace("_<isend>", "")
|
||||
|
||||
|
||||
if sentence or is_first or is_end :
|
||||
interact = Interact("stream", 1, {"user": username, "msg": sentence, "isfirst" : is_first, "isend" : is_end})
|
||||
fay_core.say(interact, sentence) # 调用核心处理模块进行响应
|
||||
|
||||
@@ -345,6 +345,11 @@ def start():
|
||||
from llm.nlp_cognitive_stream import init_memory_scheduler
|
||||
init_memory_scheduler()
|
||||
|
||||
#初始化知识库
|
||||
util.log(1, '初始化本地知识库...')
|
||||
from llm.nlp_cognitive_stream import init_knowledge_base
|
||||
init_knowledge_base()
|
||||
|
||||
#开启录音服务
|
||||
record = config_util.config['source']['record']
|
||||
if record['enabled']:
|
||||
|
||||
BIN
llm/data/测试.docx
Normal file
BIN
llm/data/测试.docx
Normal file
Binary file not shown.
@@ -11,6 +11,28 @@ from pydantic import create_model
|
||||
from langchain.tools import StructuredTool
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
# 新增:本地知识库相关导入
|
||||
import re
|
||||
from pathlib import Path
|
||||
import docx
|
||||
from docx.document import Document
|
||||
from docx.oxml.table import CT_Tbl
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import _Cell, Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
try:
|
||||
from pptx import Presentation
|
||||
PPTX_AVAILABLE = True
|
||||
except ImportError:
|
||||
PPTX_AVAILABLE = False
|
||||
|
||||
# 用于处理 .doc 文件的库
|
||||
try:
|
||||
import win32com.client
|
||||
WIN32COM_AVAILABLE = True
|
||||
except ImportError:
|
||||
WIN32COM_AVAILABLE = False
|
||||
|
||||
from utils import util
|
||||
import utils.config_util as cfg
|
||||
from genagents.genagents import GenerativeAgent
|
||||
@@ -21,7 +43,7 @@ from core import stream_manager
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||
os.environ["LANGCHAIN_API_KEY"] = "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"
|
||||
os.environ["LANGCHAIN_PROJECT"] = "fay3.0.0_github"
|
||||
os.environ["LANGCHAIN_PROJECT"] = "fay3.1.2_github"
|
||||
|
||||
# 加载配置
|
||||
cfg.load_config()
|
||||
@@ -83,6 +105,395 @@ def get_current_time_step(username=None):
|
||||
util.log(1, f"获取time_step时出错: {str(e)},使用0代替")
|
||||
return 0
|
||||
|
||||
# 新增:本地知识库相关函数
|
||||
def read_doc_file(file_path):
|
||||
"""
|
||||
读取doc文件内容
|
||||
|
||||
参数:
|
||||
file_path: doc文件路径
|
||||
|
||||
返回:
|
||||
str: 文档内容
|
||||
"""
|
||||
try:
|
||||
# 方法1: 使用 win32com.client(Windows系统,推荐用于.doc文件)
|
||||
if WIN32COM_AVAILABLE:
|
||||
word = None
|
||||
doc = None
|
||||
try:
|
||||
import pythoncom
|
||||
pythoncom.CoInitialize() # 初始化COM组件
|
||||
|
||||
word = win32com.client.Dispatch("Word.Application")
|
||||
word.Visible = False
|
||||
doc = word.Documents.Open(file_path)
|
||||
content = doc.Content.Text
|
||||
|
||||
# 先保存内容,再尝试关闭
|
||||
if content and content.strip():
|
||||
try:
|
||||
doc.Close()
|
||||
word.Quit()
|
||||
except Exception as close_e:
|
||||
util.log(1, f"关闭Word应用程序时出错: {str(close_e)},但内容已成功提取")
|
||||
|
||||
try:
|
||||
pythoncom.CoUninitialize() # 清理COM组件
|
||||
except:
|
||||
pass
|
||||
|
||||
return content.strip()
|
||||
|
||||
except Exception as e:
|
||||
util.log(1, f"使用 win32com 读取 .doc 文件失败: {str(e)}")
|
||||
finally:
|
||||
# 确保资源被释放
|
||||
try:
|
||||
if doc:
|
||||
doc.Close()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
if word:
|
||||
word.Quit()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
pythoncom.CoUninitialize()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 方法2: 简单的二进制文本提取(备选方案)
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
# 尝试提取可打印的文本
|
||||
text_parts = []
|
||||
current_text = ""
|
||||
|
||||
for byte in raw_data:
|
||||
char = chr(byte) if 32 <= byte <= 126 or byte in [9, 10, 13] else None
|
||||
if char:
|
||||
current_text += char
|
||||
else:
|
||||
if len(current_text) > 3: # 只保留长度大于3的文本片段
|
||||
text_parts.append(current_text.strip())
|
||||
current_text = ""
|
||||
|
||||
if len(current_text) > 3:
|
||||
text_parts.append(current_text.strip())
|
||||
|
||||
# 过滤和清理文本
|
||||
filtered_parts = []
|
||||
for part in text_parts:
|
||||
# 移除过多的重复字符和无意义的片段
|
||||
if (len(part) > 5 and
|
||||
not part.startswith('Microsoft') and
|
||||
not all(c in '0123456789-_.' for c in part) and
|
||||
len(set(part)) > 3): # 字符种类要多样
|
||||
filtered_parts.append(part)
|
||||
|
||||
if filtered_parts:
|
||||
return '\n'.join(filtered_parts)
|
||||
|
||||
except Exception as e:
|
||||
util.log(1, f"使用二进制方法读取 .doc 文件失败: {str(e)}")
|
||||
|
||||
util.log(1, f"无法读取 .doc 文件 {file_path},建议转换为 .docx 格式")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
util.log(1, f"读取doc文件 {file_path} 时出错: {str(e)}")
|
||||
return ""
|
||||
|
||||
def read_docx_file(file_path):
|
||||
"""
|
||||
读取docx文件内容
|
||||
|
||||
参数:
|
||||
file_path: docx文件路径
|
||||
|
||||
返回:
|
||||
str: 文档内容
|
||||
"""
|
||||
try:
|
||||
doc = docx.Document(file_path)
|
||||
content = []
|
||||
|
||||
for element in doc.element.body:
|
||||
if isinstance(element, CT_P):
|
||||
paragraph = Paragraph(element, doc)
|
||||
if paragraph.text.strip():
|
||||
content.append(paragraph.text.strip())
|
||||
elif isinstance(element, CT_Tbl):
|
||||
table = Table(element, doc)
|
||||
for row in table.rows:
|
||||
row_text = []
|
||||
for cell in row.cells:
|
||||
if cell.text.strip():
|
||||
row_text.append(cell.text.strip())
|
||||
if row_text:
|
||||
content.append(" | ".join(row_text))
|
||||
|
||||
return "\n".join(content)
|
||||
except Exception as e:
|
||||
util.log(1, f"读取docx文件 {file_path} 时出错: {str(e)}")
|
||||
return ""
|
||||
|
||||
def read_pptx_file(file_path):
|
||||
"""
|
||||
读取pptx文件内容
|
||||
|
||||
参数:
|
||||
file_path: pptx文件路径
|
||||
|
||||
返回:
|
||||
str: 演示文稿内容
|
||||
"""
|
||||
if not PPTX_AVAILABLE:
|
||||
util.log(1, "python-pptx 库未安装,无法读取 PowerPoint 文件")
|
||||
return ""
|
||||
|
||||
try:
|
||||
prs = Presentation(file_path)
|
||||
content = []
|
||||
|
||||
for i, slide in enumerate(prs.slides):
|
||||
slide_content = [f"第{i+1}页:"]
|
||||
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text") and shape.text.strip():
|
||||
slide_content.append(shape.text.strip())
|
||||
|
||||
if len(slide_content) > 1: # 有内容才添加
|
||||
content.append("\n".join(slide_content))
|
||||
|
||||
return "\n\n".join(content)
|
||||
except Exception as e:
|
||||
util.log(1, f"读取pptx文件 {file_path} 时出错: {str(e)}")
|
||||
return ""
|
||||
|
||||
def load_local_knowledge_base():
|
||||
"""
|
||||
加载本地知识库内容
|
||||
|
||||
返回:
|
||||
dict: 文件名到内容的映射
|
||||
"""
|
||||
knowledge_base = {}
|
||||
|
||||
# 获取llm/data目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.join(current_dir, "data")
|
||||
|
||||
if not os.path.exists(data_dir):
|
||||
util.log(1, f"知识库目录不存在: {data_dir}")
|
||||
return knowledge_base
|
||||
|
||||
# 遍历data目录中的文件
|
||||
for file_path in Path(data_dir).iterdir():
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
file_name = file_path.name
|
||||
file_extension = file_path.suffix.lower()
|
||||
|
||||
try:
|
||||
if file_extension == '.docx':
|
||||
content = read_docx_file(str(file_path))
|
||||
elif file_extension == '.doc':
|
||||
content = read_doc_file(str(file_path))
|
||||
elif file_extension == '.pptx':
|
||||
content = read_pptx_file(str(file_path))
|
||||
else:
|
||||
# 尝试作为文本文件读取
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
with open(file_path, 'r', encoding='gbk') as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError:
|
||||
util.log(1, f"无法解码文件: {file_name}")
|
||||
continue
|
||||
|
||||
if content.strip():
|
||||
knowledge_base[file_name] = content
|
||||
util.log(1, f"成功加载知识库文件: {file_name} ({len(content)} 字符)")
|
||||
|
||||
except Exception as e:
|
||||
util.log(1, f"加载知识库文件 {file_name} 时出错: {str(e)}")
|
||||
|
||||
return knowledge_base
|
||||
|
||||
def search_knowledge_base(query, knowledge_base, max_results=3):
|
||||
"""
|
||||
在知识库中搜索相关内容
|
||||
|
||||
参数:
|
||||
query: 查询内容
|
||||
knowledge_base: 知识库字典
|
||||
max_results: 最大返回结果数
|
||||
|
||||
返回:
|
||||
list: 相关内容列表
|
||||
"""
|
||||
if not knowledge_base:
|
||||
return []
|
||||
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
|
||||
# 搜索关键词
|
||||
query_keywords = re.findall(r'\w+', query_lower)
|
||||
|
||||
for file_name, content in knowledge_base.items():
|
||||
content_lower = content.lower()
|
||||
|
||||
# 计算匹配度
|
||||
score = 0
|
||||
matched_sentences = []
|
||||
|
||||
# 按句子分割内容
|
||||
sentences = re.split(r'[。!?\n]', content)
|
||||
|
||||
for sentence in sentences:
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
sentence_lower = sentence.lower()
|
||||
sentence_score = 0
|
||||
|
||||
# 计算关键词匹配度
|
||||
for keyword in query_keywords:
|
||||
if keyword in sentence_lower:
|
||||
sentence_score += 1
|
||||
|
||||
# 如果句子有匹配,记录
|
||||
if sentence_score > 0:
|
||||
matched_sentences.append((sentence.strip(), sentence_score))
|
||||
score += sentence_score
|
||||
|
||||
# 如果有匹配的内容
|
||||
if score > 0:
|
||||
# 按匹配度排序句子
|
||||
matched_sentences.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 取前几个最相关的句子
|
||||
relevant_sentences = [sent[0] for sent in matched_sentences[:5] if sent[0]]
|
||||
|
||||
if relevant_sentences:
|
||||
results.append({
|
||||
'file_name': file_name,
|
||||
'score': score,
|
||||
'content': '\n'.join(relevant_sentences)
|
||||
})
|
||||
|
||||
# 按匹配度排序
|
||||
results.sort(key=lambda x: x['score'], reverse=True)
|
||||
|
||||
return results[:max_results]
|
||||
|
||||
# 全局知识库缓存
|
||||
_knowledge_base_cache = None
|
||||
_knowledge_base_load_time = None
|
||||
_knowledge_base_file_times = {} # 存储文件的最后修改时间
|
||||
|
||||
def check_knowledge_base_changes():
|
||||
"""
|
||||
检查知识库文件是否有变化
|
||||
|
||||
返回:
|
||||
bool: 如果有文件变化返回True,否则返回False
|
||||
"""
|
||||
global _knowledge_base_file_times
|
||||
|
||||
# 获取llm/data目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.join(current_dir, "data")
|
||||
|
||||
if not os.path.exists(data_dir):
|
||||
return False
|
||||
|
||||
current_file_times = {}
|
||||
|
||||
# 遍历data目录中的文件
|
||||
for file_path in Path(data_dir).iterdir():
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
file_name = file_path.name
|
||||
file_extension = file_path.suffix.lower()
|
||||
|
||||
# 只检查支持的文件格式
|
||||
if file_extension in ['.docx', '.doc', '.pptx', '.txt'] or file_extension == '':
|
||||
try:
|
||||
mtime = os.path.getmtime(str(file_path))
|
||||
current_file_times[file_name] = mtime
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
# 检查是否有变化
|
||||
if not _knowledge_base_file_times:
|
||||
# 第一次检查,保存文件时间
|
||||
_knowledge_base_file_times = current_file_times
|
||||
return True
|
||||
|
||||
# 比较文件时间
|
||||
if set(current_file_times.keys()) != set(_knowledge_base_file_times.keys()):
|
||||
# 文件数量发生变化
|
||||
_knowledge_base_file_times = current_file_times
|
||||
return True
|
||||
|
||||
for file_name, mtime in current_file_times.items():
|
||||
if file_name not in _knowledge_base_file_times or _knowledge_base_file_times[file_name] != mtime:
|
||||
# 文件被修改
|
||||
_knowledge_base_file_times = current_file_times
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def init_knowledge_base():
|
||||
"""
|
||||
初始化知识库,在系统启动时调用
|
||||
"""
|
||||
global _knowledge_base_cache, _knowledge_base_load_time
|
||||
|
||||
util.log(1, "初始化本地知识库...")
|
||||
_knowledge_base_cache = load_local_knowledge_base()
|
||||
_knowledge_base_load_time = time.time()
|
||||
|
||||
# 初始化文件修改时间跟踪
|
||||
check_knowledge_base_changes()
|
||||
|
||||
util.log(1, f"知识库初始化完成,共 {len(_knowledge_base_cache)} 个文件")
|
||||
|
||||
def get_knowledge_base():
|
||||
"""
|
||||
获取知识库,使用缓存机制
|
||||
|
||||
返回:
|
||||
dict: 知识库内容
|
||||
"""
|
||||
global _knowledge_base_cache, _knowledge_base_load_time
|
||||
|
||||
# 如果缓存为空,先初始化
|
||||
if _knowledge_base_cache is None:
|
||||
init_knowledge_base()
|
||||
return _knowledge_base_cache
|
||||
|
||||
# 检查文件是否有变化
|
||||
if check_knowledge_base_changes():
|
||||
util.log(1, "检测到知识库文件变化,正在重新加载...")
|
||||
_knowledge_base_cache = load_local_knowledge_base()
|
||||
_knowledge_base_load_time = time.time()
|
||||
util.log(1, f"知识库重新加载完成,共 {len(_knowledge_base_cache)} 个文件")
|
||||
|
||||
return _knowledge_base_cache
|
||||
|
||||
|
||||
# 定时保存记忆的线程
|
||||
def memory_scheduler_thread():
|
||||
"""
|
||||
@@ -361,6 +772,21 @@ def question(content, username, observation=None):
|
||||
except Exception as e:
|
||||
util.log(1, f"获取相关记忆时出错: {str(e)}")
|
||||
|
||||
# 新增:搜索本地知识库
|
||||
knowledge_context = ""
|
||||
try:
|
||||
knowledge_base = get_knowledge_base()
|
||||
if knowledge_base:
|
||||
knowledge_results = search_knowledge_base(content, knowledge_base, max_results=3)
|
||||
if knowledge_results:
|
||||
knowledge_context = "**本地知识库相关信息**:\n"
|
||||
for result in knowledge_results:
|
||||
knowledge_context += f"来源文件:{result['file_name']}\n"
|
||||
knowledge_context += f"{result['content']}\n\n"
|
||||
util.log(1, f"找到 {len(knowledge_results)} 条相关知识库信息")
|
||||
except Exception as e:
|
||||
util.log(1, f"搜索知识库时出错: {str(e)}")
|
||||
|
||||
# 使用文件开头定义的llm对象进行流式请求
|
||||
observation = "**还观察的情况**:" + observation + "\n" if observation else ""
|
||||
|
||||
@@ -437,12 +863,60 @@ def question(content, username, observation=None):
|
||||
if react_response_text and react_response_text.strip():
|
||||
if is_agent_think_start:
|
||||
react_response_text = "</think>" + react_response_text
|
||||
stream_manager.new_instance().write_sentence(username, react_response_text)
|
||||
# 对React Agent的最终回复也进行分句处理
|
||||
accumulated_text += react_response_text
|
||||
# 使用安全的流式文本处理器和状态管理器
|
||||
from utils.stream_text_processor import get_processor
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
|
||||
processor = get_processor()
|
||||
state_manager = get_state_manager()
|
||||
|
||||
# 确保有活跃会话
|
||||
if not state_manager.is_session_active(username):
|
||||
state_manager.start_new_session(username, "react_agent")
|
||||
|
||||
# 如果累积文本达到一定长度,进行处理
|
||||
if len(accumulated_text) >= 20: # 设置一个合理的阈值
|
||||
# 找到最后一个标点符号的位置
|
||||
last_punct_pos = -1
|
||||
for punct in processor.punctuation_marks:
|
||||
pos = accumulated_text.rfind(punct)
|
||||
if pos > last_punct_pos:
|
||||
last_punct_pos = pos
|
||||
|
||||
if last_punct_pos > 10: # 确保有足够的内容发送
|
||||
sentence_text = accumulated_text[:last_punct_pos + 1]
|
||||
# 使用状态管理器准备句子
|
||||
marked_text, _, _ = state_manager.prepare_sentence(username, sentence_text)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
accumulated_text = accumulated_text[last_punct_pos + 1:].lstrip()
|
||||
|
||||
except (KeyError, IndexError, AttributeError):
|
||||
react_response_text = f"抱歉,我现在太忙了,休息一会,请稍后再试。"
|
||||
if is_first_sentence:
|
||||
react_response_text = "_<isfirst>" + react_response_text
|
||||
is_first_sentence = False
|
||||
stream_manager.new_instance().write_sentence(username, react_response_text)
|
||||
|
||||
full_response_text += react_response_text
|
||||
|
||||
# 确保React Agent最后一段文本也被发送,并标记为结束
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
state_manager = get_state_manager()
|
||||
|
||||
if accumulated_text:
|
||||
# 使用状态管理器准备最后的文本,强制标记为结束
|
||||
marked_text, _, _ = state_manager.prepare_sentence(username, accumulated_text, force_end=True)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
else:
|
||||
# 如果没有剩余文本,检查是否需要发送结束标记
|
||||
session_info = state_manager.get_session_info(username)
|
||||
if session_info and not session_info.get('is_end_sent', False):
|
||||
# 发送一个空的结束标记
|
||||
marked_text, _, _ = state_manager.prepare_sentence(username, "", force_end=True)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
|
||||
|
||||
else:
|
||||
try:
|
||||
@@ -452,24 +926,50 @@ def question(content, username, observation=None):
|
||||
if not flush_text:
|
||||
continue
|
||||
accumulated_text += flush_text
|
||||
for mark in punctuation_marks:
|
||||
if mark in accumulated_text:
|
||||
last_punct_pos = max(accumulated_text.rfind(p) for p in punctuation_marks if p in accumulated_text)
|
||||
if last_punct_pos != -1:
|
||||
to_write = accumulated_text[:last_punct_pos + 1]
|
||||
accumulated_text = accumulated_text[last_punct_pos + 1:]
|
||||
if is_first_sentence:
|
||||
to_write += "_<isfirst>"
|
||||
is_first_sentence = False
|
||||
stream_manager.new_instance().write_sentence(username, to_write)
|
||||
break
|
||||
# 使用安全的流式处理逻辑和状态管理器
|
||||
from utils.stream_text_processor import get_processor
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
|
||||
processor = get_processor()
|
||||
state_manager = get_state_manager()
|
||||
|
||||
# 确保有活跃会话
|
||||
if not state_manager.is_session_active(username):
|
||||
state_manager.start_new_session(username, "llm_stream")
|
||||
|
||||
# 如果累积文本达到一定长度,进行处理
|
||||
if len(accumulated_text) >= 20: # 设置一个合理的阈值
|
||||
# 找到最后一个标点符号的位置
|
||||
last_punct_pos = -1
|
||||
for punct in processor.punctuation_marks:
|
||||
pos = accumulated_text.rfind(punct)
|
||||
if pos > last_punct_pos:
|
||||
last_punct_pos = pos
|
||||
|
||||
if last_punct_pos > 10: # 确保有足够的内容发送
|
||||
sentence_text = accumulated_text[:last_punct_pos + 1]
|
||||
# 使用状态管理器准备句子
|
||||
marked_text, _, _ = state_manager.prepare_sentence(username, sentence_text)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
accumulated_text = accumulated_text[last_punct_pos + 1:].lstrip()
|
||||
|
||||
full_response_text += flush_text
|
||||
# 确保最后一段文本也被发送
|
||||
# 确保最后一段文本也被发送,并标记为结束
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
state_manager = get_state_manager()
|
||||
|
||||
if accumulated_text:
|
||||
if is_first_sentence: #相当于整个回复没有标点
|
||||
accumulated_text += "_<isfirst>"
|
||||
is_first_sentence = False
|
||||
stream_manager.new_instance().write_sentence(username, accumulated_text)
|
||||
# 使用状态管理器准备最后的文本,强制标记为结束
|
||||
marked_text, _, _ = state_manager.prepare_sentence(username, accumulated_text, force_end=True)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
else:
|
||||
# 如果没有剩余文本,检查是否需要发送结束标记
|
||||
session_info = state_manager.get_session_info(username)
|
||||
if session_info and not session_info.get('is_end_sent', False):
|
||||
# 发送一个空的结束标记
|
||||
marked_text, _, _ = state_manager.prepare_sentence(username, "", force_end=True)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
util.log(1, f"请求失败: {e}")
|
||||
@@ -477,8 +977,10 @@ def question(content, username, observation=None):
|
||||
stream_manager.new_instance().write_sentence(username, "_<isfirst>" + error_message + "_<isend>")
|
||||
full_response_text = error_message
|
||||
|
||||
# 发送结束标记
|
||||
stream_manager.new_instance().write_sentence(username, "_<isend>")
|
||||
# 结束会话(不再需要发送额外的结束标记)
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
state_manager = get_state_manager()
|
||||
state_manager.end_session(username)
|
||||
|
||||
# 在单独线程中记忆对话内容
|
||||
MyThread(target=remember_conversation_thread, args=(username, content, full_response_text.split("</think>")[-1])).start()
|
||||
|
||||
@@ -6,10 +6,8 @@ from tts import tts_voice
|
||||
from tts.tts_voice import EnumVoice
|
||||
from utils import util, config_util
|
||||
from utils import config_util as cfg
|
||||
import pygame
|
||||
import edge_tts
|
||||
from pydub import AudioSegment
|
||||
from scheduler.thread_manager import MyThread
|
||||
|
||||
class Speech:
|
||||
def __init__(self):
|
||||
|
||||
249
utils/stream_state_manager.py
Normal file
249
utils/stream_state_manager.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import threading
|
||||
import time
|
||||
from utils import util
|
||||
from enum import Enum
|
||||
|
||||
class StreamState(Enum):
|
||||
"""流式状态枚举"""
|
||||
IDLE = "idle" # 空闲状态
|
||||
FIRST_SENTENCE = "first" # 第一句话
|
||||
MIDDLE_SENTENCE = "middle" # 中间句子
|
||||
LAST_SENTENCE = "last" # 最后一句话
|
||||
COMPLETED = "completed" # 完成状态
|
||||
|
||||
class StreamStateManager:
|
||||
"""
|
||||
流式状态管理器 - 统一管理isfirst/isend标记
|
||||
解决多处设置标记导致的状态不一致问题
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.lock = threading.RLock()
|
||||
self.user_states = {} # 用户名 -> 状态信息
|
||||
self.session_counters = {} # 用户名 -> 会话计数器
|
||||
|
||||
def start_new_session(self, username, session_type="stream"):
|
||||
"""
|
||||
开始新的流式会话
|
||||
|
||||
参数:
|
||||
username: 用户名
|
||||
session_type: 会话类型 (stream, qa, auto_play等)
|
||||
|
||||
返回:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
with self.lock:
|
||||
if username not in self.session_counters:
|
||||
self.session_counters[username] = 0
|
||||
|
||||
self.session_counters[username] += 1
|
||||
session_id = f"{username}_{session_type}_{self.session_counters[username]}_{int(time.time())}"
|
||||
|
||||
self.user_states[username] = {
|
||||
'session_id': session_id,
|
||||
'session_type': session_type,
|
||||
'state': StreamState.IDLE,
|
||||
'sentence_count': 0,
|
||||
'start_time': time.time(),
|
||||
'last_update': time.time(),
|
||||
'is_first_sent': False,
|
||||
'is_end_sent': False
|
||||
}
|
||||
|
||||
util.log(1, f"开始新会话: {session_id}")
|
||||
return session_id
|
||||
|
||||
def prepare_sentence(self, username, text, force_first=False, force_end=False):
|
||||
"""
|
||||
准备发送句子,自动添加适当的标记
|
||||
|
||||
参数:
|
||||
username: 用户名
|
||||
text: 文本内容
|
||||
force_first: 强制设为第一句
|
||||
force_end: 强制设为最后一句
|
||||
|
||||
返回:
|
||||
tuple: (处理后的文本, 是否为第一句, 是否为最后一句)
|
||||
"""
|
||||
with self.lock:
|
||||
if username not in self.user_states:
|
||||
# 如果没有活跃会话,自动创建一个
|
||||
self.start_new_session(username)
|
||||
|
||||
state_info = self.user_states[username]
|
||||
state_info['last_update'] = time.time()
|
||||
|
||||
# 判断是否为第一句
|
||||
is_first = False
|
||||
if force_first or (not state_info['is_first_sent'] and state_info['sentence_count'] == 0):
|
||||
is_first = True
|
||||
state_info['is_first_sent'] = True
|
||||
state_info['state'] = StreamState.FIRST_SENTENCE
|
||||
elif state_info['sentence_count'] > 0:
|
||||
state_info['state'] = StreamState.MIDDLE_SENTENCE
|
||||
|
||||
# 判断是否为最后一句
|
||||
is_end = force_end
|
||||
if is_end:
|
||||
state_info['is_end_sent'] = True
|
||||
state_info['state'] = StreamState.LAST_SENTENCE
|
||||
|
||||
# 更新句子计数
|
||||
state_info['sentence_count'] += 1
|
||||
|
||||
# 构造带标记的文本
|
||||
marked_text = text
|
||||
if is_first and not marked_text.endswith('_<isfirst>'):
|
||||
marked_text += "_<isfirst>"
|
||||
if is_end and not marked_text.endswith('_<isend>'):
|
||||
marked_text += "_<isend>"
|
||||
return marked_text, is_first, is_end
|
||||
|
||||
def end_session(self, username):
|
||||
"""
|
||||
结束当前会话
|
||||
|
||||
参数:
|
||||
username: 用户名
|
||||
|
||||
返回:
|
||||
str: 空字符串(结束标记应该已经附加到最后一句话上)
|
||||
"""
|
||||
with self.lock:
|
||||
if username not in self.user_states:
|
||||
util.log(1, f"警告: 尝试结束不存在的会话 [{username}]")
|
||||
return ""
|
||||
|
||||
state_info = self.user_states[username]
|
||||
|
||||
# 标记会话为完成状态
|
||||
if state_info['state'] != StreamState.COMPLETED:
|
||||
state_info['state'] = StreamState.COMPLETED
|
||||
|
||||
session_duration = time.time() - state_info['start_time']
|
||||
|
||||
# 检查是否已经发送过结束标记
|
||||
if not state_info['is_end_sent']:
|
||||
util.log(1, f"警告: 会话结束但未发送过结束标记,可能存在逻辑问题")
|
||||
|
||||
return "" # 不再返回单独的_<isend>标记
|
||||
|
||||
def get_session_info(self, username):
|
||||
"""
|
||||
获取用户的会话信息
|
||||
|
||||
参数:
|
||||
username: 用户名
|
||||
|
||||
返回:
|
||||
dict: 会话信息
|
||||
"""
|
||||
with self.lock:
|
||||
if username in self.user_states:
|
||||
return self.user_states[username].copy()
|
||||
return None
|
||||
|
||||
def is_session_active(self, username):
|
||||
"""
|
||||
检查用户是否有活跃的会话
|
||||
|
||||
参数:
|
||||
username: 用户名
|
||||
|
||||
返回:
|
||||
bool: 是否有活跃会话
|
||||
"""
|
||||
with self.lock:
|
||||
if username not in self.user_states:
|
||||
return False
|
||||
|
||||
state_info = self.user_states[username]
|
||||
return state_info['state'] not in [StreamState.COMPLETED]
|
||||
|
||||
def cleanup_expired_sessions(self, timeout_seconds=300):
|
||||
"""
|
||||
清理过期的会话
|
||||
|
||||
参数:
|
||||
timeout_seconds: 超时时间(秒)
|
||||
"""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
expired_users = []
|
||||
|
||||
for username, state_info in self.user_states.items():
|
||||
if current_time - state_info['last_update'] > timeout_seconds:
|
||||
expired_users.append(username)
|
||||
|
||||
for username in expired_users:
|
||||
util.log(1, f"清理过期会话: {self.user_states[username]['session_id']}")
|
||||
del self.user_states[username]
|
||||
|
||||
def force_reset_user_state(self, username):
|
||||
"""
|
||||
强制重置用户状态(用于异常恢复)
|
||||
|
||||
参数:
|
||||
username: 用户名
|
||||
"""
|
||||
with self.lock:
|
||||
if username in self.user_states:
|
||||
old_session = self.user_states[username]['session_id']
|
||||
del self.user_states[username]
|
||||
util.log(1, f"强制重置用户状态: {username}, 旧会话: {old_session}")
|
||||
|
||||
def get_all_active_sessions(self):
|
||||
"""
|
||||
获取所有活跃会话的信息
|
||||
|
||||
返回:
|
||||
dict: 用户名 -> 会话信息
|
||||
"""
|
||||
with self.lock:
|
||||
active_sessions = {}
|
||||
for username, state_info in self.user_states.items():
|
||||
if state_info['state'] != StreamState.COMPLETED:
|
||||
active_sessions[username] = state_info.copy()
|
||||
return active_sessions
|
||||
|
||||
# 全局单例实例
|
||||
_state_manager_instance = None
|
||||
_state_manager_lock = threading.Lock()
|
||||
|
||||
def get_state_manager():
|
||||
"""
|
||||
获取流式状态管理器单例
|
||||
|
||||
返回:
|
||||
StreamStateManager: 状态管理器实例
|
||||
"""
|
||||
global _state_manager_instance
|
||||
if _state_manager_instance is None:
|
||||
with _state_manager_lock:
|
||||
if _state_manager_instance is None:
|
||||
_state_manager_instance = StreamStateManager()
|
||||
return _state_manager_instance
|
||||
|
||||
# 定时清理过期会话的线程
|
||||
def start_cleanup_thread():
|
||||
"""
|
||||
启动定时清理线程
|
||||
"""
|
||||
import threading
|
||||
|
||||
def cleanup_worker():
|
||||
while True:
|
||||
try:
|
||||
time.sleep(60) # 每分钟清理一次
|
||||
get_state_manager().cleanup_expired_sessions()
|
||||
except Exception as e:
|
||||
util.log(1, f"清理过期会话时出错: {str(e)}")
|
||||
|
||||
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
cleanup_thread.start()
|
||||
util.log(1, "流式状态管理器清理线程已启动")
|
||||
|
||||
# 自动启动清理线程
|
||||
start_cleanup_thread()
|
||||
183
utils/stream_text_processor.py
Normal file
183
utils/stream_text_processor.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import time
|
||||
from utils import util
|
||||
from core import stream_manager
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
|
||||
class StreamTextProcessor:
|
||||
"""
|
||||
安全的流式文本处理器,防止死循环和性能问题
|
||||
"""
|
||||
|
||||
def __init__(self, min_length=10, max_iterations=100, timeout_seconds=30, max_cache_size=10240):
|
||||
"""
|
||||
初始化流式文本处理器
|
||||
|
||||
参数:
|
||||
min_length: 最小发送长度阈值
|
||||
max_iterations: 最大循环次数限制
|
||||
timeout_seconds: 超时时间(秒)
|
||||
max_cache_size: 最大缓存大小(字符数)
|
||||
"""
|
||||
self.min_length = min_length
|
||||
self.max_iterations = max_iterations
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.max_cache_size = max_cache_size
|
||||
self.punctuation_marks = [",", ",", "。", "、", "!", "?", ".", "!", "?", "\n"]
|
||||
|
||||
def process_stream_text(self, text, username, is_qa=False, session_type="stream"):
|
||||
"""
|
||||
安全地处理流式文本分割和发送
|
||||
|
||||
参数:
|
||||
text: 要处理的文本
|
||||
username: 用户名
|
||||
is_qa: 是否为Q&A模式
|
||||
session_type: 会话类型
|
||||
|
||||
返回:
|
||||
bool: 处理是否成功
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
|
||||
# 获取状态管理器并开始新会话
|
||||
state_manager = get_state_manager()
|
||||
if not state_manager.is_session_active(username):
|
||||
state_manager.start_new_session(username, session_type)
|
||||
|
||||
try:
|
||||
return self._safe_process_text(text, username, is_qa, state_manager)
|
||||
except Exception as e:
|
||||
util.log(1, f"流式文本处理出错: {str(e)}")
|
||||
# 发生异常时,直接发送完整文本作为备用方案
|
||||
self._send_fallback_text(text, username, state_manager)
|
||||
return False
|
||||
|
||||
def _safe_process_text(self, text, username, is_qa, state_manager):
|
||||
"""
|
||||
安全的文本处理核心逻辑,包含缓存溢出保护
|
||||
"""
|
||||
accumulated_text = text
|
||||
iteration_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
# 缓存溢出检查
|
||||
if len(accumulated_text) > self.max_cache_size:
|
||||
util.log(1, f"文本缓存溢出,长度: {len(accumulated_text)}, 限制: {self.max_cache_size}")
|
||||
# 截断文本到安全大小
|
||||
accumulated_text = accumulated_text[:self.max_cache_size]
|
||||
util.log(1, f"文本已截断到: {len(accumulated_text)} 字符")
|
||||
|
||||
# 主处理循环,带安全保护
|
||||
while accumulated_text and iteration_count < self.max_iterations:
|
||||
# 超时检查
|
||||
if time.time() - start_time > self.timeout_seconds:
|
||||
util.log(1, f"流式处理超时,剩余文本长度: {len(accumulated_text)}")
|
||||
break
|
||||
|
||||
# 动态缓存大小检查
|
||||
if len(accumulated_text) > self.max_cache_size:
|
||||
util.log(1, f"处理过程中缓存溢出,强制发送剩余文本")
|
||||
break
|
||||
|
||||
iteration_count += 1
|
||||
|
||||
# 查找标点符号位置
|
||||
punct_indices = self._find_punctuation_indices(accumulated_text)
|
||||
|
||||
if not punct_indices:
|
||||
# 没有标点符号,退出循环
|
||||
break
|
||||
|
||||
# 尝试发送一个句子
|
||||
sent_successfully = False
|
||||
for punct_index in punct_indices:
|
||||
sentence_text = accumulated_text[:punct_index + 1]
|
||||
|
||||
if len(sentence_text) >= self.min_length:
|
||||
# 使用状态管理器准备句子
|
||||
marked_text, is_first, is_end = state_manager.prepare_sentence(
|
||||
username, sentence_text, force_first=False, force_end=False
|
||||
)
|
||||
|
||||
success = stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
if success:
|
||||
accumulated_text = accumulated_text[punct_index + 1:].lstrip()
|
||||
sent_successfully = True
|
||||
break
|
||||
else:
|
||||
util.log(1, f"发送句子失败: {marked_text[:50]}...")
|
||||
|
||||
# 如果这轮没有成功发送任何内容,退出循环防止死循环
|
||||
if not sent_successfully:
|
||||
break
|
||||
|
||||
# 发送剩余文本,如果是最后的文本则标记为结束
|
||||
if accumulated_text:
|
||||
marked_text, _, _ = state_manager.prepare_sentence(
|
||||
username, accumulated_text, force_first=False, force_end=True
|
||||
)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
else:
|
||||
# 如果没有剩余文本,需要确保最后发送的句子包含结束标记
|
||||
session_info = state_manager.get_session_info(username)
|
||||
if session_info and not session_info.get('is_end_sent', False):
|
||||
marked_text, _, _ = state_manager.prepare_sentence(
|
||||
username, "", force_first=False, force_end=True
|
||||
)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
|
||||
# 结束会话
|
||||
state_manager.end_session(username)
|
||||
|
||||
# 记录处理统计
|
||||
if iteration_count >= self.max_iterations:
|
||||
util.log(1, f"流式处理达到最大迭代次数限制: {self.max_iterations}")
|
||||
|
||||
return True
|
||||
|
||||
def _find_punctuation_indices(self, text):
|
||||
"""
|
||||
安全地查找标点符号位置
|
||||
"""
|
||||
try:
|
||||
indices = []
|
||||
for punct in self.punctuation_marks:
|
||||
try:
|
||||
index = text.find(punct)
|
||||
if index != -1:
|
||||
indices.append(index)
|
||||
except Exception as e:
|
||||
util.log(1, f"查找标点符号 '{punct}' 时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
return sorted([i for i in indices if i != -1])
|
||||
except Exception as e:
|
||||
util.log(1, f"查找标点符号时出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def _send_fallback_text(self, text, username, state_manager):
|
||||
"""
|
||||
备用发送方案,直接发送完整文本
|
||||
"""
|
||||
try:
|
||||
# 使用状态管理器准备完整文本
|
||||
marked_text, _, _ = state_manager.prepare_sentence(
|
||||
username, text, force_first=True, force_end=True
|
||||
)
|
||||
stream_manager.new_instance().write_sentence(username, marked_text)
|
||||
util.log(1, "使用备用方案发送完整文本")
|
||||
except Exception as e:
|
||||
util.log(1, f"备用发送方案也失败: {str(e)}")
|
||||
|
||||
# 全局单例实例
|
||||
_processor_instance = None
|
||||
|
||||
def get_processor():
|
||||
"""
|
||||
获取流式文本处理器单例
|
||||
"""
|
||||
global _processor_instance
|
||||
if _processor_instance is None:
|
||||
_processor_instance = StreamTextProcessor()
|
||||
return _processor_instance
|
||||
Reference in New Issue
Block a user