1.增加音频缓存功能,降低tts费率;
2.优化透传接口流式断句逻辑,域名、版本号等不断句;
3.优化数字人接口流式文本输出顺序;
4.llm透传功能接入langsmith,配置环境变量后可通过langsmith平台调优prompt;
5.优化配置中心加载逻辑,人设配置依然保留。
This commit is contained in:
guo zebin
2026-03-04 17:36:42 +08:00
parent 3b4552d50e
commit 38db690ae8
12 changed files with 1490 additions and 838 deletions

View File

@@ -1,190 +1,206 @@
from threading import Thread
from threading import Lock
import websocket
import json
import time
import ssl
import wave
import _thread as thread
from aliyunsdkcore.client import AcsClient
from aliyunsdkcore.request import CommonRequest
from core import wsa_server
from scheduler.thread_manager import MyThread
from utils import util
from utils import config_util as cfg
from core.authorize_tb import Authorize_Tb
__running = True
__my_thread = None
_token = ''
from threading import Thread
from threading import Lock
import websocket
import json
import time
import ssl
import wave
import _thread as thread
from aliyunsdkcore.client import AcsClient
from aliyunsdkcore.request import CommonRequest
from core import wsa_server
from scheduler.thread_manager import MyThread
from utils import util
from utils import config_util as cfg
from core.authorize_tb import Authorize_Tb
__running = True
__my_thread = None
_token = ''
def __post_token():
global _token
__client = AcsClient(
cfg.key_ali_nls_key_id,
cfg.key_ali_nls_key_secret,
"cn-shanghai"
)
if not cfg.key_ali_nls_key_id or not cfg.key_ali_nls_key_secret:
util.log(2, "AliNLS 凭据未配置,跳过 token 刷新。")
return False
__request = CommonRequest()
__request.set_method('POST')
__request.set_domain('nls-meta.cn-shanghai.aliyuncs.com')
__request.set_version('2019-02-28')
__request.set_action_name('CreateToken')
info = json.loads(__client.do_action_with_exception(__request))
_token = info['Token']['Id']
authorize = Authorize_Tb()
authorize_info = authorize.find_by_userid(cfg.key_ali_nls_key_id)
if authorize_info is not None:
authorize.update_by_userid(cfg.key_ali_nls_key_id, _token, info['Token']['ExpireTime']*1000)
else:
authorize.add(cfg.key_ali_nls_key_id, _token, info['Token']['ExpireTime']*1000)
try:
__client = AcsClient(
cfg.key_ali_nls_key_id,
cfg.key_ali_nls_key_secret,
"cn-shanghai"
)
__request = CommonRequest()
__request.set_method('POST')
__request.set_domain('nls-meta.cn-shanghai.aliyuncs.com')
__request.set_version('2019-02-28')
__request.set_action_name('CreateToken')
info = json.loads(__client.do_action_with_exception(__request))
_token = info['Token']['Id']
authorize = Authorize_Tb()
authorize_info = authorize.find_by_userid(cfg.key_ali_nls_key_id)
if authorize_info is not None:
authorize.update_by_userid(cfg.key_ali_nls_key_id, _token, info['Token']['ExpireTime']*1000)
else:
authorize.add(cfg.key_ali_nls_key_id, _token, info['Token']['ExpireTime']*1000)
util.log(1, "AliNLS token刷新成功")
return True
except Exception as e:
util.log(2, f"AliNLS token刷新失败: {str(e)}")
return False
def __runnable():
while __running:
__post_token()
time.sleep(60 * 60 * 12)
if __post_token():
time.sleep(60 * 60 * 12)
else:
time.sleep(60)
def start():
MyThread(target=__runnable).start()
class ALiNls:
# 初始化
def __init__(self, username):
self.__URL = 'wss://nls-gateway-cn-shenzhen.aliyuncs.com/ws/v1'
self.__ws = None
self.__frames = []
self.started = False
self.__closing = False
self.__task_id = ''
self.done = False
self.finalResults = ""
self.username = username
self.data = b''
self.__endding = False
self.__is_close = False
self.lock = Lock()
def __create_header(self, name):
if name == 'StartTranscription':
self.__task_id = util.random_hex(32)
header = {
"appkey": cfg.key_ali_nls_app_key,
"message_id": util.random_hex(32),
"task_id": self.__task_id,
"namespace": "SpeechTranscriber",
"name": name
}
return header
# 收到websocket消息的处理
def on_message(self, ws, message):
try:
data = json.loads(message)
header = data['header']
name = header['name']
if name == 'TranscriptionStarted':
self.started = True
if name == 'SentenceEnd':
self.done = True
self.finalResults = data['payload']['result']
if wsa_server.get_web_instance().is_connected(self.username):
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
if wsa_server.get_instance().is_connected(self.username):
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
wsa_server.get_instance().add_cmd(content)
ws.close()#TODO
elif name == 'TranscriptionResultChanged':
self.finalResults = data['payload']['result']
if wsa_server.get_web_instance().is_connected(self.username):
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
if wsa_server.get_instance().is_connected(self.username):
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
wsa_server.get_instance().add_cmd(content)
except Exception as e:
print(e)
# print("### message:", message)
# 收到websocket的关闭要求
def on_close(self, ws, code, msg):
self.__endding = True
self.__is_close = True
# 收到websocket错误的处理
def on_error(self, ws, error):
print("aliyun asr error:", error)
self.started = True #避免在aliyun asr出错时recorder一直等待start状态返回
# 收到websocket连接建立的处理
def on_open(self, ws):
self.__endding = False
#为了兼容多路asr关闭过程数据
def run(*args):
while self.__endding == False:
try:
if len(self.__frames) > 0:
with self.lock:
frame = self.__frames.pop(0)
if isinstance(frame, dict):
ws.send(json.dumps(frame))
elif isinstance(frame, bytes):
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
self.data += frame
else:
time.sleep(0.001) # 避免忙等
except Exception as e:
print(e)
break
if self.__is_close == False:
for frame in self.__frames:
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
frame = {"header": self.__create_header('StopTranscription')}
ws.send(json.dumps(frame))
thread.start_new_thread(run, ())
class ALiNls:
# 初始化
def __init__(self, username):
self.__URL = 'wss://nls-gateway-cn-shenzhen.aliyuncs.com/ws/v1'
self.__ws = None
self.__frames = []
self.started = False
self.__closing = False
self.__task_id = ''
self.done = False
self.finalResults = ""
self.username = username
self.data = b''
self.__endding = False
self.__is_close = False
self.lock = Lock()
def __create_header(self, name):
if name == 'StartTranscription':
self.__task_id = util.random_hex(32)
header = {
"appkey": cfg.key_ali_nls_app_key,
"message_id": util.random_hex(32),
"task_id": self.__task_id,
"namespace": "SpeechTranscriber",
"name": name
}
return header
# 收到websocket消息的处理
def on_message(self, ws, message):
try:
data = json.loads(message)
header = data['header']
name = header['name']
if name == 'TranscriptionStarted':
self.started = True
if name == 'SentenceEnd':
self.done = True
self.finalResults = data['payload']['result']
if wsa_server.get_web_instance().is_connected(self.username):
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
if wsa_server.get_instance().is_connected(self.username):
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
wsa_server.get_instance().add_cmd(content)
ws.close()#TODO
elif name == 'TranscriptionResultChanged':
self.finalResults = data['payload']['result']
if wsa_server.get_web_instance().is_connected(self.username):
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
if wsa_server.get_instance().is_connected(self.username):
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
wsa_server.get_instance().add_cmd(content)
except Exception as e:
print(e)
# print("### message:", message)
# 收到websocket的关闭要求
def on_close(self, ws, code, msg):
self.__endding = True
self.__is_close = True
# 收到websocket错误的处理
def on_error(self, ws, error):
print("aliyun asr error:", error)
self.started = True #避免在aliyun asr出错时recorder一直等待start状态返回
# 收到websocket连接建立的处理
def on_open(self, ws):
self.__endding = False
#为了兼容多路asr关闭过程数据
def run(*args):
while self.__endding == False:
try:
if len(self.__frames) > 0:
with self.lock:
frame = self.__frames.pop(0)
if isinstance(frame, dict):
ws.send(json.dumps(frame))
elif isinstance(frame, bytes):
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
self.data += frame
else:
time.sleep(0.001) # 避免忙等
except Exception as e:
print(e)
break
if self.__is_close == False:
for frame in self.__frames:
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
frame = {"header": self.__create_header('StopTranscription')}
ws.send(json.dumps(frame))
thread.start_new_thread(run, ())
def __connect(self):
if not _token:
util.log(2, "AliNLS token为空本次语音识别连接跳过")
self.started = True
return
self.finalResults = ""
self.done = False
with self.lock:
self.__frames.clear()
self.__ws = websocket.WebSocketApp(self.__URL + '?token=' + _token, on_message=self.on_message)
self.__ws.on_open = self.on_open
self.__ws.on_error = self.on_error
self.__ws.on_close = self.on_close
self.__ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def send(self, buf):
with self.lock:
self.__frames.append(buf)
def start(self):
Thread(target=self.__connect, args=[]).start()
data = {
'header': self.__create_header('StartTranscription'),
"payload": {
"format": "pcm",
"sample_rate": 16000,
"enable_intermediate_result": True,
"enable_punctuation_prediction": False,
"enable_inverse_text_normalization": True,
"speech_noise_threshold": -1
}
}
self.send(data)
def end(self):
self.__endding = True
with wave.open('cache_data/input2.wav', 'wb') as wf:
# 设置音频参数
n_channels = 1 # 单声道
sampwidth = 2 # 16 位音频,每个采样点 2 字节
wf.setnchannels(n_channels)
wf.setsampwidth(sampwidth)
wf.setframerate(16000)
wf.writeframes(self.data)
self.data = b''
self.__ws = websocket.WebSocketApp(self.__URL + '?token=' + _token, on_message=self.on_message)
self.__ws.on_open = self.on_open
self.__ws.on_error = self.on_error
self.__ws.on_close = self.on_close
self.__ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def send(self, buf):
with self.lock:
self.__frames.append(buf)
def start(self):
Thread(target=self.__connect, args=[]).start()
data = {
'header': self.__create_header('StartTranscription'),
"payload": {
"format": "pcm",
"sample_rate": 16000,
"enable_intermediate_result": True,
"enable_punctuation_prediction": False,
"enable_inverse_text_normalization": True,
"speech_noise_threshold": -1
}
}
self.send(data)
def end(self):
self.__endding = True
with wave.open('cache_data/input2.wav', 'wb') as wf:
# 设置音频参数
n_channels = 1 # 单声道
sampwidth = 2 # 16 位音频,每个采样点 2 字节
wf.setnchannels(n_channels)
wf.setsampwidth(sampwidth)
wf.setframerate(16000)
wf.writeframes(self.data)
self.data = b''

View File

@@ -17,7 +17,7 @@ from bionicmemory.algorithms.newton_cooling_helper import NewtonCoolingHelper, C
from bionicmemory.core.chroma_service import ChromaService
from bionicmemory.services.summary_service import SummaryService
from bionicmemory.algorithms.clustering_suppression import ClusteringSuppression
from utils.api_embedding_service import get_embedding_service
from utils.api_embedding_service import get_embedding_service
# 使用统一日志配置
from bionicmemory.utils.logging_config import get_logger
@@ -238,7 +238,8 @@ class LongShortTermMemorySystem:
logger.info(f"[仿生记忆] _prepare_document_data: embedding生成完成, 类型={type(embedding)}")
except Exception as e:
logger.error(f"生成embedding失败: {e}")
embedding = []
# 返回None而不是空列表让调用方知道这是一个失败的情况
embedding = None
# 准备元数据
current_time = datetime.now().isoformat()
@@ -398,12 +399,15 @@ class LongShortTermMemorySystem:
embedding_list = embedding.tolist()
else:
embedding_list = embedding
# 检查长度
# 检查长度和维度
if len(embedding_list) > 0:
embeddings_param = [embedding_list]
logger.debug(f"添加embedding到长期记忆: doc_id={doc_id}, 维度={len(embedding_list)}")
else:
logger.warning(f"⚠️ Embedding为空列表: doc_id={doc_id}")
embeddings_param = None
else:
logger.warning(f"⚠️ Embedding为None: doc_id={doc_id}, 将不包含向量信息")
embeddings_param = None
self.chroma_service.add_documents(
@@ -1046,6 +1050,16 @@ class LongShortTermMemorySystem:
)
logger.info(f"[仿生记忆] 步骤1完成: doc_id={doc_id}, user_embedding类型={type(user_embedding)}")
# 检查embedding是否生成成功
if user_embedding is None:
logger.warning("[仿生记忆] embedding生成失败跳过记忆检索使用默认响应")
return {
"system_prompt": "你是一个智能助手。由于技术问题,暂时无法访问历史记忆,但我会尽力帮助你。",
"retrieved_memories": [],
"user_doc_id": None,
"assistant_doc_id": None
}
# 使用用户embedding进行检索
logger.info("[仿生记忆] 步骤2: 使用用户embedding进行检索")
query_embedding = user_embedding
@@ -1485,4 +1499,4 @@ if __name__ == "__main__":
stats_after['short_term_memory']['total_records'] == 0:
logger.info(f"✅ 用户 {target_user_id} (key: {target_key[:10]}...) 历史记录清除成功!")
else:
logger.warning(f"⚠️ 用户 {target_user_id} (key: {target_key[:10]}...) 历史记录可能未完全清除")
logger.warning(f"⚠️ 用户 {target_user_id} (key: {target_key[:10]}...) 历史记录可能未完全清除")

View File

@@ -31,7 +31,9 @@ from queue import Queue
import re # 添加正则表达式模块用于过滤表情符号
import uuid
import uuid
import hashlib
from urllib.parse import urlparse, urljoin
@@ -136,16 +138,24 @@ else:
import platform
if platform.system() == "Windows":
import sys
sys.path.append("test/ovr_lipsync")
from test_olipsync import LipSyncGenerator
if platform.system() == "Windows":
import sys
_fay_runtime_dir = os.path.abspath(os.path.dirname(__file__))
if hasattr(sys, "_MEIPASS"):
_fay_runtime_dir = os.path.abspath(sys._MEIPASS)
else:
_fay_runtime_dir = os.path.abspath(os.path.join(_fay_runtime_dir, ".."))
_lipsync_dir = os.path.join(_fay_runtime_dir, "test", "ovr_lipsync")
if _lipsync_dir not in map(os.path.abspath, sys.path):
sys.path.insert(0, _lipsync_dir)
from test_olipsync import LipSyncGenerator
@@ -237,7 +247,15 @@ class FeiFei:
self.think_display_limit = 400
self.user_conv_map = {} #存储用户对话id及句子流序号key为(username, conversation_id)
self.pending_isfirst = {} # 存储因prestart被过滤而延迟的isfirst标记key为username
self.pending_isfirst = {} # 存储因prestart被过滤而延迟的isfirst标记key为username
self.tts_cache = {}
self.tts_cache_limit = 1000
self.tts_cache_lock = threading.Lock()
self.user_audio_conv_map = {} # 仅用于音频片段的连续序号(避免文本序号空洞导致乱序/缺包)
self.human_audio_order_map = {}
self.human_audio_order_lock = threading.Lock()
self.human_audio_reorder_wait_seconds = 0.2
self.human_audio_first_wait_seconds = 1.2
@@ -868,7 +886,7 @@ class FeiFei:
#获取不同情绪声音
def __get_mood_voice(self):
def __get_mood_voice(self):
voice = tts_voice.get_voice_of(config_util.config["attribute"]["voice"])
@@ -883,18 +901,168 @@ class FeiFei:
styleList = voice.value["styleList"]
sayType = styleList["calm"]
return sayType
# 合成声音
sayType = styleList["calm"]
return sayType
def __build_tts_cache_key(self, text, style):
tts_module = str(getattr(cfg, "tts_module", "") or "")
style_str = str(style or "")
voice_name = ""
try:
voice_name = str(config_util.config.get("attribute", {}).get("voice", "") or "")
except Exception:
voice_name = ""
if tts_module == "volcano":
try:
volcano_voice = str(getattr(cfg, "volcano_tts_voice_type", "") or "")
if volcano_voice:
voice_name = volcano_voice
except Exception:
pass
raw = f"{tts_module}|{voice_name}|{style_str}|{text}"
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
def __get_tts_cache(self, key):
with self.tts_cache_lock:
file_url = self.tts_cache.get(key)
if not file_url:
return None
if os.path.exists(file_url):
return file_url
with self.tts_cache_lock:
if key in self.tts_cache:
del self.tts_cache[key]
return None
def __set_tts_cache(self, key, file_url):
if not file_url:
return
with self.tts_cache_lock:
self.tts_cache[key] = file_url
while len(self.tts_cache) > self.tts_cache_limit:
try:
self.tts_cache.pop(next(iter(self.tts_cache)))
except Exception:
break
def __send_human_audio_ordered(self, content, username, conversation_id, conversation_msg_no, is_end=False):
now = time.time()
sent_messages = []
data = content.get("Data", {}) if isinstance(content, dict) else {}
has_audio_payload = bool(data.get("Value")) or bool(data.get("HttpValue"))
is_end_marker_only = bool(is_end or data.get("IsEnd", 0)) and (not has_audio_payload)
seq = None
try:
if conversation_msg_no is not None:
seq = int(conversation_msg_no)
except Exception:
seq = None
# Fallback to direct send for legacy paths without sequence metadata.
if (not conversation_id) or (seq is None):
if is_end_marker_only:
return 0
wsa_server.get_instance().add_cmd(content)
return 1
key = (username or "User", conversation_id)
with self.human_audio_order_lock:
state = self.human_audio_order_map.get(key)
if state is None:
state = {
"next_seq": None,
"buffer": {},
"last_progress_time": now,
"first_wait_start": now,
"start_known": False,
"end_seq": None,
"pending_end_seq": None,
}
self.human_audio_order_map[key] = state
next_seq = state.get("next_seq")
if (next_seq is not None) and (seq < next_seq):
return 0
def _mark_buffer_end(target_seq):
existed = state["buffer"].get(target_seq)
if isinstance(existed, dict):
existed_data = existed.get("Data", {})
if isinstance(existed_data, dict):
existed_data["IsEnd"] = 1
return True
return False
if is_end_marker_only:
target_seq = None
if seq in state["buffer"]:
target_seq = seq
elif (seq - 1) in state["buffer"]:
target_seq = seq - 1
elif state["buffer"]:
target_seq = max(state["buffer"].keys())
if (target_seq is not None) and _mark_buffer_end(target_seq):
end_seq = state.get("end_seq")
state["end_seq"] = target_seq if end_seq is None else max(end_seq, target_seq)
state["pending_end_seq"] = None
else:
state["pending_end_seq"] = seq
else:
if seq in state["buffer"]:
return 0
state["buffer"][seq] = content
pending_end_seq = state.get("pending_end_seq")
if pending_end_seq is not None:
if (seq == pending_end_seq) or (seq == pending_end_seq - 1):
if _mark_buffer_end(seq):
end_seq = state.get("end_seq")
state["end_seq"] = seq if end_seq is None else max(end_seq, seq)
state["pending_end_seq"] = None
if is_end:
end_seq = state.get("end_seq")
state["end_seq"] = seq if end_seq is None else max(end_seq, seq)
is_first_flag = bool(data.get("IsFirst", 0))
if (not state["start_known"]) and is_first_flag:
state["start_known"] = True
state["next_seq"] = seq
state["last_progress_time"] = now
elif (not state["start_known"]) and (seq == 0):
state["start_known"] = True
state["next_seq"] = 0
state["last_progress_time"] = now
elif (not state["start_known"]):
first_elapsed = now - state.get("first_wait_start", now)
if (first_elapsed >= self.human_audio_first_wait_seconds) and (0 in state["buffer"]):
state["start_known"] = True
state["next_seq"] = 0
state["last_progress_time"] = now
def _flush_contiguous():
flush_count = 0
while (state["next_seq"] is not None) and (state["next_seq"] in state["buffer"]):
sent_messages.append(state["buffer"].pop(state["next_seq"]))
state["next_seq"] += 1
state["last_progress_time"] = now
flush_count += 1
return flush_count
_flush_contiguous()
end_seq = state.get("end_seq")
if (end_seq is not None) and (state.get("next_seq") is not None) and (state["next_seq"] > end_seq) and (not state["buffer"]):
self.human_audio_order_map.pop(key, None)
for message in sent_messages:
wsa_server.get_instance().add_cmd(message)
return len(sent_messages)
def say(self, interact, text, type = ""):
@@ -1085,25 +1253,27 @@ class FeiFei:
# 如果 conv_map_key 不存在,尝试使用 username 作为备用查找
if not conv_info and text and text.strip():
if not conv_info and text and text.strip():
# 查找所有匹配用户名的会话
for (u, c), info in list(self.user_conv_map.items()):
for (u, c), info in list(self.user_conv_map.items()):
if u == username and info.get("content_id", 0) > 0:
if u == username and info.get("content_id", 0) > 0:
content_id = info.get("content_id", 0)
content_id = info.get("content_id", 0)
conv_info = info
conv = info.get("conversation_id", c)
conv_map_key = (username, conv)
conv_info = info
util.log(1, f"警告:使用备用会话 ({u}, {c}) 的 content_id={content_id},原 key=({username}, {conv})")
util.log(1, f"警告:使用备用会话 ({u}, {c}) 的 content_id={content_id},原 key=({username}, {conv})")
break
@@ -1157,10 +1327,27 @@ class FeiFei:
# 会话结束时清理 user_conv_map 中的对应条目,避免内存泄漏
# 固化当前会话序号避免异步音频线程读取时会话映射已被清理而回落为0
current_conv_info = self.user_conv_map.get(conv_map_key, {})
if (not current_conv_info) and (not conv):
for (u, c), info in list(self.user_conv_map.items()):
if u == username and info.get("conversation_id", ""):
current_conv_info = info
conv = info.get("conversation_id", c)
conv_map_key = (username, conv)
break
if current_conv_info:
interact.data["conversation_id"] = current_conv_info.get("conversation_id", conv)
interact.data["conversation_msg_no"] = current_conv_info.get("conversation_msg_no", 0)
else:
if conv:
interact.data["conversation_id"] = conv
interact.data["conversation_msg_no"] = interact.data.get("conversation_msg_no", 0)
# 会话结束时清理 user_conv_map 中的对应条目,避免内存泄漏
if is_end and conv_map_key in self.user_conv_map:
if is_end and conv_map_key in self.user_conv_map:
del self.user_conv_map[conv_map_key]
@@ -1394,7 +1581,15 @@ class FeiFei:
tm = time.time()
result = self.sp.to_sample(filtered_text, self.__get_mood_voice())
mood_voice = self.__get_mood_voice()
cache_key = self.__build_tts_cache_key(filtered_text, mood_voice)
cache_result = self.__get_tts_cache(cache_key)
if cache_result is not None:
result = cache_result
util.printInfo(1, interact.data.get('user'), 'TTS cache hit')
else:
result = self.sp.to_sample(filtered_text, mood_voice)
self.__set_tts_cache(cache_key, result)
# 合成完成后再次检查会话是否仍有效,避免继续输出旧会话结果
@@ -1439,7 +1634,22 @@ class FeiFei:
if result is not None or is_first or is_end:
# 为数字人音频单独维护连续序号,避免 conversation_msg_no 因无音频片段产生空洞
audio_conv_id = interact.data.get("conversation_id", "") or ""
audio_conv_key = (username, audio_conv_id)
audio_msg_no = None
if result is not None:
audio_msg_no = self.user_audio_conv_map.get(audio_conv_key, -1) + 1
self.user_audio_conv_map[audio_conv_key] = audio_msg_no
elif is_end:
audio_msg_no = self.user_audio_conv_map.get(audio_conv_key, None)
if audio_conv_key in self.user_audio_conv_map:
del self.user_audio_conv_map[audio_conv_key]
interact.data["audio_conversation_msg_no"] = audio_msg_no
if is_end and audio_conv_key in self.user_audio_conv_map:
del self.user_audio_conv_map[audio_conv_key]
if result is not None or is_first or is_end:
# prestart 内容不需要进入音频处理流程
@@ -1490,7 +1700,26 @@ class FeiFei:
# 发送HTTP GET请求以获取WAV文件内容
response = requests.get(url, stream=True)
if url is None:
return None
url = str(url).strip()
if not url:
return None
if os.path.isfile(url):
return url
parsed_url = urlparse(url)
if not parsed_url.scheme:
if url.startswith('//'):
url = 'http:' + url
else:
base_url = str(getattr(cfg, "fay_url", "") or "").strip()
if base_url:
url = urljoin(base_url.rstrip('/') + '/', url.lstrip('/'))
response = requests.get(url, stream=True)
response.raise_for_status() # 检查请求是否成功
@@ -1946,61 +2175,115 @@ class FeiFei:
#发送音频给数字人接口
if file_url is not None and wsa_server.get_instance().get_client_output(interact.data.get("user")):
# 使用 (username, conversation_id) 作为 key 获取会话信息
audio_username = interact.data.get("user", "User")
audio_conv_id = interact.data.get("conversation_id") or ""
audio_conv_info = self.user_conv_map.get((audio_username, audio_conv_id), {})
content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0, 'CONV_ID' : audio_conv_info.get("conversation_id", ""), 'CONV_MSG_NO' : audio_conv_info.get("conversation_msg_no", 0) }, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}
#计算lips
if platform.system() == "Windows":
try:
lip_sync_generator = LipSyncGenerator()
viseme_list = lip_sync_generator.generate_visemes(os.path.abspath(file_url))
consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
content["Data"]["Lips"] = consolidated_visemes
except Exception as e:
print(e)
util.printInfo(1, interact.data.get("user"), "唇型数据生成失败")
wsa_server.get_instance().add_cmd(content)
util.printInfo(1, interact.data.get("user"), "数字人接口发送音频数据成功")
#发送音频给数字人接口
if wsa_server.get_instance().get_client_output(interact.data.get("user")):
# 使用 (username, conversation_id) 作为 key 获取会话信息
audio_username = interact.data.get("user", "User")
audio_conv_id = interact.data.get("conversation_id") or ""
audio_conv_info = self.user_conv_map.get((audio_username, audio_conv_id), {})
msg_no_from_interact = interact.data.get("audio_conversation_msg_no", None)
conv_id_for_send = audio_conv_id if audio_conv_id else audio_conv_info.get("conversation_id", "")
if msg_no_from_interact is None:
fallback_no = interact.data.get("conversation_msg_no", None)
if fallback_no is None:
conv_msg_no_for_send = audio_conv_info.get("conversation_msg_no", 0)
else:
conv_msg_no_for_send = fallback_no
else:
conv_msg_no_for_send = msg_no_from_interact
if file_url is not None:
content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0, 'CONV_ID' : conv_id_for_send, 'CONV_MSG_NO' : conv_msg_no_for_send }, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}
#计算lips
if platform.system() == "Windows":
try:
lip_sync_generator = LipSyncGenerator()
viseme_list = lip_sync_generator.generate_visemes(os.path.abspath(file_url))
consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
content["Data"]["Lips"] = consolidated_visemes
except Exception as e:
print(e)
util.printInfo(1, interact.data.get("user"), "唇型数据生成失败")
sent_count = self.__send_human_audio_ordered(
content,
audio_username,
conv_id_for_send,
conv_msg_no_for_send,
is_end=bool(interact.data.get("isend", False)),
)
if sent_count > 0:
util.printInfo(1, interact.data.get("user"), "digital human audio sent")
else:
util.printInfo(1, interact.data.get("user"), "digital human audio queued")
elif bool(interact.data.get("isend", False)):
# 没有音频文件时,也要给数字人发送结束标记,避免客户端一直等待
end_target_seq = conv_msg_no_for_send
try:
end_target_seq = int(conv_msg_no_for_send)
except Exception:
end_target_seq = conv_msg_no_for_send
end_content = {
'Topic': 'human',
'Data': {
'Key': 'audio',
'Value': '',
'HttpValue': '',
'Text': text,
'Time': 0,
'Type': interact.interleaver,
'IsFirst': 1 if interact.data.get("isfirst", False) else 0,
'IsEnd': 1,
'CONV_ID': conv_id_for_send,
'CONV_MSG_NO': end_target_seq
},
'Username': interact.data.get('user'),
'robot': f'{cfg.fay_url}/robot/Speaking.jpg'
}
sent_count = self.__send_human_audio_ordered(
end_content,
audio_username,
conv_id_for_send,
end_target_seq,
is_end=True,
)
if sent_count > 0:
util.printInfo(1, interact.data.get("user"), "digital human audio end sent")
else:
util.printInfo(1, interact.data.get("user"), "digital human audio end queued")

View File

@@ -1,288 +1,318 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fay broadcast MCP server (SSE transport).
暴露 `broadcast_message` 工具,将文本/音频透传到 Fay 的 `/transparent-pass`。
环境变量:
- FAY_BROADCAST_API 默认 http://127.0.0.1:5000/transparent-pass
- FAY_BROADCAST_USER 默认 User
- FAY_BROADCAST_TIMEOUT 默认 10
- FAY_MCP_SSE_HOST 默认 0.0.0.0
- FAY_MCP_SSE_PORT 默认 8765
- FAY_MCP_SSE_PATH SSE 路径(默认 /sse
- FAY_MCP_MSG_PATH 消息 POST 路径(默认 /messages
"""
import asyncio
import logging
import os
import sys
import json
from typing import Any, Dict, Tuple, List, Optional
try:
from mcp.server import Server
from mcp.types import Tool, TextContent
from mcp.server.sse import SseServerTransport
from faymcp import tool_registry
from faymcp import mcp_service
except ImportError:
print("缺少 mcp 库请先安装pip install mcp", file=sys.stderr)
sys.exit(1)
try:
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route
except ImportError:
print("缺少 starlette请先安装pip install starlette sse-starlette", file=sys.stderr)
sys.exit(1)
try:
import uvicorn
except ImportError:
print("缺少 uvicorn请先安装pip install uvicorn", file=sys.stderr)
sys.exit(1)
try:
import requests
except ImportError:
print("缺少 requests请先安装pip install requests", file=sys.stderr)
sys.exit(1)
log = logging.getLogger("fay_mcp_server")
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
SERVER_NAME = "fay_broadcast"
DEFAULT_API_URL = os.environ.get("FAY_BROADCAST_API", "http://127.0.0.1:5000/transparent-pass")
DEFAULT_USER = os.environ.get("FAY_BROADCAST_USER", "User")
DEFAULT_SPEAKER = os.environ.get("FAY_BROADCAST_SPEAKER", "\u5e7f\u64ad\u6d88\u606f")
REQUEST_TIMEOUT = float(os.environ.get("FAY_BROADCAST_TIMEOUT", "10"))
HOST = os.environ.get("FAY_MCP_SSE_HOST", "0.0.0.0")
PORT = int(os.environ.get("FAY_MCP_SSE_PORT", "8765"))
SSE_PATH = os.environ.get("FAY_MCP_SSE_PATH", "/sse")
MSG_PATH = os.environ.get("FAY_MCP_MSG_PATH", "/messages")
server = None # Removed global singleton
sse_transport = SseServerTransport(MSG_PATH)
# 聚合工具索引namespaced_tool_name -> (server_id, tool_name)
_aggregated_index: Dict[str, Tuple[int, str]] = {}
def _text_content(text: str) -> TextContent:
try:
return TextContent(type="text", text=text)
except Exception:
return {"type": "text", "text": text} # type: ignore[return-value]
TOOLS: list[Tool] = [
Tool(
name="broadcast_message",
description="通过 Fay 的 /transparent-pass 广播文本/音频SSE 服务器)",
inputSchema={
"type": "object",
"properties": {
"text": {"type": "string", "description": "要广播的文本audio_url为空时必填"},
"audio_url": {"type": "string", "description": "可选音频 URL"},
"user": {"type": "string", "description": "目标用户名,默认 FAY_BROADCAST_USER 或 User"},
"speaker": {
"type": "string",
"description": "\u5e7f\u64ad\u4eba\u663e\u793a\u540d\uff0c\u8f93\u51fa\u4e3a\"{speaker}\u8bf4\uff1a{text}\"",
},
},
"required": [],
},
)
]
async def _handle_list_tools() -> list[Tool]:
# 本地广播工具 + Fay 当前在线 MCP 工具的聚合视图namespaced
aggregated = []
try:
aggregated = _build_aggregated_tools()
except Exception as e:
log.warning(f"Failed to build aggregated tools: {e}")
return TOOLS + aggregated
def _parse_arguments(arguments: Dict[str, Any]) -> Tuple[str, str, str, str]:
text = str(arguments.get("text", "") or "").strip()
audio_url = str(arguments.get("audio_url", "") or "").strip()
user = str(arguments.get("user", "") or "").strip() or DEFAULT_USER
speaker = str(arguments.get("speaker", "") or "").strip() or DEFAULT_SPEAKER
return text, audio_url, user, speaker
def _build_aggregated_tools() -> List[Tool]:
"""
将 Fay 已连接的 MCP 工具聚合,对外暴露为 namespaced 名称:
<server_id>:<tool_name>
"""
tools: List[Tool] = []
_aggregated_index.clear()
server_name_map = {s.get("id"): s.get("name", f"Server{s.get('id')}") for s in mcp_service.mcp_servers or []}
for entry in tool_registry.get_enabled_tools():
server_id = entry.get("server_id")
tool_name = entry.get("name")
if server_id is None or not tool_name:
continue
agg_name = f"{server_id}:{tool_name}"
desc = entry.get("description", "")
server_label = server_name_map.get(server_id, f"Server {server_id}")
agg_desc = f"{desc} [via {server_label}]"
input_schema = entry.get("inputSchema") or {}
tool = Tool(
name=agg_name,
description=agg_desc,
inputSchema=input_schema if isinstance(input_schema, dict) else {},
)
tools.append(tool)
_aggregated_index[agg_name] = (server_id, tool_name)
return tools
async def _send_broadcast(payload: Dict[str, Any]) -> Tuple[bool, str]:
def _post() -> Tuple[bool, str]:
body = json.dumps(payload, ensure_ascii=True).encode("utf-8")
resp = requests.post(
DEFAULT_API_URL,
data=body,
headers={"Content-Type": "application/json; charset=utf-8"},
timeout=REQUEST_TIMEOUT,
)
try:
data = resp.json()
except Exception:
data = None
if resp.ok:
if isinstance(data, dict):
msg = data.get("message") or data.get("msg") or ""
code = data.get("code")
if isinstance(code, int) and code >= 400:
return False, msg or f"Broadcast failed with code {code}"
return True, msg or "Broadcast sent via Fay."
return True, "Broadcast sent via Fay."
err_detail = ""
if isinstance(data, dict):
err_detail = data.get("message") or data.get("error") or data.get("msg") or ""
if not err_detail:
err_detail = resp.text
return False, f"HTTP {resp.status_code}: {err_detail}"
try:
return await asyncio.to_thread(_post)
except Exception as e:
return False, f"{type(e).__name__}: {e}"
async def _handle_call_tool(name: str, arguments: Dict[str, Any]) -> list[TextContent]:
# 本地广播
if name == "broadcast_message":
text, audio_url, user, speaker = _parse_arguments(arguments or {})
if not text and not audio_url:
return [_text_content("Either 'text' or 'audio_url' must be provided.")]
payload: Dict[str, Any] = {"user": user}
if text:
payload["text"] = f"{speaker}\u8bf4\uff1a{text}"
if audio_url:
payload["audio"] = audio_url
ok, message = await _send_broadcast(payload)
prefix = "success" if ok else "error"
return [_text_content(f"{prefix}: {message}")]
# 聚合的远端 MCP 工具
target = _aggregated_index.get(name)
if not target:
return [_text_content(f"Unknown tool: {name}")]
server_id, tool_name = target
try:
success, result = await asyncio.to_thread(mcp_service.call_mcp_tool, server_id, tool_name, arguments or {})
if not success:
return [_text_content(f"error: {result}")]
return _normalize_result(result)
except Exception as e:
return [_text_content(f"error: {type(e).__name__}: {e}")]
def _normalize_result(result: Any) -> List[TextContent]:
"""
将上游返回的任意对象转换为 MCP 文本内容列表。
"""
# 如果已经是 TextContent 或列表,直接返回
try:
from mcp.types import TextContent
if isinstance(result, TextContent):
return [result]
except Exception:
pass
if isinstance(result, list):
contents: List[TextContent] = []
for item in result:
try:
if hasattr(item, "type") and getattr(item, "type", "") == "text" and hasattr(item, "text"):
contents.append(item)
continue
except Exception:
pass
try:
if isinstance(item, dict) and item.get("type") == "text":
contents.append(TextContent(type="text", text=str(item.get("text", "")))) # type: ignore
continue
except Exception:
pass
contents.append(_text_content(str(item)))
return contents
return [_text_content(str(result))]
async def sse_endpoint(request):
# 为每个连接创建新的 Server 实例以支持并发
local_server = Server(SERVER_NAME)
# 注册工具处理程序
local_server.list_tools()(_handle_list_tools)
local_server.call_tool()(_handle_call_tool)
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
await local_server.run(read_stream, write_stream, local_server.create_initialization_options())
# 客户端断开时返回空响应,避免 NoneType 问题
return Response()
routes = [
Route(SSE_PATH, sse_endpoint, methods=["GET"]),
Mount(MSG_PATH, app=sse_transport.handle_post_message),
]
app = Starlette(routes=routes)
def main():
log.info(f"SSE MCP server started at http://{HOST}:{PORT}{SSE_PATH}")
log.info(f"Message endpoint mounted at {MSG_PATH}")
uvicorn.run(app, host=HOST, port=PORT, log_level="info")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fay broadcast MCP server (SSE transport).
暴露 `broadcast_message` 工具,将文本/音频透传到 Fay 的 `/transparent-pass`。
环境变量:
- FAY_BROADCAST_API 默认 http://127.0.0.1:5000/transparent-pass
- FAY_BROADCAST_USER 默认 User
- FAY_BROADCAST_TIMEOUT 默认 10
- FAY_MCP_SSE_HOST 默认 0.0.0.0
- FAY_MCP_SSE_PORT 默认 8765
- FAY_MCP_SSE_PATH SSE 路径(默认 /sse
- FAY_MCP_MSG_PATH 消息 POST 路径(默认 /messages
"""
import asyncio
import logging
import os
import sys
import json
from typing import Any, Dict, Tuple, List, Optional
try:
from mcp.server import Server
from mcp.types import Tool, TextContent
from mcp.server.sse import SseServerTransport
from faymcp import tool_registry
from faymcp import mcp_service
except ImportError:
print("缺少 mcp 库请先安装pip install mcp", file=sys.stderr)
sys.exit(1)
try:
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route
except ImportError:
print("缺少 starlette请先安装pip install starlette sse-starlette", file=sys.stderr)
sys.exit(1)
try:
import uvicorn
except ImportError:
print("缺少 uvicorn请先安装pip install uvicorn", file=sys.stderr)
sys.exit(1)
try:
import requests
except ImportError:
print("缺少 requests请先安装pip install requests", file=sys.stderr)
sys.exit(1)
log = logging.getLogger("fay_mcp_server")
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
SERVER_NAME = "fay_broadcast"
DEFAULT_API_URL = os.environ.get("FAY_BROADCAST_API", "http://127.0.0.1:5000/transparent-pass")
DEFAULT_USER = os.environ.get("FAY_BROADCAST_USER", "User")
DEFAULT_SPEAKER = os.environ.get("FAY_BROADCAST_SPEAKER", "\u5e7f\u64ad\u6d88\u606f")
REQUEST_TIMEOUT = float(os.environ.get("FAY_BROADCAST_TIMEOUT", "10"))
HOST = os.environ.get("FAY_MCP_SSE_HOST", "0.0.0.0")
PORT = int(os.environ.get("FAY_MCP_SSE_PORT", "8765"))
SSE_PATH = os.environ.get("FAY_MCP_SSE_PATH", "/sse")
MSG_PATH = os.environ.get("FAY_MCP_MSG_PATH", "/messages")
server = None # Removed global singleton
sse_transport = SseServerTransport(MSG_PATH)
# 聚合工具索引namespaced_tool_name -> (server_id, tool_name)
_aggregated_index: Dict[str, Tuple[int, str]] = {}
def _text_content(text: str) -> TextContent:
try:
return TextContent(type="text", text=text)
except Exception:
return {"type": "text", "text": text} # type: ignore[return-value]
TOOLS: list[Tool] = [
Tool(
name="broadcast_message",
description="通过 Fay 的 /transparent-pass 透传文本/音频。",
inputSchema={
"type": "object",
"properties": {
"text": {"type": "string", "description": "要广播的文本audio_url为空时必填"},
"audio_url": {"type": "string", "description": "可选音频 URL"},
"user": {"type": "string", "description": "用户标识名称,默认 FAY_BROADCAST_USER 或 User"},
"speaker": {
"type": "string",
"description": "发言人显示名,输出为\"{speaker}说:{text}\"",
},
"queue": {"type": "boolean", "description": "是否走队列播放,默认 false"},
"queue_playback": {"type": "boolean", "description": "兼容参数,等同 queue"},
"enqueue": {"type": "boolean", "description": "兼容参数,等同 queue"},
"mode": {"type": "string", "description": "兼容参数,值为 queue 时启用队列播放"},
},
"required": [],
},
)
]
async def _handle_list_tools() -> list[Tool]:
# 本地广播工具 + Fay 当前在线 MCP 工具的聚合视图namespaced
aggregated = []
try:
aggregated = _build_aggregated_tools()
except Exception as e:
log.warning(f"Failed to build aggregated tools: {e}")
return TOOLS + aggregated
def _as_bool(value: Any) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
if isinstance(value, (int, float)):
return value != 0
if isinstance(value, str):
v = value.strip().lower()
if v == "":
return False
return v in {"1", "true", "yes", "on", "y", "queue"}
return bool(value)
def _parse_arguments(arguments: Dict[str, Any]) -> Tuple[str, str, str, str, bool]:
text = str(arguments.get("text", "") or "").strip()
audio_url = str(arguments.get("audio_url", "") or "").strip()
user = str(arguments.get("user", "") or "").strip() or DEFAULT_USER
speaker = str(arguments.get("speaker", "") or "").strip() or DEFAULT_SPEAKER
if "queue" in arguments:
queue = _as_bool(arguments.get("queue"))
elif "queue_playback" in arguments:
queue = _as_bool(arguments.get("queue_playback"))
elif "enqueue" in arguments:
queue = _as_bool(arguments.get("enqueue"))
elif "mode" in arguments:
queue = str(arguments.get("mode", "") or "").strip().lower() == "queue"
else:
queue = False
return text, audio_url, user, speaker, queue
def _build_aggregated_tools() -> List[Tool]:
"""
将 Fay 已连接的 MCP 工具聚合,对外暴露为 namespaced 名称:
<server_id>:<tool_name>
"""
tools: List[Tool] = []
_aggregated_index.clear()
server_name_map = {s.get("id"): s.get("name", f"Server{s.get('id')}") for s in mcp_service.mcp_servers or []}
for entry in tool_registry.get_enabled_tools():
server_id = entry.get("server_id")
tool_name = entry.get("name")
if server_id is None or not tool_name:
continue
agg_name = f"{server_id}:{tool_name}"
desc = entry.get("description", "")
server_label = server_name_map.get(server_id, f"Server {server_id}")
agg_desc = f"{desc} [via {server_label}]"
input_schema = entry.get("inputSchema") or {}
tool = Tool(
name=agg_name,
description=agg_desc,
inputSchema=input_schema if isinstance(input_schema, dict) else {},
)
tools.append(tool)
_aggregated_index[agg_name] = (server_id, tool_name)
return tools
async def _send_broadcast(payload: Dict[str, Any]) -> Tuple[bool, str]:
def _post() -> Tuple[bool, str]:
body = json.dumps(payload, ensure_ascii=True).encode("utf-8")
resp = requests.post(
DEFAULT_API_URL,
data=body,
headers={"Content-Type": "application/json; charset=utf-8"},
timeout=REQUEST_TIMEOUT,
)
try:
data = resp.json()
except Exception:
data = None
if resp.ok:
if isinstance(data, dict):
msg = data.get("message") or data.get("msg") or ""
code = data.get("code")
if isinstance(code, int) and code >= 400:
return False, msg or f"透传失败HTTP码 {code}"
return True, msg or "已发送透传请求。"
return True, "已发送透传请求。"
err_detail = ""
if isinstance(data, dict):
err_detail = data.get("message") or data.get("error") or data.get("msg") or ""
if not err_detail:
err_detail = resp.text
return False, f"HTTP {resp.status_code}: {err_detail}"
try:
return await asyncio.to_thread(_post)
except Exception as e:
return False, f"{type(e).__name__}: {e}"
async def _handle_call_tool(name: str, arguments: Dict[str, Any]) -> list[TextContent]:
# 本地广播
if name == "broadcast_message":
text, audio_url, user, speaker, queue = _parse_arguments(arguments or {})
if not text and not audio_url:
return [_text_content("text 或 audio_url 至少需提供一个。")]
payload: Dict[str, Any] = {"user": user}
if text:
payload["text"] = f"{speaker}\u8bf4\uff1a{text}"
if audio_url:
payload["audio"] = audio_url
if queue:
payload["queue"] = True
payload["queue_playback"] = True
payload["mode"] = "queue"
ok, message = await _send_broadcast(payload)
prefix = "成功" if ok else "失败"
return [_text_content(f"{prefix}: {message}")]
target = _aggregated_index.get(name)
if not target:
return [_text_content(f"未知工具: {name}")]
server_id, tool_name = target
try:
success, result = await asyncio.to_thread(mcp_service.call_mcp_tool, server_id, tool_name, arguments or {})
if not success:
return [_text_content(f"error: {result}")]
return _normalize_result(result)
except Exception as e:
return [_text_content(f"error: {type(e).__name__}: {e}")]
def _normalize_result(result: Any) -> List[TextContent]:
"""
将上游返回的任意对象转换为 MCP 文本内容列表。
"""
# 如果已经是 TextContent 或列表,直接返回
try:
from mcp.types import TextContent
if isinstance(result, TextContent):
return [result]
except Exception:
pass
if isinstance(result, list):
contents: List[TextContent] = []
for item in result:
try:
if hasattr(item, "type") and getattr(item, "type", "") == "text" and hasattr(item, "text"):
contents.append(item)
continue
except Exception:
pass
try:
if isinstance(item, dict) and item.get("type") == "text":
contents.append(TextContent(type="text", text=str(item.get("text", "")))) # type: ignore
continue
except Exception:
pass
contents.append(_text_content(str(item)))
return contents
return [_text_content(str(result))]
async def sse_endpoint(request):
# 为每个连接创建新的 Server 实例以支持并发
local_server = Server(SERVER_NAME)
# 注册工具处理程序
local_server.list_tools()(_handle_list_tools)
local_server.call_tool()(_handle_call_tool)
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
await local_server.run(read_stream, write_stream, local_server.create_initialization_options())
# 客户端断开时返回空响应,避免 NoneType 问题
return Response()
routes = [
Route(SSE_PATH, sse_endpoint, methods=["GET"]),
Mount(MSG_PATH, app=sse_transport.handle_post_message),
]
app = Starlette(routes=routes)
def main():
log.info(f"SSE MCP server started at http://{HOST}:{PORT}{SSE_PATH}")
log.info(f"Message endpoint mounted at {MSG_PATH}")
uvicorn.run(app, host=HOST, port=PORT, log_level="info")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.6 KiB

After

Width:  |  Height:  |  Size: 8.2 KiB

View File

@@ -10,8 +10,17 @@ from flask_cors import CORS
import requests
import datetime
import pytz
import logging
import uuid
import logging
import uuid
from urllib.parse import urlparse, urljoin
try:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
except Exception:
ChatOpenAI = None
HumanMessage = None
SystemMessage = None
AIMessage = None
import fay_booter
from tts import tts_voice
@@ -101,21 +110,77 @@ def _as_bool(value):
return value.strip().lower() in ("1", "true", "yes", "y", "on")
return False
def _build_llm_url(base_url: str) -> str:
if not base_url:
return ""
def _build_llm_url(base_url: str) -> str:
if not base_url:
return ""
url = base_url.rstrip("/")
if url.endswith("/chat/completions"):
return url
if url.endswith("/v1"):
return url + "/chat/completions"
return url + "/v1/chat/completions"
def _build_embedding_url(base_url: str) -> str:
if not base_url:
return ""
return url + "/v1/chat/completions"
def _normalize_openai_content(content):
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
text = item.get("text")
if text is not None:
parts.append(str(text))
continue
if "content" in item:
parts.append(_normalize_openai_content(item.get("content")))
continue
parts.append(str(item))
return "".join(parts)
if isinstance(content, dict):
if "text" in content:
return _normalize_openai_content(content.get("text"))
if "content" in content:
return _normalize_openai_content(content.get("content"))
return str(content)
def _build_langchain_messages(messages):
normalized = []
if isinstance(messages, str):
normalized.append(HumanMessage(content=messages))
return normalized
if not isinstance(messages, list):
return normalized
for msg in messages:
if not isinstance(msg, dict):
continue
role = str(msg.get("role", "user")).strip().lower()
content = _normalize_openai_content(msg.get("content"))
if content is None:
content = ""
if role == "system":
normalized.append(SystemMessage(content=content))
elif role == "assistant":
normalized.append(AIMessage(content=content))
else:
normalized.append(HumanMessage(content=content))
return normalized
def _safe_text_from_chunk(chunk):
if chunk is None:
return ""
value = getattr(chunk, "content", "")
return _normalize_openai_content(value)
def _build_embedding_url(base_url: str) -> str:
if not base_url:
return ""
url = base_url.rstrip("/")
if url.endswith("/v1/embeddings") or url.endswith("/embeddings"):
return url
@@ -123,9 +188,20 @@ def _build_embedding_url(base_url: str) -> str:
return url[:-len("/v1/chat/completions")] + "/v1/embeddings"
if url.endswith("/chat/completions"):
return url[:-len("/chat/completions")] + "/embeddings"
if url.endswith("/v1"):
return url + "/embeddings"
return url + "/v1/embeddings"
if url.endswith("/v1"):
return url + "/embeddings"
return url + "/v1/embeddings"
def _build_langchain_base_url(base_url: str) -> str:
if not base_url:
return ""
url = base_url.rstrip("/")
if url.endswith("/v1/chat/completions"):
return url[:-len("/chat/completions")]
if url.endswith("/chat/completions"):
return url[:-len("/chat/completions")]
return url
@__app.route('/api/submit', methods=['post'])
def api_submit():
@@ -371,61 +447,111 @@ def api_send_v1_chat_completions():
if not data:
return jsonify({'error': 'missing request body'})
try:
model = data.get('model', 'fay')
if model == 'llm':
try:
config_util.load_config()
llm_url = _build_llm_url(config_util.gpt_base_url)
api_key = config_util.key_gpt_api_key
model_engine = config_util.gpt_model_engine
except Exception as exc:
return jsonify({'error': f'LLM config load failed: {exc}'}), 500
if not llm_url:
return jsonify({'error': 'LLM base_url is not configured'}), 500
payload = dict(data)
if payload.get('model') == 'llm' and model_engine:
payload['model'] = model_engine
stream_requested = _as_bool(payload.get('stream', False))
headers = {'Content-Type': 'application/json'}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
try:
if stream_requested:
resp = requests.post(llm_url, headers=headers, json=payload, stream=True)
def generate():
try:
for chunk in resp.iter_content(chunk_size=8192):
if not chunk:
continue
yield chunk
finally:
resp.close()
content_type = resp.headers.get("Content-Type", "text/event-stream")
if "charset=" not in content_type.lower():
content_type = f"{content_type}; charset=utf-8"
return Response(
generate(),
status=resp.status_code,
content_type=content_type,
)
resp = requests.post(llm_url, headers=headers, json=payload, timeout=60)
content_type = resp.headers.get("Content-Type", "application/json")
if "charset=" not in content_type.lower():
content_type = f"{content_type}; charset=utf-8"
return Response(
resp.content,
status=resp.status_code,
content_type=content_type,
)
except Exception as exc:
return jsonify({'error': f'LLM request failed: {exc}'}), 500
model = data.get('model', 'fay')
if model == 'llm':
if ChatOpenAI is None or HumanMessage is None:
return jsonify({'error': 'langchain_openai or langchain_core is not available'}), 500
try:
config_util.load_config()
api_key = config_util.key_gpt_api_key
model_engine = config_util.gpt_model_engine
base_url = _build_langchain_base_url(config_util.gpt_base_url)
except Exception as exc:
return jsonify({'error': f'LLM config load failed: {exc}'}), 500
if not base_url:
return jsonify({'error': 'LLM base_url is not configured'}), 500
payload = dict(data)
stream_requested = _as_bool(payload.get('stream', False))
model_name = model_engine or payload.get('model')
lc_messages = _build_langchain_messages(payload.get('messages', []))
if not lc_messages:
return jsonify({'error': 'messages is required'}), 400
llm_kwargs = {
"model": model_name,
"base_url": base_url,
"api_key": api_key,
"streaming": bool(stream_requested),
}
if payload.get("temperature") is not None:
llm_kwargs["temperature"] = payload.get("temperature")
if payload.get("max_tokens") is not None:
llm_kwargs["max_tokens"] = payload.get("max_tokens")
model_kwargs = {}
if payload.get("top_p") is not None:
model_kwargs["top_p"] = payload.get("top_p")
if model_kwargs:
llm_kwargs["model_kwargs"] = model_kwargs
try:
llm_client = ChatOpenAI(**llm_kwargs)
run_cfg = {
"tags": ["fay", "api", "model-llm"],
"metadata": {"entrypoint": "api_send_v1_chat_completions", "model_alias": "llm"},
}
if stream_requested:
stream_id = "chatcmpl-" + str(uuid.uuid4())
def generate():
try:
for chunk in llm_client.stream(lc_messages, config=run_cfg):
text_piece = _safe_text_from_chunk(chunk)
if text_piece is None or text_piece == "":
continue
message = {
"id": stream_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"delta": {"content": text_piece},
"index": 0,
"finish_reason": None
}
]
}
yield f"data: {json.dumps(message, ensure_ascii=False)}\n\n"
final_message = {
"id": stream_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"delta": {},
"index": 0,
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(final_message, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
finally:
pass
return Response(generate(), content_type="text/event-stream; charset=utf-8")
ai_resp = llm_client.invoke(lc_messages, config=run_cfg)
answer_text = _normalize_openai_content(getattr(ai_resp, "content", ""))
return jsonify({
"id": "chatcmpl-" + str(uuid.uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": answer_text
},
"finish_reason": "stop"
}
]
})
except Exception as exc:
return jsonify({'error': f'LLM request failed: {exc}'}), 500
last_content = ""
username = "User"
@@ -1099,31 +1225,97 @@ def api_toggle_microphone():
#消息透传接口
@__app.route('/transparent-pass', methods=['post'])
def transparent_pass():
try:
data = request.form.get('data')
if data is None:
data = request.get_json()
else:
data = json.loads(data)
username = data.get('user', 'User')
response_text = data.get('text', None)
audio_url = data.get('audio', None)
if response_text or audio_url:
# 新消息到达,立即中断该用户之前的所有处理(文本流+音频队列)
util.printInfo(1, username, f'[API中断] 新消息到达,完整中断用户 {username} 之前的所有处理')
util.printInfo(1, username, f'[API中断] 用户 {username} 的文本流和音频队列已清空,准备处理新消息')
interact = Interact('transparent_pass', 2, {'user': username, 'text': response_text, 'audio': audio_url, 'isend':True, 'isfirst':True})
util.printInfo(1, username, '透传播放:{}{}'.format(response_text, audio_url), time.time())
success = fay_booter.feiFei.on_interact(interact)
if (success == 'success'):
return jsonify({'code': 200, 'message' : '成功'})
return jsonify({'code': 500, 'message' : '未知原因出错'})
except Exception as e:
return jsonify({'code': 500, 'message': f'出错: {e}'}), 500
# 清除记忆API
@__app.route('/transparent-pass', methods=['post'])
def transparent_pass():
try:
data = request.form.get('data')
if data is None:
data = request.get_json(silent=True) or {}
else:
data = json.loads(data)
if isinstance(data, dict):
nested_data = data.get('data')
if isinstance(nested_data, dict):
data = nested_data
elif isinstance(nested_data, str):
nested_data = nested_data.strip()
if nested_data:
try:
data = json.loads(nested_data)
except Exception:
pass
if not isinstance(data, dict):
data = {}
username = data.get('user', 'User')
response_text = data.get('text', None)
audio_url = data.get('audio', None)
if isinstance(audio_url, str):
audio_url = audio_url.strip()
if audio_url:
parsed_audio = urlparse(audio_url)
if not parsed_audio.scheme:
if audio_url.startswith('//'):
audio_url = 'http:' + audio_url
else:
base_url = ''
origin = (request.headers.get('Origin') or '').strip()
referer = (request.headers.get('Referer') or '').strip()
if origin:
parsed_origin = urlparse(origin)
if parsed_origin.scheme and parsed_origin.netloc:
base_url = f'{parsed_origin.scheme}://{parsed_origin.netloc}/'
if (not base_url) and referer:
parsed_referer = urlparse(referer)
if parsed_referer.scheme and parsed_referer.netloc:
base_url = f'{parsed_referer.scheme}://{parsed_referer.netloc}/'
if not base_url:
base_url = request.host_url
audio_url = urljoin(base_url, audio_url)
else:
audio_url = None
queue_mode = _as_bool(data.get('queue', False))
if not queue_mode:
queue_mode = _as_bool(data.get('queue_playback', data.get('enqueue', False)))
if not queue_mode:
queue_mode = str(data.get('mode', '')).strip().lower() == 'queue'
if not queue_mode:
queue_mode = _as_bool(data.get('qutue', False))
if response_text or audio_url:
if queue_mode:
interact = Interact('transparent_pass', 2, {
'user': username,
'text': response_text,
'audio': audio_url,
'isend': True,
'isfirst': True,
'no_reply': True,
'queue': True,
'queue_playback': True
})
else:
util.printInfo(1, username, f'[\u0041\u0050\u0049\u4e2d\u65ad] \u65b0\u6d88\u606f\u5230\u8fbe\uff0c\u5b8c\u6574\u4e2d\u65ad\u7528\u6237 {username} \u4e4b\u524d\u7684\u6240\u6709\u5904\u7406')
util.printInfo(1, username, f'[\u0041\u0050\u0049\u4e2d\u65ad] \u7528\u6237 {username} \u7684\u6587\u672c\u6d41\u548c\u97f3\u9891\u961f\u5217\u5df2\u6e05\u7a7a\uff0c\u51c6\u5907\u5904\u7406\u65b0\u6d88\u606f')
interact = Interact('transparent_pass', 2, {
'user': username,
'text': response_text,
'audio': audio_url,
'isend': True,
'isfirst': True
})
util.printInfo(1, username, '\u900f\u4f20\u64ad\u653e\uff1a{},{}'.format(response_text, audio_url), time.time())
success = fay_booter.feiFei.on_interact(interact)
if success == 'success':
return jsonify({'code': 200, 'message': '\u6210\u529f'})
return jsonify({'code': 500, 'message': '\u672a\u77e5\u539f\u56e0\u51fa\u9519'})
except Exception as e:
return jsonify({'code': 500, 'message': f'\u51fa\u9519: {e}'}), 500
@__app.route('/api/clear-memory', methods=['POST'])
def api_clear_memory():
try:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.6 KiB

After

Width:  |  Height:  |  Size: 8.2 KiB

View File

@@ -2,7 +2,13 @@
import os
import sys
os.environ['PATH'] += os.pathsep + os.path.join(os.getcwd(), "test", "ovr_lipsync", "ffmpeg", "bin")
def _resolve_runtime_dir():
if hasattr(sys, "_MEIPASS"):
return os.path.abspath(sys._MEIPASS)
return os.path.abspath(os.path.dirname(__file__))
_RUNTIME_DIR = _resolve_runtime_dir()
os.environ['PATH'] += os.pathsep + os.path.join(_RUNTIME_DIR, "test", "ovr_lipsync", "ffmpeg", "bin")
def _preload_config_center(argv):
for i, arg in enumerate(argv):

View File

@@ -1,35 +1,35 @@
requests
numpy
pyaudio~=0.2.11
websockets~=10.4
ws4py~=0.5.1
PyQt5==5.15.10
PyQt5-sip==12.13.0
PyQtWebEngine==5.15.6
flask~=3.0.0
openpyxl~=3.0.9
flask_cors~=3.0.10
websocket-client
azure-cognitiveservices-speech
aliyun-python-sdk-core
simhash
pytz
gevent
edge_tts
pydub
tenacity==8.2.3
pygame
scipy
flask-httpauth
opencv-python
psutil
langchain
langchain_openai
langgraph
bs4
schedule
mcp
python-docx
python-pptx
chromadb
requests
numpy
pyaudio~=0.2.11
websockets~=10.4
ws4py~=0.5.1
#PyQt5==5.15.10
#PyQt5-sip==12.13.0
#PyQtWebEngine==5.15.6
flask~=3.0.0
openpyxl~=3.0.9
flask_cors~=3.0.10
websocket-client
azure-cognitiveservices-speech
aliyun-python-sdk-core
simhash
pytz
gevent
edge_tts
pydub
tenacity==8.2.3
pygame
scipy
flask-httpauth
opencv-python
psutil
langchain
langchain_openai
langgraph
bs4
schedule
mcp
python-docx
python-pptx
chromadb
sentence_transformers

View File

@@ -1,91 +1,95 @@
import subprocess
import time
import os
os.environ['PATH'] += os.pathsep + os.path.join(os.getcwd(), "test", "ovr_lipsync", "ffmpeg", "bin")
import sys
_RUNTIME_DIR = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
if hasattr(sys, "_MEIPASS"):
_RUNTIME_DIR = os.path.abspath(sys._MEIPASS)
os.environ['PATH'] += os.pathsep + os.path.join(_RUNTIME_DIR, "test", "ovr_lipsync", "ffmpeg", "bin")
from pydub import AudioSegment
import json
def list_files(dir_path):
for root, dirs, files in os.walk(dir_path):
for file in files:
print(os.path.join(root, file))
class LipSyncGenerator:
def list_files(dir_path):
for root, dirs, files in os.walk(dir_path):
for file in files:
print(os.path.join(root, file))
class LipSyncGenerator:
def __init__(self):
self.viseme_em = [
"sil", "PP", "FF", "TH", "DD",
"kk", "CH", "SS", "nn", "RR",
"aa", "E", "ih", "oh", "ou"]
self.viseme = []
self.exe_path = os.path.join(os.getcwd(), "test", "ovr_lipsync", "ovr_lipsync_exe", "ProcessWAV.exe")
def run_exe_and_get_output(self, arguments):
process = subprocess.Popen([self.exe_path] + arguments, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
while True:
output = process.stdout.readline()
if output == b'' and process.poll() is not None:
break
if output:
self.viseme.append(output.strip().decode())
rc = process.poll()
return rc
def filter(self, viseme):
new_viseme = []
for v in self.viseme:
if v in self.viseme_em:
new_viseme.append(v)
return new_viseme
def generate_visemes(self, wav_filepath):
if wav_filepath.endswith(".mp3"):
wav_filepath = self.convert_mp3_to_wav(wav_filepath)
arguments = ["--print-viseme-name", wav_filepath]
self.run_exe_and_get_output(arguments)
return self.filter(self.viseme)
def consolidate_visemes(self, viseme_list):
if not viseme_list:
return []
result = []
current_viseme = viseme_list[0]
count = 1
for viseme in viseme_list[1:]:
if viseme == current_viseme:
count += 1
else:
result.append({"Lip": current_viseme, "Time": count*33}) # Multiply by 10 for duration in ms
current_viseme = viseme
count = 1
# Add the last viseme to the result
result.append({"Lip": current_viseme, "Time": count*33}) # Multiply by 10 for duration in ms
new_data = []
for i in range(len(result)):
if result[i]['Time'] < 30:
if len(new_data) > 0:
new_data[-1]['Time'] += result[i]['Time']
else:
new_data.append(result[i])
return new_data
def convert_mp3_to_wav(self, mp3_filepath):
audio = AudioSegment.from_mp3(mp3_filepath)
# 使用 set_frame_rate 方法设置采样率
audio = audio.set_frame_rate(44100)
wav_filepath = mp3_filepath.rsplit(".", 1)[0] + ".wav"
audio.export(wav_filepath, format="wav")
return wav_filepath
if __name__ == "__main__":
start_time = time.time()
lip_sync_generator = LipSyncGenerator()
viseme_list = lip_sync_generator.generate_visemes("E:\\github\\Fay\\samples\\fay-man.wav")
print(viseme_list)
consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
print(json.dumps(consolidated_visemes))
print(time.time() - start_time)
self.exe_path = os.path.join(_RUNTIME_DIR, "test", "ovr_lipsync", "ovr_lipsync_exe", "ProcessWAV.exe")
def run_exe_and_get_output(self, arguments):
process = subprocess.Popen([self.exe_path] + arguments, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
while True:
output = process.stdout.readline()
if output == b'' and process.poll() is not None:
break
if output:
self.viseme.append(output.strip().decode())
rc = process.poll()
return rc
def filter(self, viseme):
new_viseme = []
for v in self.viseme:
if v in self.viseme_em:
new_viseme.append(v)
return new_viseme
def generate_visemes(self, wav_filepath):
if wav_filepath.endswith(".mp3"):
wav_filepath = self.convert_mp3_to_wav(wav_filepath)
arguments = ["--print-viseme-name", wav_filepath]
self.run_exe_and_get_output(arguments)
return self.filter(self.viseme)
def consolidate_visemes(self, viseme_list):
if not viseme_list:
return []
result = []
current_viseme = viseme_list[0]
count = 1
for viseme in viseme_list[1:]:
if viseme == current_viseme:
count += 1
else:
result.append({"Lip": current_viseme, "Time": count*33}) # Multiply by 10 for duration in ms
current_viseme = viseme
count = 1
# Add the last viseme to the result
result.append({"Lip": current_viseme, "Time": count*33}) # Multiply by 10 for duration in ms
new_data = []
for i in range(len(result)):
if result[i]['Time'] < 30:
if len(new_data) > 0:
new_data[-1]['Time'] += result[i]['Time']
else:
new_data.append(result[i])
return new_data
def convert_mp3_to_wav(self, mp3_filepath):
audio = AudioSegment.from_mp3(mp3_filepath)
# 使用 set_frame_rate 方法设置采样率
audio = audio.set_frame_rate(44100)
wav_filepath = mp3_filepath.rsplit(".", 1)[0] + ".wav"
audio.export(wav_filepath, format="wav")
return wav_filepath
if __name__ == "__main__":
start_time = time.time()
lip_sync_generator = LipSyncGenerator()
viseme_list = lip_sync_generator.generate_visemes("E:\\github\\Fay\\samples\\fay-man.wav")
print(viseme_list)
consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
print(json.dumps(consolidated_visemes))
print(time.time() - start_time)

View File

@@ -278,11 +278,11 @@ def load_config(force_reload=False):
root_config_json_exists = os.path.exists(default_config_json_path)
root_config_complete = root_system_conf_exists and root_config_json_exists
# 构建system.conf和config.json的完整路径
config_center_fallback = False
if using_config_center:
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
# 构建system.conf和config.json相关路径.
config_center_fallback = False
if using_config_center:
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
else:
if (
system_conf_path is None
@@ -308,29 +308,34 @@ def load_config(force_reload=False):
loaded_from_api = False
api_attempted = False
if using_config_center:
if explicit_config_center:
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
else:
util.log(1, f"未检测到本地system.conf或config.json尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
if explicit_config_center:
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
else:
util.log(1, f"未检测到本地system.conf或config.json尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
api_attempted = True
if api_config:
util.log(1, "成功从配置中心加载配置")
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
if config_center_fallback:
_bootstrap_loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
forced_loaded = True
# 将配置中心配置缓存到本地文件.
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(
api_config,
system_conf_path,
config_json_path,
save_config_json=not os.path.exists(config_json_path)
)
forced_loaded = True
_warn_public_config_once()
else:
util.log(2, "配置中心加载失败,尝试使用缓存配置")
util.log(2, "配置中心加载失败,尝试使用缓存配置")
sys_conf_exists = os.path.exists(system_conf_path)
config_json_exists = os.path.exists(config_json_path)
@@ -339,38 +344,48 @@ def load_config(force_reload=False):
if (not sys_conf_exists or not config_json_exists) and not forced_loaded:
if using_config_center:
if not api_attempted:
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
api_attempted = True
if api_config:
util.log(1, "成功从配置中心加载配置")
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
if config_center_fallback:
_bootstrap_loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
# 将配置中心配置缓存到本地文件.
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(
api_config,
system_conf_path,
config_json_path,
save_config_json=not os.path.exists(config_json_path)
)
_warn_public_config_once()
else:
# 使用提取的项目ID或全局项目ID
util.log(1, f"本地配置文件不完整{system_conf_path if not sys_conf_exists else ''}{'' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在)尝试从API加载配置...")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
# 使用项目配置或全局项目配置作为回退来源.
util.log(1, f"本地配置文件不完整尝试从API加载配置...")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
if api_config:
util.log(1, "成功从配置中心加载配置")
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
# 将配置中心配置缓存到本地文件.
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(
api_config,
system_conf_path,
config_json_path,
save_config_json=not os.path.exists(config_json_path)
)
_warn_public_config_once()
@@ -378,17 +393,17 @@ def load_config(force_reload=False):
config_json_exists = os.path.exists(config_json_path)
if using_config_center and (not sys_conf_exists or not config_json_exists):
if _last_loaded_config is not None and _last_loaded_from_api:
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
return _last_loaded_config
if config_center_fallback and using_config_center and (not sys_conf_exists or not config_json_exists):
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
if cache_ready:
util.log(2, "配置中心不可用,回退使用缓存配置")
util.log(2, "配置中心不可用,回退使用缓存配置")
using_config_center = False
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
else:
util.log(2, "配置中心不可用,回退使用本地配置文件")
util.log(2, "配置中心不可用,回退使用本地配置文件")
using_config_center = False
system_conf_path = default_system_conf_path
config_json_path = default_config_json_path
@@ -497,31 +512,33 @@ def load_config(force_reload=False):
return config_dict
def save_api_config_to_local(api_config, system_conf_path, config_json_path):
"""
将API加载的配置保存到本地文件
def save_api_config_to_local(api_config, system_conf_path, config_json_path, save_config_json=True):
"""
Persist API config to local files.
Args:
api_config: API加载的配置字典
system_conf_path: system.conf文件路径
config_json_path: config.json文件路径
Args:
api_config: API response dict.
system_conf_path: Path to system.conf.
config_json_path: Path to config.json.
save_config_json: Whether to write config.json.
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(system_conf_path), exist_ok=True)
os.makedirs(os.path.dirname(config_json_path), exist_ok=True)
# 保存system.conf
with open(system_conf_path, 'w', encoding='utf-8') as f:
api_config['system_config'].write(f)
# 保存config.json
with codecs.open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(api_config['config'], f, ensure_ascii=False, indent=4)
util.log(1, f"已将配置中心配置缓存到本地文件: {system_conf_path}{config_json_path}")
except Exception as e:
util.log(2, f"保存配置中心配置缓存到本地文件时出错: {str(e)}")
# 确保目录存在.
os.makedirs(os.path.dirname(system_conf_path), exist_ok=True)
os.makedirs(os.path.dirname(config_json_path), exist_ok=True)
# 始终刷新 system.conf.
with open(system_conf_path, 'w', encoding='utf-8') as f:
api_config['system_config'].write(f)
# 默认只在首次下载时保存config.json.
if save_config_json:
with codecs.open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(api_config['config'], f, ensure_ascii=False, indent=4)
util.log(1, f"已将配置中心配置缓存到本地文件: {system_conf_path}{config_json_path}")
except Exception as e:
util.log(2, f"保存配置中心配置到本地文件时出错: {str(e)}")
@synchronized
def save_config(config_data):

View File

@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import re
import time
from utils import util
from core import stream_manager
@@ -27,6 +28,20 @@ class StreamTextProcessor:
self.max_cache_size = max_cache_size
# 常用中英文分句标点UTF-8
self.punctuation_marks = ["", "", "", "", "", "", "", ".", "!", "?", "\n"]
self.punctuation_mark_set = set(self.punctuation_marks)
self.url_regex_list = [
re.compile(
r"(?i)\b(?:https?://|ftp://|file://|www\.)[^\s\u3002\uff01\uff1f\u3001\uff0c\uff1b\uff1a<>'\"\[\]\(\)\{\}]+"
),
re.compile(
r"(?i)\b[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?"
r"(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+"
r"(?::\d{2,5})?"
r"(?:/[^\s\u3002\uff01\uff1f\u3001\uff0c\uff1b\uff1a<>'\"\[\]\(\)\{\}]*)?"
r"(?:\?[^\s\u3002\uff01\uff1f\u3001\uff0c\uff1b\uff1a<>'\"\[\]\(\)\{\}]*)?"
r"(?:#[^\s\u3002\uff01\uff1f\u3001\uff0c\uff1b\uff1a<>'\"\[\]\(\)\{\}]*)?"
),
]
def process_stream_text(self, text, username, is_qa=False, session_type="stream"):
"""
@@ -87,7 +102,7 @@ class StreamTextProcessor:
# 动态缓存大小检查
if len(accumulated_text) > self.max_cache_size:
util.log(1, f"处理过程中缓存溢出,强制发送剩余文本")
util.log(1, "处理过程中缓存溢出,强制发送剩余文本")
break
iteration_count += 1
@@ -178,20 +193,95 @@ class StreamTextProcessor:
"""
try:
indices = []
for punct in self.punctuation_marks:
try:
index = text.find(punct)
if index != -1:
indices.append(index)
except Exception as e:
util.log(1, f"查找标点符号 '{punct}' 时出错: {str(e)}")
url_spans = self._find_url_spans(text)
for index, ch in enumerate(text):
if ch not in self.punctuation_mark_set:
continue
return sorted([i for i in indices if i != -1])
if self._is_in_spans(index, url_spans):
continue
if ch == "." and self._is_protected_dot(text, index):
continue
indices.append(index)
return indices
except Exception as e:
util.log(1, f"查找标点符号时出错: {str(e)}")
return []
def _find_url_spans(self, text):
spans = []
for regex in self.url_regex_list:
for match in regex.finditer(text):
start, end = match.span()
if end > start:
spans.append((start, end))
if not spans:
return []
spans.sort(key=lambda item: item[0])
merged = [spans[0]]
for start, end in spans[1:]:
last_start, last_end = merged[-1]
if start <= last_end:
merged[-1] = (last_start, max(last_end, end))
else:
merged.append((start, end))
return merged
@staticmethod
def _is_in_spans(index, spans):
for start, end in spans:
if start <= index < end:
return True
if index < start:
break
return False
def _is_protected_dot(self, text, index):
"""
判断英文句点是否位于不应截断的 token 内如版本号、URL、域名等
"""
try:
prev_char = text[index - 1] if index > 0 else ""
next_char = text[index + 1] if (index + 1) < len(text) else ""
# 数字点位1.2 / 3.14 / 192.168.1.1
if prev_char.isdigit() and next_char.isdigit():
return True
token = self._extract_token_around(text, index).lower()
if not token:
return False
# URL 写法
if "://" in token or token.startswith("www."):
return True
# 版本号写法v1.2.3 / 1.2.3
if re.match(r"^[a-z]*\d+\.\d+(\.\d+)*[a-z]*$", token):
return True
# 域名/主机名(可带路径)
if re.match(r"^[a-z0-9-]+(\.[a-z0-9-]+)+(/\S*)?$", token):
return True
return False
except Exception:
return False
def _extract_token_around(self, text, index):
"""
提取句点所在连续 token按空白与中英文标点分隔
"""
separators = set(" \t\r\n\"'()[]{}<>,!?;:" + "\uFF0C\u3002\uFF01\uFF1F\uFF1B\uFF1A\u3001")
left = index
right = index + 1
while left > 0 and text[left - 1] not in separators:
left -= 1
while right < len(text) and text[right] not in separators:
right += 1
return text[left:right]
def _send_fallback_text(self, text, username, state_manager, conversation_id):
"""
备用发送方案:直接发送完整文本(含首尾标记)