mirror of
https://github.com/xszyou/Fay.git
synced 2026-03-12 17:51:28 +08:00
v4.3.1
1.增加音频缓存功能,降低tts费率; 2.优化透传接口流式断句逻辑,域名、版本号等不断句; 3.优化数字人接口流式文本输出顺序; 4.llm透传功能接入langsmith,配置环境变量后可通过langsmith平台调优prompt; 5.优化配置中心加载逻辑,人设配置依然保留。
This commit is contained in:
368
asr/ali_nls.py
368
asr/ali_nls.py
@@ -1,190 +1,206 @@
|
||||
from threading import Thread
|
||||
from threading import Lock
|
||||
import websocket
|
||||
import json
|
||||
import time
|
||||
import ssl
|
||||
import wave
|
||||
import _thread as thread
|
||||
from aliyunsdkcore.client import AcsClient
|
||||
from aliyunsdkcore.request import CommonRequest
|
||||
|
||||
from core import wsa_server
|
||||
from scheduler.thread_manager import MyThread
|
||||
from utils import util
|
||||
from utils import config_util as cfg
|
||||
from core.authorize_tb import Authorize_Tb
|
||||
|
||||
__running = True
|
||||
__my_thread = None
|
||||
_token = ''
|
||||
|
||||
|
||||
from threading import Thread
|
||||
from threading import Lock
|
||||
import websocket
|
||||
import json
|
||||
import time
|
||||
import ssl
|
||||
import wave
|
||||
import _thread as thread
|
||||
from aliyunsdkcore.client import AcsClient
|
||||
from aliyunsdkcore.request import CommonRequest
|
||||
|
||||
from core import wsa_server
|
||||
from scheduler.thread_manager import MyThread
|
||||
from utils import util
|
||||
from utils import config_util as cfg
|
||||
from core.authorize_tb import Authorize_Tb
|
||||
|
||||
__running = True
|
||||
__my_thread = None
|
||||
_token = ''
|
||||
|
||||
|
||||
def __post_token():
|
||||
global _token
|
||||
__client = AcsClient(
|
||||
cfg.key_ali_nls_key_id,
|
||||
cfg.key_ali_nls_key_secret,
|
||||
"cn-shanghai"
|
||||
)
|
||||
if not cfg.key_ali_nls_key_id or not cfg.key_ali_nls_key_secret:
|
||||
util.log(2, "AliNLS 凭据未配置,跳过 token 刷新。")
|
||||
return False
|
||||
|
||||
__request = CommonRequest()
|
||||
__request.set_method('POST')
|
||||
__request.set_domain('nls-meta.cn-shanghai.aliyuncs.com')
|
||||
__request.set_version('2019-02-28')
|
||||
__request.set_action_name('CreateToken')
|
||||
info = json.loads(__client.do_action_with_exception(__request))
|
||||
_token = info['Token']['Id']
|
||||
authorize = Authorize_Tb()
|
||||
authorize_info = authorize.find_by_userid(cfg.key_ali_nls_key_id)
|
||||
if authorize_info is not None:
|
||||
authorize.update_by_userid(cfg.key_ali_nls_key_id, _token, info['Token']['ExpireTime']*1000)
|
||||
else:
|
||||
authorize.add(cfg.key_ali_nls_key_id, _token, info['Token']['ExpireTime']*1000)
|
||||
try:
|
||||
__client = AcsClient(
|
||||
cfg.key_ali_nls_key_id,
|
||||
cfg.key_ali_nls_key_secret,
|
||||
"cn-shanghai"
|
||||
)
|
||||
|
||||
__request = CommonRequest()
|
||||
__request.set_method('POST')
|
||||
__request.set_domain('nls-meta.cn-shanghai.aliyuncs.com')
|
||||
__request.set_version('2019-02-28')
|
||||
__request.set_action_name('CreateToken')
|
||||
info = json.loads(__client.do_action_with_exception(__request))
|
||||
_token = info['Token']['Id']
|
||||
authorize = Authorize_Tb()
|
||||
authorize_info = authorize.find_by_userid(cfg.key_ali_nls_key_id)
|
||||
if authorize_info is not None:
|
||||
authorize.update_by_userid(cfg.key_ali_nls_key_id, _token, info['Token']['ExpireTime']*1000)
|
||||
else:
|
||||
authorize.add(cfg.key_ali_nls_key_id, _token, info['Token']['ExpireTime']*1000)
|
||||
util.log(1, "AliNLS token刷新成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
util.log(2, f"AliNLS token刷新失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def __runnable():
|
||||
while __running:
|
||||
__post_token()
|
||||
time.sleep(60 * 60 * 12)
|
||||
|
||||
if __post_token():
|
||||
time.sleep(60 * 60 * 12)
|
||||
else:
|
||||
time.sleep(60)
|
||||
|
||||
def start():
|
||||
MyThread(target=__runnable).start()
|
||||
|
||||
|
||||
class ALiNls:
|
||||
# 初始化
|
||||
def __init__(self, username):
|
||||
self.__URL = 'wss://nls-gateway-cn-shenzhen.aliyuncs.com/ws/v1'
|
||||
self.__ws = None
|
||||
self.__frames = []
|
||||
self.started = False
|
||||
self.__closing = False
|
||||
self.__task_id = ''
|
||||
self.done = False
|
||||
self.finalResults = ""
|
||||
self.username = username
|
||||
self.data = b''
|
||||
self.__endding = False
|
||||
self.__is_close = False
|
||||
self.lock = Lock()
|
||||
|
||||
def __create_header(self, name):
|
||||
if name == 'StartTranscription':
|
||||
self.__task_id = util.random_hex(32)
|
||||
header = {
|
||||
"appkey": cfg.key_ali_nls_app_key,
|
||||
"message_id": util.random_hex(32),
|
||||
"task_id": self.__task_id,
|
||||
"namespace": "SpeechTranscriber",
|
||||
"name": name
|
||||
}
|
||||
return header
|
||||
|
||||
# 收到websocket消息的处理
|
||||
def on_message(self, ws, message):
|
||||
try:
|
||||
data = json.loads(message)
|
||||
header = data['header']
|
||||
name = header['name']
|
||||
if name == 'TranscriptionStarted':
|
||||
self.started = True
|
||||
if name == 'SentenceEnd':
|
||||
self.done = True
|
||||
self.finalResults = data['payload']['result']
|
||||
if wsa_server.get_web_instance().is_connected(self.username):
|
||||
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
|
||||
if wsa_server.get_instance().is_connected(self.username):
|
||||
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
|
||||
wsa_server.get_instance().add_cmd(content)
|
||||
ws.close()#TODO
|
||||
elif name == 'TranscriptionResultChanged':
|
||||
self.finalResults = data['payload']['result']
|
||||
if wsa_server.get_web_instance().is_connected(self.username):
|
||||
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
|
||||
if wsa_server.get_instance().is_connected(self.username):
|
||||
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
|
||||
wsa_server.get_instance().add_cmd(content)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# print("### message:", message)
|
||||
|
||||
# 收到websocket的关闭要求
|
||||
def on_close(self, ws, code, msg):
|
||||
self.__endding = True
|
||||
self.__is_close = True
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(self, ws, error):
|
||||
print("aliyun asr error:", error)
|
||||
self.started = True #避免在aliyun asr出错时,recorder一直等待start状态返回
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(self, ws):
|
||||
self.__endding = False
|
||||
#为了兼容多路asr,关闭过程数据
|
||||
def run(*args):
|
||||
while self.__endding == False:
|
||||
try:
|
||||
if len(self.__frames) > 0:
|
||||
with self.lock:
|
||||
frame = self.__frames.pop(0)
|
||||
if isinstance(frame, dict):
|
||||
ws.send(json.dumps(frame))
|
||||
elif isinstance(frame, bytes):
|
||||
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
|
||||
self.data += frame
|
||||
else:
|
||||
time.sleep(0.001) # 避免忙等
|
||||
except Exception as e:
|
||||
print(e)
|
||||
break
|
||||
if self.__is_close == False:
|
||||
for frame in self.__frames:
|
||||
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
|
||||
frame = {"header": self.__create_header('StopTranscription')}
|
||||
ws.send(json.dumps(frame))
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
|
||||
|
||||
class ALiNls:
|
||||
# 初始化
|
||||
def __init__(self, username):
|
||||
self.__URL = 'wss://nls-gateway-cn-shenzhen.aliyuncs.com/ws/v1'
|
||||
self.__ws = None
|
||||
self.__frames = []
|
||||
self.started = False
|
||||
self.__closing = False
|
||||
self.__task_id = ''
|
||||
self.done = False
|
||||
self.finalResults = ""
|
||||
self.username = username
|
||||
self.data = b''
|
||||
self.__endding = False
|
||||
self.__is_close = False
|
||||
self.lock = Lock()
|
||||
|
||||
def __create_header(self, name):
|
||||
if name == 'StartTranscription':
|
||||
self.__task_id = util.random_hex(32)
|
||||
header = {
|
||||
"appkey": cfg.key_ali_nls_app_key,
|
||||
"message_id": util.random_hex(32),
|
||||
"task_id": self.__task_id,
|
||||
"namespace": "SpeechTranscriber",
|
||||
"name": name
|
||||
}
|
||||
return header
|
||||
|
||||
# 收到websocket消息的处理
|
||||
def on_message(self, ws, message):
|
||||
try:
|
||||
data = json.loads(message)
|
||||
header = data['header']
|
||||
name = header['name']
|
||||
if name == 'TranscriptionStarted':
|
||||
self.started = True
|
||||
if name == 'SentenceEnd':
|
||||
self.done = True
|
||||
self.finalResults = data['payload']['result']
|
||||
if wsa_server.get_web_instance().is_connected(self.username):
|
||||
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
|
||||
if wsa_server.get_instance().is_connected(self.username):
|
||||
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
|
||||
wsa_server.get_instance().add_cmd(content)
|
||||
ws.close()#TODO
|
||||
elif name == 'TranscriptionResultChanged':
|
||||
self.finalResults = data['payload']['result']
|
||||
if wsa_server.get_web_instance().is_connected(self.username):
|
||||
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
|
||||
if wsa_server.get_instance().is_connected(self.username):
|
||||
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
|
||||
wsa_server.get_instance().add_cmd(content)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# print("### message:", message)
|
||||
|
||||
# 收到websocket的关闭要求
|
||||
def on_close(self, ws, code, msg):
|
||||
self.__endding = True
|
||||
self.__is_close = True
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(self, ws, error):
|
||||
print("aliyun asr error:", error)
|
||||
self.started = True #避免在aliyun asr出错时,recorder一直等待start状态返回
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(self, ws):
|
||||
self.__endding = False
|
||||
#为了兼容多路asr,关闭过程数据
|
||||
def run(*args):
|
||||
while self.__endding == False:
|
||||
try:
|
||||
if len(self.__frames) > 0:
|
||||
with self.lock:
|
||||
frame = self.__frames.pop(0)
|
||||
if isinstance(frame, dict):
|
||||
ws.send(json.dumps(frame))
|
||||
elif isinstance(frame, bytes):
|
||||
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
|
||||
self.data += frame
|
||||
else:
|
||||
time.sleep(0.001) # 避免忙等
|
||||
except Exception as e:
|
||||
print(e)
|
||||
break
|
||||
if self.__is_close == False:
|
||||
for frame in self.__frames:
|
||||
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
|
||||
frame = {"header": self.__create_header('StopTranscription')}
|
||||
ws.send(json.dumps(frame))
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
def __connect(self):
|
||||
if not _token:
|
||||
util.log(2, "AliNLS token为空,本次语音识别连接跳过")
|
||||
self.started = True
|
||||
return
|
||||
|
||||
self.finalResults = ""
|
||||
self.done = False
|
||||
with self.lock:
|
||||
self.__frames.clear()
|
||||
self.__ws = websocket.WebSocketApp(self.__URL + '?token=' + _token, on_message=self.on_message)
|
||||
self.__ws.on_open = self.on_open
|
||||
self.__ws.on_error = self.on_error
|
||||
self.__ws.on_close = self.on_close
|
||||
self.__ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def send(self, buf):
|
||||
with self.lock:
|
||||
self.__frames.append(buf)
|
||||
|
||||
def start(self):
|
||||
Thread(target=self.__connect, args=[]).start()
|
||||
data = {
|
||||
'header': self.__create_header('StartTranscription'),
|
||||
"payload": {
|
||||
"format": "pcm",
|
||||
"sample_rate": 16000,
|
||||
"enable_intermediate_result": True,
|
||||
"enable_punctuation_prediction": False,
|
||||
"enable_inverse_text_normalization": True,
|
||||
"speech_noise_threshold": -1
|
||||
}
|
||||
}
|
||||
self.send(data)
|
||||
|
||||
def end(self):
|
||||
self.__endding = True
|
||||
with wave.open('cache_data/input2.wav', 'wb') as wf:
|
||||
# 设置音频参数
|
||||
n_channels = 1 # 单声道
|
||||
sampwidth = 2 # 16 位音频,每个采样点 2 字节
|
||||
wf.setnchannels(n_channels)
|
||||
wf.setsampwidth(sampwidth)
|
||||
wf.setframerate(16000)
|
||||
wf.writeframes(self.data)
|
||||
self.data = b''
|
||||
self.__ws = websocket.WebSocketApp(self.__URL + '?token=' + _token, on_message=self.on_message)
|
||||
self.__ws.on_open = self.on_open
|
||||
self.__ws.on_error = self.on_error
|
||||
self.__ws.on_close = self.on_close
|
||||
self.__ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def send(self, buf):
|
||||
with self.lock:
|
||||
self.__frames.append(buf)
|
||||
|
||||
def start(self):
|
||||
Thread(target=self.__connect, args=[]).start()
|
||||
data = {
|
||||
'header': self.__create_header('StartTranscription'),
|
||||
"payload": {
|
||||
"format": "pcm",
|
||||
"sample_rate": 16000,
|
||||
"enable_intermediate_result": True,
|
||||
"enable_punctuation_prediction": False,
|
||||
"enable_inverse_text_normalization": True,
|
||||
"speech_noise_threshold": -1
|
||||
}
|
||||
}
|
||||
self.send(data)
|
||||
|
||||
def end(self):
|
||||
self.__endding = True
|
||||
with wave.open('cache_data/input2.wav', 'wb') as wf:
|
||||
# 设置音频参数
|
||||
n_channels = 1 # 单声道
|
||||
sampwidth = 2 # 16 位音频,每个采样点 2 字节
|
||||
wf.setnchannels(n_channels)
|
||||
wf.setsampwidth(sampwidth)
|
||||
wf.setframerate(16000)
|
||||
wf.writeframes(self.data)
|
||||
self.data = b''
|
||||
|
||||
@@ -17,7 +17,7 @@ from bionicmemory.algorithms.newton_cooling_helper import NewtonCoolingHelper, C
|
||||
from bionicmemory.core.chroma_service import ChromaService
|
||||
from bionicmemory.services.summary_service import SummaryService
|
||||
from bionicmemory.algorithms.clustering_suppression import ClusteringSuppression
|
||||
from utils.api_embedding_service import get_embedding_service
|
||||
from utils.api_embedding_service import get_embedding_service
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
@@ -238,7 +238,8 @@ class LongShortTermMemorySystem:
|
||||
logger.info(f"[仿生记忆] _prepare_document_data: embedding生成完成, 类型={type(embedding)}")
|
||||
except Exception as e:
|
||||
logger.error(f"生成embedding失败: {e}")
|
||||
embedding = []
|
||||
# 返回None而不是空列表,让调用方知道这是一个失败的情况
|
||||
embedding = None
|
||||
|
||||
# 准备元数据
|
||||
current_time = datetime.now().isoformat()
|
||||
@@ -398,12 +399,15 @@ class LongShortTermMemorySystem:
|
||||
embedding_list = embedding.tolist()
|
||||
else:
|
||||
embedding_list = embedding
|
||||
# 检查长度
|
||||
# 检查长度和维度
|
||||
if len(embedding_list) > 0:
|
||||
embeddings_param = [embedding_list]
|
||||
logger.debug(f"添加embedding到长期记忆: doc_id={doc_id}, 维度={len(embedding_list)}")
|
||||
else:
|
||||
logger.warning(f"⚠️ Embedding为空列表: doc_id={doc_id}")
|
||||
embeddings_param = None
|
||||
else:
|
||||
logger.warning(f"⚠️ Embedding为None: doc_id={doc_id}, 将不包含向量信息")
|
||||
embeddings_param = None
|
||||
|
||||
self.chroma_service.add_documents(
|
||||
@@ -1046,6 +1050,16 @@ class LongShortTermMemorySystem:
|
||||
)
|
||||
logger.info(f"[仿生记忆] 步骤1完成: doc_id={doc_id}, user_embedding类型={type(user_embedding)}")
|
||||
|
||||
# 检查embedding是否生成成功
|
||||
if user_embedding is None:
|
||||
logger.warning("[仿生记忆] embedding生成失败,跳过记忆检索,使用默认响应")
|
||||
return {
|
||||
"system_prompt": "你是一个智能助手。由于技术问题,暂时无法访问历史记忆,但我会尽力帮助你。",
|
||||
"retrieved_memories": [],
|
||||
"user_doc_id": None,
|
||||
"assistant_doc_id": None
|
||||
}
|
||||
|
||||
# 使用用户embedding进行检索
|
||||
logger.info("[仿生记忆] 步骤2: 使用用户embedding进行检索")
|
||||
query_embedding = user_embedding
|
||||
@@ -1485,4 +1499,4 @@ if __name__ == "__main__":
|
||||
stats_after['short_term_memory']['total_records'] == 0:
|
||||
logger.info(f"✅ 用户 {target_user_id} (key: {target_key[:10]}...) 历史记录清除成功!")
|
||||
else:
|
||||
logger.warning(f"⚠️ 用户 {target_user_id} (key: {target_key[:10]}...) 历史记录可能未完全清除")
|
||||
logger.warning(f"⚠️ 用户 {target_user_id} (key: {target_key[:10]}...) 历史记录可能未完全清除")
|
||||
|
||||
469
core/fay_core.py
469
core/fay_core.py
@@ -31,7 +31,9 @@ from queue import Queue
|
||||
import re # 添加正则表达式模块用于过滤表情符号
|
||||
|
||||
|
||||
import uuid
|
||||
import uuid
|
||||
import hashlib
|
||||
from urllib.parse import urlparse, urljoin
|
||||
|
||||
|
||||
|
||||
@@ -136,16 +138,24 @@ else:
|
||||
import platform
|
||||
|
||||
|
||||
if platform.system() == "Windows":
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append("test/ovr_lipsync")
|
||||
|
||||
|
||||
from test_olipsync import LipSyncGenerator
|
||||
if platform.system() == "Windows":
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
_fay_runtime_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
if hasattr(sys, "_MEIPASS"):
|
||||
_fay_runtime_dir = os.path.abspath(sys._MEIPASS)
|
||||
else:
|
||||
_fay_runtime_dir = os.path.abspath(os.path.join(_fay_runtime_dir, ".."))
|
||||
|
||||
_lipsync_dir = os.path.join(_fay_runtime_dir, "test", "ovr_lipsync")
|
||||
if _lipsync_dir not in map(os.path.abspath, sys.path):
|
||||
sys.path.insert(0, _lipsync_dir)
|
||||
|
||||
|
||||
from test_olipsync import LipSyncGenerator
|
||||
|
||||
|
||||
|
||||
@@ -237,7 +247,15 @@ class FeiFei:
|
||||
self.think_display_limit = 400
|
||||
self.user_conv_map = {} #存储用户对话id及句子流序号,key为(username, conversation_id)
|
||||
|
||||
self.pending_isfirst = {} # 存储因prestart被过滤而延迟的isfirst标记,key为username
|
||||
self.pending_isfirst = {} # 存储因prestart被过滤而延迟的isfirst标记,key为username
|
||||
self.tts_cache = {}
|
||||
self.tts_cache_limit = 1000
|
||||
self.tts_cache_lock = threading.Lock()
|
||||
self.user_audio_conv_map = {} # 仅用于音频片段的连续序号(避免文本序号空洞导致乱序/缺包)
|
||||
self.human_audio_order_map = {}
|
||||
self.human_audio_order_lock = threading.Lock()
|
||||
self.human_audio_reorder_wait_seconds = 0.2
|
||||
self.human_audio_first_wait_seconds = 1.2
|
||||
|
||||
|
||||
|
||||
@@ -868,7 +886,7 @@ class FeiFei:
|
||||
#获取不同情绪声音
|
||||
|
||||
|
||||
def __get_mood_voice(self):
|
||||
def __get_mood_voice(self):
|
||||
|
||||
|
||||
voice = tts_voice.get_voice_of(config_util.config["attribute"]["voice"])
|
||||
@@ -883,18 +901,168 @@ class FeiFei:
|
||||
styleList = voice.value["styleList"]
|
||||
|
||||
|
||||
sayType = styleList["calm"]
|
||||
|
||||
|
||||
return sayType
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 合成声音
|
||||
|
||||
|
||||
sayType = styleList["calm"]
|
||||
|
||||
|
||||
return sayType
|
||||
|
||||
def __build_tts_cache_key(self, text, style):
|
||||
tts_module = str(getattr(cfg, "tts_module", "") or "")
|
||||
style_str = str(style or "")
|
||||
voice_name = ""
|
||||
try:
|
||||
voice_name = str(config_util.config.get("attribute", {}).get("voice", "") or "")
|
||||
except Exception:
|
||||
voice_name = ""
|
||||
if tts_module == "volcano":
|
||||
try:
|
||||
volcano_voice = str(getattr(cfg, "volcano_tts_voice_type", "") or "")
|
||||
if volcano_voice:
|
||||
voice_name = volcano_voice
|
||||
except Exception:
|
||||
pass
|
||||
raw = f"{tts_module}|{voice_name}|{style_str}|{text}"
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
def __get_tts_cache(self, key):
|
||||
with self.tts_cache_lock:
|
||||
file_url = self.tts_cache.get(key)
|
||||
if not file_url:
|
||||
return None
|
||||
if os.path.exists(file_url):
|
||||
return file_url
|
||||
with self.tts_cache_lock:
|
||||
if key in self.tts_cache:
|
||||
del self.tts_cache[key]
|
||||
return None
|
||||
|
||||
def __set_tts_cache(self, key, file_url):
|
||||
if not file_url:
|
||||
return
|
||||
with self.tts_cache_lock:
|
||||
self.tts_cache[key] = file_url
|
||||
while len(self.tts_cache) > self.tts_cache_limit:
|
||||
try:
|
||||
self.tts_cache.pop(next(iter(self.tts_cache)))
|
||||
except Exception:
|
||||
break
|
||||
|
||||
def __send_human_audio_ordered(self, content, username, conversation_id, conversation_msg_no, is_end=False):
|
||||
now = time.time()
|
||||
sent_messages = []
|
||||
data = content.get("Data", {}) if isinstance(content, dict) else {}
|
||||
has_audio_payload = bool(data.get("Value")) or bool(data.get("HttpValue"))
|
||||
is_end_marker_only = bool(is_end or data.get("IsEnd", 0)) and (not has_audio_payload)
|
||||
|
||||
seq = None
|
||||
try:
|
||||
if conversation_msg_no is not None:
|
||||
seq = int(conversation_msg_no)
|
||||
except Exception:
|
||||
seq = None
|
||||
|
||||
# Fallback to direct send for legacy paths without sequence metadata.
|
||||
if (not conversation_id) or (seq is None):
|
||||
if is_end_marker_only:
|
||||
return 0
|
||||
wsa_server.get_instance().add_cmd(content)
|
||||
return 1
|
||||
|
||||
key = (username or "User", conversation_id)
|
||||
with self.human_audio_order_lock:
|
||||
state = self.human_audio_order_map.get(key)
|
||||
if state is None:
|
||||
state = {
|
||||
"next_seq": None,
|
||||
"buffer": {},
|
||||
"last_progress_time": now,
|
||||
"first_wait_start": now,
|
||||
"start_known": False,
|
||||
"end_seq": None,
|
||||
"pending_end_seq": None,
|
||||
}
|
||||
self.human_audio_order_map[key] = state
|
||||
|
||||
next_seq = state.get("next_seq")
|
||||
if (next_seq is not None) and (seq < next_seq):
|
||||
return 0
|
||||
|
||||
def _mark_buffer_end(target_seq):
|
||||
existed = state["buffer"].get(target_seq)
|
||||
if isinstance(existed, dict):
|
||||
existed_data = existed.get("Data", {})
|
||||
if isinstance(existed_data, dict):
|
||||
existed_data["IsEnd"] = 1
|
||||
return True
|
||||
return False
|
||||
|
||||
if is_end_marker_only:
|
||||
target_seq = None
|
||||
if seq in state["buffer"]:
|
||||
target_seq = seq
|
||||
elif (seq - 1) in state["buffer"]:
|
||||
target_seq = seq - 1
|
||||
elif state["buffer"]:
|
||||
target_seq = max(state["buffer"].keys())
|
||||
|
||||
if (target_seq is not None) and _mark_buffer_end(target_seq):
|
||||
end_seq = state.get("end_seq")
|
||||
state["end_seq"] = target_seq if end_seq is None else max(end_seq, target_seq)
|
||||
state["pending_end_seq"] = None
|
||||
else:
|
||||
state["pending_end_seq"] = seq
|
||||
else:
|
||||
if seq in state["buffer"]:
|
||||
return 0
|
||||
state["buffer"][seq] = content
|
||||
|
||||
pending_end_seq = state.get("pending_end_seq")
|
||||
if pending_end_seq is not None:
|
||||
if (seq == pending_end_seq) or (seq == pending_end_seq - 1):
|
||||
if _mark_buffer_end(seq):
|
||||
end_seq = state.get("end_seq")
|
||||
state["end_seq"] = seq if end_seq is None else max(end_seq, seq)
|
||||
state["pending_end_seq"] = None
|
||||
|
||||
if is_end:
|
||||
end_seq = state.get("end_seq")
|
||||
state["end_seq"] = seq if end_seq is None else max(end_seq, seq)
|
||||
|
||||
is_first_flag = bool(data.get("IsFirst", 0))
|
||||
if (not state["start_known"]) and is_first_flag:
|
||||
state["start_known"] = True
|
||||
state["next_seq"] = seq
|
||||
state["last_progress_time"] = now
|
||||
elif (not state["start_known"]) and (seq == 0):
|
||||
state["start_known"] = True
|
||||
state["next_seq"] = 0
|
||||
state["last_progress_time"] = now
|
||||
elif (not state["start_known"]):
|
||||
first_elapsed = now - state.get("first_wait_start", now)
|
||||
if (first_elapsed >= self.human_audio_first_wait_seconds) and (0 in state["buffer"]):
|
||||
state["start_known"] = True
|
||||
state["next_seq"] = 0
|
||||
state["last_progress_time"] = now
|
||||
|
||||
def _flush_contiguous():
|
||||
flush_count = 0
|
||||
while (state["next_seq"] is not None) and (state["next_seq"] in state["buffer"]):
|
||||
sent_messages.append(state["buffer"].pop(state["next_seq"]))
|
||||
state["next_seq"] += 1
|
||||
state["last_progress_time"] = now
|
||||
flush_count += 1
|
||||
return flush_count
|
||||
|
||||
_flush_contiguous()
|
||||
|
||||
end_seq = state.get("end_seq")
|
||||
if (end_seq is not None) and (state.get("next_seq") is not None) and (state["next_seq"] > end_seq) and (not state["buffer"]):
|
||||
self.human_audio_order_map.pop(key, None)
|
||||
|
||||
for message in sent_messages:
|
||||
wsa_server.get_instance().add_cmd(message)
|
||||
return len(sent_messages)
|
||||
|
||||
def say(self, interact, text, type = ""):
|
||||
|
||||
|
||||
@@ -1085,25 +1253,27 @@ class FeiFei:
|
||||
# 如果 conv_map_key 不存在,尝试使用 username 作为备用查找
|
||||
|
||||
|
||||
if not conv_info and text and text.strip():
|
||||
if not conv_info and text and text.strip():
|
||||
|
||||
|
||||
# 查找所有匹配用户名的会话
|
||||
|
||||
|
||||
for (u, c), info in list(self.user_conv_map.items()):
|
||||
for (u, c), info in list(self.user_conv_map.items()):
|
||||
|
||||
|
||||
if u == username and info.get("content_id", 0) > 0:
|
||||
if u == username and info.get("content_id", 0) > 0:
|
||||
|
||||
|
||||
content_id = info.get("content_id", 0)
|
||||
content_id = info.get("content_id", 0)
|
||||
|
||||
|
||||
conv_info = info
|
||||
conv = info.get("conversation_id", c)
|
||||
conv_map_key = (username, conv)
|
||||
|
||||
|
||||
conv_info = info
|
||||
|
||||
|
||||
util.log(1, f"警告:使用备用会话 ({u}, {c}) 的 content_id={content_id},原 key=({username}, {conv})")
|
||||
util.log(1, f"警告:使用备用会话 ({u}, {c}) 的 content_id={content_id},原 key=({username}, {conv})")
|
||||
|
||||
|
||||
break
|
||||
@@ -1157,10 +1327,27 @@ class FeiFei:
|
||||
|
||||
|
||||
|
||||
# 会话结束时清理 user_conv_map 中的对应条目,避免内存泄漏
|
||||
# 固化当前会话序号,避免异步音频线程读取时会话映射已被清理而回落为0
|
||||
current_conv_info = self.user_conv_map.get(conv_map_key, {})
|
||||
if (not current_conv_info) and (not conv):
|
||||
for (u, c), info in list(self.user_conv_map.items()):
|
||||
if u == username and info.get("conversation_id", ""):
|
||||
current_conv_info = info
|
||||
conv = info.get("conversation_id", c)
|
||||
conv_map_key = (username, conv)
|
||||
break
|
||||
if current_conv_info:
|
||||
interact.data["conversation_id"] = current_conv_info.get("conversation_id", conv)
|
||||
interact.data["conversation_msg_no"] = current_conv_info.get("conversation_msg_no", 0)
|
||||
else:
|
||||
if conv:
|
||||
interact.data["conversation_id"] = conv
|
||||
interact.data["conversation_msg_no"] = interact.data.get("conversation_msg_no", 0)
|
||||
|
||||
# 会话结束时清理 user_conv_map 中的对应条目,避免内存泄漏
|
||||
|
||||
|
||||
if is_end and conv_map_key in self.user_conv_map:
|
||||
if is_end and conv_map_key in self.user_conv_map:
|
||||
|
||||
|
||||
del self.user_conv_map[conv_map_key]
|
||||
@@ -1394,7 +1581,15 @@ class FeiFei:
|
||||
tm = time.time()
|
||||
|
||||
|
||||
result = self.sp.to_sample(filtered_text, self.__get_mood_voice())
|
||||
mood_voice = self.__get_mood_voice()
|
||||
cache_key = self.__build_tts_cache_key(filtered_text, mood_voice)
|
||||
cache_result = self.__get_tts_cache(cache_key)
|
||||
if cache_result is not None:
|
||||
result = cache_result
|
||||
util.printInfo(1, interact.data.get('user'), 'TTS cache hit')
|
||||
else:
|
||||
result = self.sp.to_sample(filtered_text, mood_voice)
|
||||
self.__set_tts_cache(cache_key, result)
|
||||
|
||||
|
||||
# 合成完成后再次检查会话是否仍有效,避免继续输出旧会话结果
|
||||
@@ -1439,7 +1634,22 @@ class FeiFei:
|
||||
|
||||
|
||||
|
||||
if result is not None or is_first or is_end:
|
||||
# 为数字人音频单独维护连续序号,避免 conversation_msg_no 因无音频片段产生空洞
|
||||
audio_conv_id = interact.data.get("conversation_id", "") or ""
|
||||
audio_conv_key = (username, audio_conv_id)
|
||||
audio_msg_no = None
|
||||
if result is not None:
|
||||
audio_msg_no = self.user_audio_conv_map.get(audio_conv_key, -1) + 1
|
||||
self.user_audio_conv_map[audio_conv_key] = audio_msg_no
|
||||
elif is_end:
|
||||
audio_msg_no = self.user_audio_conv_map.get(audio_conv_key, None)
|
||||
if audio_conv_key in self.user_audio_conv_map:
|
||||
del self.user_audio_conv_map[audio_conv_key]
|
||||
interact.data["audio_conversation_msg_no"] = audio_msg_no
|
||||
if is_end and audio_conv_key in self.user_audio_conv_map:
|
||||
del self.user_audio_conv_map[audio_conv_key]
|
||||
|
||||
if result is not None or is_first or is_end:
|
||||
|
||||
|
||||
# prestart 内容不需要进入音频处理流程
|
||||
@@ -1490,7 +1700,26 @@ class FeiFei:
|
||||
# 发送HTTP GET请求以获取WAV文件内容
|
||||
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
if url is None:
|
||||
return None
|
||||
|
||||
url = str(url).strip()
|
||||
if not url:
|
||||
return None
|
||||
|
||||
if os.path.isfile(url):
|
||||
return url
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme:
|
||||
if url.startswith('//'):
|
||||
url = 'http:' + url
|
||||
else:
|
||||
base_url = str(getattr(cfg, "fay_url", "") or "").strip()
|
||||
if base_url:
|
||||
url = urljoin(base_url.rstrip('/') + '/', url.lstrip('/'))
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
@@ -1946,61 +2175,115 @@ class FeiFei:
|
||||
|
||||
|
||||
|
||||
#发送音频给数字人接口
|
||||
|
||||
|
||||
if file_url is not None and wsa_server.get_instance().get_client_output(interact.data.get("user")):
|
||||
|
||||
|
||||
# 使用 (username, conversation_id) 作为 key 获取会话信息
|
||||
|
||||
|
||||
audio_username = interact.data.get("user", "User")
|
||||
|
||||
|
||||
audio_conv_id = interact.data.get("conversation_id") or ""
|
||||
|
||||
|
||||
audio_conv_info = self.user_conv_map.get((audio_username, audio_conv_id), {})
|
||||
|
||||
|
||||
content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0, 'CONV_ID' : audio_conv_info.get("conversation_id", ""), 'CONV_MSG_NO' : audio_conv_info.get("conversation_msg_no", 0) }, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}
|
||||
|
||||
|
||||
#计算lips
|
||||
|
||||
|
||||
if platform.system() == "Windows":
|
||||
|
||||
|
||||
try:
|
||||
|
||||
|
||||
lip_sync_generator = LipSyncGenerator()
|
||||
|
||||
|
||||
viseme_list = lip_sync_generator.generate_visemes(os.path.abspath(file_url))
|
||||
|
||||
|
||||
consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
|
||||
|
||||
|
||||
content["Data"]["Lips"] = consolidated_visemes
|
||||
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
||||
print(e)
|
||||
|
||||
|
||||
util.printInfo(1, interact.data.get("user"), "唇型数据生成失败")
|
||||
|
||||
|
||||
wsa_server.get_instance().add_cmd(content)
|
||||
|
||||
|
||||
util.printInfo(1, interact.data.get("user"), "数字人接口发送音频数据成功")
|
||||
#发送音频给数字人接口
|
||||
|
||||
|
||||
if wsa_server.get_instance().get_client_output(interact.data.get("user")):
|
||||
|
||||
|
||||
# 使用 (username, conversation_id) 作为 key 获取会话信息
|
||||
|
||||
|
||||
audio_username = interact.data.get("user", "User")
|
||||
|
||||
|
||||
audio_conv_id = interact.data.get("conversation_id") or ""
|
||||
|
||||
|
||||
audio_conv_info = self.user_conv_map.get((audio_username, audio_conv_id), {})
|
||||
msg_no_from_interact = interact.data.get("audio_conversation_msg_no", None)
|
||||
conv_id_for_send = audio_conv_id if audio_conv_id else audio_conv_info.get("conversation_id", "")
|
||||
if msg_no_from_interact is None:
|
||||
fallback_no = interact.data.get("conversation_msg_no", None)
|
||||
if fallback_no is None:
|
||||
conv_msg_no_for_send = audio_conv_info.get("conversation_msg_no", 0)
|
||||
else:
|
||||
conv_msg_no_for_send = fallback_no
|
||||
else:
|
||||
conv_msg_no_for_send = msg_no_from_interact
|
||||
|
||||
if file_url is not None:
|
||||
|
||||
|
||||
content = {'Topic': 'human', 'Data': {'Key': 'audio', 'Value': os.path.abspath(file_url), 'HttpValue': f'{cfg.fay_url}/audio/' + os.path.basename(file_url), 'Text': text, 'Time': audio_length, 'Type': interact.interleaver, 'IsFirst': 1 if interact.data.get("isfirst", False) else 0, 'IsEnd': 1 if interact.data.get("isend", False) else 0, 'CONV_ID' : conv_id_for_send, 'CONV_MSG_NO' : conv_msg_no_for_send }, 'Username' : interact.data.get('user'), 'robot': f'{cfg.fay_url}/robot/Speaking.jpg'}
|
||||
|
||||
|
||||
#计算lips
|
||||
|
||||
|
||||
if platform.system() == "Windows":
|
||||
|
||||
|
||||
try:
|
||||
|
||||
|
||||
lip_sync_generator = LipSyncGenerator()
|
||||
|
||||
|
||||
viseme_list = lip_sync_generator.generate_visemes(os.path.abspath(file_url))
|
||||
|
||||
|
||||
consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
|
||||
|
||||
|
||||
content["Data"]["Lips"] = consolidated_visemes
|
||||
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
||||
print(e)
|
||||
|
||||
|
||||
util.printInfo(1, interact.data.get("user"), "唇型数据生成失败")
|
||||
|
||||
|
||||
sent_count = self.__send_human_audio_ordered(
|
||||
content,
|
||||
audio_username,
|
||||
conv_id_for_send,
|
||||
conv_msg_no_for_send,
|
||||
is_end=bool(interact.data.get("isend", False)),
|
||||
)
|
||||
if sent_count > 0:
|
||||
util.printInfo(1, interact.data.get("user"), "digital human audio sent")
|
||||
else:
|
||||
util.printInfo(1, interact.data.get("user"), "digital human audio queued")
|
||||
elif bool(interact.data.get("isend", False)):
|
||||
# 没有音频文件时,也要给数字人发送结束标记,避免客户端一直等待
|
||||
end_target_seq = conv_msg_no_for_send
|
||||
try:
|
||||
end_target_seq = int(conv_msg_no_for_send)
|
||||
except Exception:
|
||||
end_target_seq = conv_msg_no_for_send
|
||||
end_content = {
|
||||
'Topic': 'human',
|
||||
'Data': {
|
||||
'Key': 'audio',
|
||||
'Value': '',
|
||||
'HttpValue': '',
|
||||
'Text': text,
|
||||
'Time': 0,
|
||||
'Type': interact.interleaver,
|
||||
'IsFirst': 1 if interact.data.get("isfirst", False) else 0,
|
||||
'IsEnd': 1,
|
||||
'CONV_ID': conv_id_for_send,
|
||||
'CONV_MSG_NO': end_target_seq
|
||||
},
|
||||
'Username': interact.data.get('user'),
|
||||
'robot': f'{cfg.fay_url}/robot/Speaking.jpg'
|
||||
}
|
||||
sent_count = self.__send_human_audio_ordered(
|
||||
end_content,
|
||||
audio_username,
|
||||
conv_id_for_send,
|
||||
end_target_seq,
|
||||
is_end=True,
|
||||
)
|
||||
if sent_count > 0:
|
||||
util.printInfo(1, interact.data.get("user"), "digital human audio end sent")
|
||||
else:
|
||||
util.printInfo(1, interact.data.get("user"), "digital human audio end queued")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,288 +1,318 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Fay broadcast MCP server (SSE transport).
|
||||
|
||||
暴露 `broadcast_message` 工具,将文本/音频透传到 Fay 的 `/transparent-pass`。
|
||||
|
||||
环境变量:
|
||||
- FAY_BROADCAST_API 默认 http://127.0.0.1:5000/transparent-pass
|
||||
- FAY_BROADCAST_USER 默认 User
|
||||
- FAY_BROADCAST_TIMEOUT 默认 10
|
||||
- FAY_MCP_SSE_HOST 默认 0.0.0.0
|
||||
- FAY_MCP_SSE_PORT 默认 8765
|
||||
- FAY_MCP_SSE_PATH SSE 路径(默认 /sse)
|
||||
- FAY_MCP_MSG_PATH 消息 POST 路径(默认 /messages)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import Any, Dict, Tuple, List, Optional
|
||||
|
||||
try:
|
||||
from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from faymcp import tool_registry
|
||||
from faymcp import mcp_service
|
||||
except ImportError:
|
||||
print("缺少 mcp 库,请先安装:pip install mcp", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
except ImportError:
|
||||
print("缺少 starlette,请先安装:pip install starlette sse-starlette", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
print("缺少 uvicorn,请先安装:pip install uvicorn", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
print("缺少 requests,请先安装:pip install requests", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
log = logging.getLogger("fay_mcp_server")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
|
||||
SERVER_NAME = "fay_broadcast"
|
||||
|
||||
DEFAULT_API_URL = os.environ.get("FAY_BROADCAST_API", "http://127.0.0.1:5000/transparent-pass")
|
||||
DEFAULT_USER = os.environ.get("FAY_BROADCAST_USER", "User")
|
||||
DEFAULT_SPEAKER = os.environ.get("FAY_BROADCAST_SPEAKER", "\u5e7f\u64ad\u6d88\u606f")
|
||||
REQUEST_TIMEOUT = float(os.environ.get("FAY_BROADCAST_TIMEOUT", "10"))
|
||||
|
||||
HOST = os.environ.get("FAY_MCP_SSE_HOST", "0.0.0.0")
|
||||
PORT = int(os.environ.get("FAY_MCP_SSE_PORT", "8765"))
|
||||
SSE_PATH = os.environ.get("FAY_MCP_SSE_PATH", "/sse")
|
||||
MSG_PATH = os.environ.get("FAY_MCP_MSG_PATH", "/messages")
|
||||
|
||||
server = None # Removed global singleton
|
||||
sse_transport = SseServerTransport(MSG_PATH)
|
||||
|
||||
# 聚合工具索引:namespaced_tool_name -> (server_id, tool_name)
|
||||
_aggregated_index: Dict[str, Tuple[int, str]] = {}
|
||||
|
||||
|
||||
def _text_content(text: str) -> TextContent:
|
||||
try:
|
||||
return TextContent(type="text", text=text)
|
||||
except Exception:
|
||||
return {"type": "text", "text": text} # type: ignore[return-value]
|
||||
|
||||
|
||||
TOOLS: list[Tool] = [
|
||||
Tool(
|
||||
name="broadcast_message",
|
||||
description="通过 Fay 的 /transparent-pass 广播文本/音频(SSE 服务器)。",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要广播的文本(audio_url为空时必填)"},
|
||||
"audio_url": {"type": "string", "description": "可选音频 URL"},
|
||||
"user": {"type": "string", "description": "目标用户名,默认 FAY_BROADCAST_USER 或 User"},
|
||||
"speaker": {
|
||||
"type": "string",
|
||||
"description": "\u5e7f\u64ad\u4eba\u663e\u793a\u540d\uff0c\u8f93\u51fa\u4e3a\"{speaker}\u8bf4\uff1a{text}\"",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def _handle_list_tools() -> list[Tool]:
|
||||
# 本地广播工具 + Fay 当前在线 MCP 工具的聚合视图(namespaced)
|
||||
aggregated = []
|
||||
try:
|
||||
aggregated = _build_aggregated_tools()
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to build aggregated tools: {e}")
|
||||
return TOOLS + aggregated
|
||||
|
||||
|
||||
def _parse_arguments(arguments: Dict[str, Any]) -> Tuple[str, str, str, str]:
|
||||
text = str(arguments.get("text", "") or "").strip()
|
||||
audio_url = str(arguments.get("audio_url", "") or "").strip()
|
||||
user = str(arguments.get("user", "") or "").strip() or DEFAULT_USER
|
||||
speaker = str(arguments.get("speaker", "") or "").strip() or DEFAULT_SPEAKER
|
||||
return text, audio_url, user, speaker
|
||||
|
||||
|
||||
def _build_aggregated_tools() -> List[Tool]:
|
||||
"""
|
||||
将 Fay 已连接的 MCP 工具聚合,对外暴露为 namespaced 名称:
|
||||
<server_id>:<tool_name>
|
||||
"""
|
||||
tools: List[Tool] = []
|
||||
_aggregated_index.clear()
|
||||
|
||||
server_name_map = {s.get("id"): s.get("name", f"Server{s.get('id')}") for s in mcp_service.mcp_servers or []}
|
||||
|
||||
for entry in tool_registry.get_enabled_tools():
|
||||
server_id = entry.get("server_id")
|
||||
tool_name = entry.get("name")
|
||||
if server_id is None or not tool_name:
|
||||
continue
|
||||
agg_name = f"{server_id}:{tool_name}"
|
||||
desc = entry.get("description", "")
|
||||
server_label = server_name_map.get(server_id, f"Server {server_id}")
|
||||
agg_desc = f"{desc} [via {server_label}]"
|
||||
input_schema = entry.get("inputSchema") or {}
|
||||
tool = Tool(
|
||||
name=agg_name,
|
||||
description=agg_desc,
|
||||
inputSchema=input_schema if isinstance(input_schema, dict) else {},
|
||||
)
|
||||
tools.append(tool)
|
||||
_aggregated_index[agg_name] = (server_id, tool_name)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def _send_broadcast(payload: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
def _post() -> Tuple[bool, str]:
|
||||
body = json.dumps(payload, ensure_ascii=True).encode("utf-8")
|
||||
resp = requests.post(
|
||||
DEFAULT_API_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
data = None
|
||||
|
||||
if resp.ok:
|
||||
if isinstance(data, dict):
|
||||
msg = data.get("message") or data.get("msg") or ""
|
||||
code = data.get("code")
|
||||
if isinstance(code, int) and code >= 400:
|
||||
return False, msg or f"Broadcast failed with code {code}"
|
||||
return True, msg or "Broadcast sent via Fay."
|
||||
return True, "Broadcast sent via Fay."
|
||||
|
||||
err_detail = ""
|
||||
if isinstance(data, dict):
|
||||
err_detail = data.get("message") or data.get("error") or data.get("msg") or ""
|
||||
if not err_detail:
|
||||
err_detail = resp.text
|
||||
return False, f"HTTP {resp.status_code}: {err_detail}"
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_post)
|
||||
except Exception as e:
|
||||
return False, f"{type(e).__name__}: {e}"
|
||||
|
||||
|
||||
async def _handle_call_tool(name: str, arguments: Dict[str, Any]) -> list[TextContent]:
|
||||
# 本地广播
|
||||
if name == "broadcast_message":
|
||||
text, audio_url, user, speaker = _parse_arguments(arguments or {})
|
||||
if not text and not audio_url:
|
||||
return [_text_content("Either 'text' or 'audio_url' must be provided.")]
|
||||
|
||||
payload: Dict[str, Any] = {"user": user}
|
||||
if text:
|
||||
payload["text"] = f"{speaker}\u8bf4\uff1a{text}"
|
||||
if audio_url:
|
||||
payload["audio"] = audio_url
|
||||
|
||||
ok, message = await _send_broadcast(payload)
|
||||
prefix = "success" if ok else "error"
|
||||
return [_text_content(f"{prefix}: {message}")]
|
||||
|
||||
# 聚合的远端 MCP 工具
|
||||
target = _aggregated_index.get(name)
|
||||
if not target:
|
||||
return [_text_content(f"Unknown tool: {name}")]
|
||||
|
||||
server_id, tool_name = target
|
||||
try:
|
||||
success, result = await asyncio.to_thread(mcp_service.call_mcp_tool, server_id, tool_name, arguments or {})
|
||||
if not success:
|
||||
return [_text_content(f"error: {result}")]
|
||||
return _normalize_result(result)
|
||||
except Exception as e:
|
||||
return [_text_content(f"error: {type(e).__name__}: {e}")]
|
||||
|
||||
|
||||
def _normalize_result(result: Any) -> List[TextContent]:
|
||||
"""
|
||||
将上游返回的任意对象转换为 MCP 文本内容列表。
|
||||
"""
|
||||
# 如果已经是 TextContent 或列表,直接返回
|
||||
try:
|
||||
from mcp.types import TextContent
|
||||
if isinstance(result, TextContent):
|
||||
return [result]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(result, list):
|
||||
contents: List[TextContent] = []
|
||||
for item in result:
|
||||
try:
|
||||
if hasattr(item, "type") and getattr(item, "type", "") == "text" and hasattr(item, "text"):
|
||||
contents.append(item)
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
contents.append(TextContent(type="text", text=str(item.get("text", "")))) # type: ignore
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
contents.append(_text_content(str(item)))
|
||||
return contents
|
||||
|
||||
return [_text_content(str(result))]
|
||||
|
||||
|
||||
async def sse_endpoint(request):
|
||||
# 为每个连接创建新的 Server 实例以支持并发
|
||||
local_server = Server(SERVER_NAME)
|
||||
|
||||
# 注册工具处理程序
|
||||
local_server.list_tools()(_handle_list_tools)
|
||||
local_server.call_tool()(_handle_call_tool)
|
||||
|
||||
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
|
||||
await local_server.run(read_stream, write_stream, local_server.create_initialization_options())
|
||||
# 客户端断开时返回空响应,避免 NoneType 问题
|
||||
return Response()
|
||||
|
||||
|
||||
routes = [
|
||||
Route(SSE_PATH, sse_endpoint, methods=["GET"]),
|
||||
Mount(MSG_PATH, app=sse_transport.handle_post_message),
|
||||
]
|
||||
|
||||
app = Starlette(routes=routes)
|
||||
|
||||
|
||||
def main():
|
||||
log.info(f"SSE MCP server started at http://{HOST}:{PORT}{SSE_PATH}")
|
||||
log.info(f"Message endpoint mounted at {MSG_PATH}")
|
||||
uvicorn.run(app, host=HOST, port=PORT, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Fay broadcast MCP server (SSE transport).
|
||||
|
||||
暴露 `broadcast_message` 工具,将文本/音频透传到 Fay 的 `/transparent-pass`。
|
||||
|
||||
环境变量:
|
||||
- FAY_BROADCAST_API 默认 http://127.0.0.1:5000/transparent-pass
|
||||
- FAY_BROADCAST_USER 默认 User
|
||||
- FAY_BROADCAST_TIMEOUT 默认 10
|
||||
- FAY_MCP_SSE_HOST 默认 0.0.0.0
|
||||
- FAY_MCP_SSE_PORT 默认 8765
|
||||
- FAY_MCP_SSE_PATH SSE 路径(默认 /sse)
|
||||
- FAY_MCP_MSG_PATH 消息 POST 路径(默认 /messages)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import Any, Dict, Tuple, List, Optional
|
||||
|
||||
try:
|
||||
from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from faymcp import tool_registry
|
||||
from faymcp import mcp_service
|
||||
except ImportError:
|
||||
print("缺少 mcp 库,请先安装:pip install mcp", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
except ImportError:
|
||||
print("缺少 starlette,请先安装:pip install starlette sse-starlette", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
print("缺少 uvicorn,请先安装:pip install uvicorn", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
print("缺少 requests,请先安装:pip install requests", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
log = logging.getLogger("fay_mcp_server")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
|
||||
SERVER_NAME = "fay_broadcast"
|
||||
|
||||
DEFAULT_API_URL = os.environ.get("FAY_BROADCAST_API", "http://127.0.0.1:5000/transparent-pass")
|
||||
DEFAULT_USER = os.environ.get("FAY_BROADCAST_USER", "User")
|
||||
DEFAULT_SPEAKER = os.environ.get("FAY_BROADCAST_SPEAKER", "\u5e7f\u64ad\u6d88\u606f")
|
||||
REQUEST_TIMEOUT = float(os.environ.get("FAY_BROADCAST_TIMEOUT", "10"))
|
||||
|
||||
HOST = os.environ.get("FAY_MCP_SSE_HOST", "0.0.0.0")
|
||||
PORT = int(os.environ.get("FAY_MCP_SSE_PORT", "8765"))
|
||||
SSE_PATH = os.environ.get("FAY_MCP_SSE_PATH", "/sse")
|
||||
MSG_PATH = os.environ.get("FAY_MCP_MSG_PATH", "/messages")
|
||||
|
||||
server = None # Removed global singleton
|
||||
sse_transport = SseServerTransport(MSG_PATH)
|
||||
|
||||
# 聚合工具索引:namespaced_tool_name -> (server_id, tool_name)
|
||||
_aggregated_index: Dict[str, Tuple[int, str]] = {}
|
||||
|
||||
|
||||
def _text_content(text: str) -> TextContent:
|
||||
try:
|
||||
return TextContent(type="text", text=text)
|
||||
except Exception:
|
||||
return {"type": "text", "text": text} # type: ignore[return-value]
|
||||
|
||||
|
||||
TOOLS: list[Tool] = [
|
||||
Tool(
|
||||
name="broadcast_message",
|
||||
description="通过 Fay 的 /transparent-pass 透传文本/音频。",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要广播的文本(audio_url为空时必填)"},
|
||||
"audio_url": {"type": "string", "description": "可选音频 URL"},
|
||||
"user": {"type": "string", "description": "用户标识名称,默认 FAY_BROADCAST_USER 或 User"},
|
||||
"speaker": {
|
||||
"type": "string",
|
||||
"description": "发言人显示名,输出为\"{speaker}说:{text}\"",
|
||||
},
|
||||
"queue": {"type": "boolean", "description": "是否走队列播放,默认 false"},
|
||||
"queue_playback": {"type": "boolean", "description": "兼容参数,等同 queue"},
|
||||
"enqueue": {"type": "boolean", "description": "兼容参数,等同 queue"},
|
||||
"mode": {"type": "string", "description": "兼容参数,值为 queue 时启用队列播放"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
async def _handle_list_tools() -> list[Tool]:
|
||||
# 本地广播工具 + Fay 当前在线 MCP 工具的聚合视图(namespaced)
|
||||
aggregated = []
|
||||
try:
|
||||
aggregated = _build_aggregated_tools()
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to build aggregated tools: {e}")
|
||||
return TOOLS + aggregated
|
||||
|
||||
|
||||
|
||||
def _as_bool(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
v = value.strip().lower()
|
||||
if v == "":
|
||||
return False
|
||||
return v in {"1", "true", "yes", "on", "y", "queue"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _parse_arguments(arguments: Dict[str, Any]) -> Tuple[str, str, str, str, bool]:
|
||||
text = str(arguments.get("text", "") or "").strip()
|
||||
audio_url = str(arguments.get("audio_url", "") or "").strip()
|
||||
user = str(arguments.get("user", "") or "").strip() or DEFAULT_USER
|
||||
speaker = str(arguments.get("speaker", "") or "").strip() or DEFAULT_SPEAKER
|
||||
if "queue" in arguments:
|
||||
queue = _as_bool(arguments.get("queue"))
|
||||
elif "queue_playback" in arguments:
|
||||
queue = _as_bool(arguments.get("queue_playback"))
|
||||
elif "enqueue" in arguments:
|
||||
queue = _as_bool(arguments.get("enqueue"))
|
||||
elif "mode" in arguments:
|
||||
queue = str(arguments.get("mode", "") or "").strip().lower() == "queue"
|
||||
else:
|
||||
queue = False
|
||||
return text, audio_url, user, speaker, queue
|
||||
def _build_aggregated_tools() -> List[Tool]:
|
||||
"""
|
||||
将 Fay 已连接的 MCP 工具聚合,对外暴露为 namespaced 名称:
|
||||
<server_id>:<tool_name>
|
||||
"""
|
||||
tools: List[Tool] = []
|
||||
_aggregated_index.clear()
|
||||
|
||||
server_name_map = {s.get("id"): s.get("name", f"Server{s.get('id')}") for s in mcp_service.mcp_servers or []}
|
||||
|
||||
for entry in tool_registry.get_enabled_tools():
|
||||
server_id = entry.get("server_id")
|
||||
tool_name = entry.get("name")
|
||||
if server_id is None or not tool_name:
|
||||
continue
|
||||
agg_name = f"{server_id}:{tool_name}"
|
||||
desc = entry.get("description", "")
|
||||
server_label = server_name_map.get(server_id, f"Server {server_id}")
|
||||
agg_desc = f"{desc} [via {server_label}]"
|
||||
input_schema = entry.get("inputSchema") or {}
|
||||
tool = Tool(
|
||||
name=agg_name,
|
||||
description=agg_desc,
|
||||
inputSchema=input_schema if isinstance(input_schema, dict) else {},
|
||||
)
|
||||
tools.append(tool)
|
||||
_aggregated_index[agg_name] = (server_id, tool_name)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def _send_broadcast(payload: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
def _post() -> Tuple[bool, str]:
|
||||
body = json.dumps(payload, ensure_ascii=True).encode("utf-8")
|
||||
resp = requests.post(
|
||||
DEFAULT_API_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
data = None
|
||||
|
||||
if resp.ok:
|
||||
if isinstance(data, dict):
|
||||
msg = data.get("message") or data.get("msg") or ""
|
||||
code = data.get("code")
|
||||
if isinstance(code, int) and code >= 400:
|
||||
return False, msg or f"透传失败,HTTP码 {code}"
|
||||
return True, msg or "已发送透传请求。"
|
||||
return True, "已发送透传请求。"
|
||||
|
||||
err_detail = ""
|
||||
if isinstance(data, dict):
|
||||
err_detail = data.get("message") or data.get("error") or data.get("msg") or ""
|
||||
if not err_detail:
|
||||
err_detail = resp.text
|
||||
return False, f"HTTP {resp.status_code}: {err_detail}"
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_post)
|
||||
except Exception as e:
|
||||
return False, f"{type(e).__name__}: {e}"
|
||||
|
||||
|
||||
async def _handle_call_tool(name: str, arguments: Dict[str, Any]) -> list[TextContent]:
|
||||
# 本地广播
|
||||
if name == "broadcast_message":
|
||||
text, audio_url, user, speaker, queue = _parse_arguments(arguments or {})
|
||||
if not text and not audio_url:
|
||||
return [_text_content("text 或 audio_url 至少需提供一个。")]
|
||||
|
||||
payload: Dict[str, Any] = {"user": user}
|
||||
if text:
|
||||
payload["text"] = f"{speaker}\u8bf4\uff1a{text}"
|
||||
if audio_url:
|
||||
payload["audio"] = audio_url
|
||||
if queue:
|
||||
payload["queue"] = True
|
||||
payload["queue_playback"] = True
|
||||
payload["mode"] = "queue"
|
||||
|
||||
ok, message = await _send_broadcast(payload)
|
||||
prefix = "成功" if ok else "失败"
|
||||
return [_text_content(f"{prefix}: {message}")]
|
||||
|
||||
target = _aggregated_index.get(name)
|
||||
if not target:
|
||||
return [_text_content(f"未知工具: {name}")]
|
||||
|
||||
server_id, tool_name = target
|
||||
try:
|
||||
success, result = await asyncio.to_thread(mcp_service.call_mcp_tool, server_id, tool_name, arguments or {})
|
||||
if not success:
|
||||
return [_text_content(f"error: {result}")]
|
||||
return _normalize_result(result)
|
||||
except Exception as e:
|
||||
return [_text_content(f"error: {type(e).__name__}: {e}")]
|
||||
|
||||
|
||||
def _normalize_result(result: Any) -> List[TextContent]:
|
||||
"""
|
||||
将上游返回的任意对象转换为 MCP 文本内容列表。
|
||||
"""
|
||||
# 如果已经是 TextContent 或列表,直接返回
|
||||
try:
|
||||
from mcp.types import TextContent
|
||||
if isinstance(result, TextContent):
|
||||
return [result]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(result, list):
|
||||
contents: List[TextContent] = []
|
||||
for item in result:
|
||||
try:
|
||||
if hasattr(item, "type") and getattr(item, "type", "") == "text" and hasattr(item, "text"):
|
||||
contents.append(item)
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
contents.append(TextContent(type="text", text=str(item.get("text", "")))) # type: ignore
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
contents.append(_text_content(str(item)))
|
||||
return contents
|
||||
|
||||
return [_text_content(str(result))]
|
||||
|
||||
|
||||
async def sse_endpoint(request):
|
||||
# 为每个连接创建新的 Server 实例以支持并发
|
||||
local_server = Server(SERVER_NAME)
|
||||
|
||||
# 注册工具处理程序
|
||||
local_server.list_tools()(_handle_list_tools)
|
||||
local_server.call_tool()(_handle_call_tool)
|
||||
|
||||
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
|
||||
await local_server.run(read_stream, write_stream, local_server.create_initialization_options())
|
||||
# 客户端断开时返回空响应,避免 NoneType 问题
|
||||
return Response()
|
||||
|
||||
|
||||
routes = [
|
||||
Route(SSE_PATH, sse_endpoint, methods=["GET"]),
|
||||
Mount(MSG_PATH, app=sse_transport.handle_post_message),
|
||||
]
|
||||
|
||||
app = Starlette(routes=routes)
|
||||
|
||||
|
||||
def main():
|
||||
log.info(f"SSE MCP server started at http://{HOST}:{PORT}{SSE_PATH}")
|
||||
log.info(f"Message endpoint mounted at {MSG_PATH}")
|
||||
uvicorn.run(app, host=HOST, port=PORT, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 8.6 KiB After Width: | Height: | Size: 8.2 KiB |
@@ -10,8 +10,17 @@ from flask_cors import CORS
|
||||
import requests
|
||||
import datetime
|
||||
import pytz
|
||||
import logging
|
||||
import uuid
|
||||
import logging
|
||||
import uuid
|
||||
from urllib.parse import urlparse, urljoin
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
except Exception:
|
||||
ChatOpenAI = None
|
||||
HumanMessage = None
|
||||
SystemMessage = None
|
||||
AIMessage = None
|
||||
|
||||
import fay_booter
|
||||
from tts import tts_voice
|
||||
@@ -101,21 +110,77 @@ def _as_bool(value):
|
||||
return value.strip().lower() in ("1", "true", "yes", "y", "on")
|
||||
return False
|
||||
|
||||
def _build_llm_url(base_url: str) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
def _build_llm_url(base_url: str) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
url = base_url.rstrip("/")
|
||||
if url.endswith("/chat/completions"):
|
||||
return url
|
||||
if url.endswith("/v1"):
|
||||
return url + "/chat/completions"
|
||||
return url + "/v1/chat/completions"
|
||||
|
||||
|
||||
|
||||
def _build_embedding_url(base_url: str) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
return url + "/v1/chat/completions"
|
||||
|
||||
|
||||
def _normalize_openai_content(content):
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if text is not None:
|
||||
parts.append(str(text))
|
||||
continue
|
||||
if "content" in item:
|
||||
parts.append(_normalize_openai_content(item.get("content")))
|
||||
continue
|
||||
parts.append(str(item))
|
||||
return "".join(parts)
|
||||
if isinstance(content, dict):
|
||||
if "text" in content:
|
||||
return _normalize_openai_content(content.get("text"))
|
||||
if "content" in content:
|
||||
return _normalize_openai_content(content.get("content"))
|
||||
return str(content)
|
||||
|
||||
|
||||
def _build_langchain_messages(messages):
|
||||
normalized = []
|
||||
if isinstance(messages, str):
|
||||
normalized.append(HumanMessage(content=messages))
|
||||
return normalized
|
||||
if not isinstance(messages, list):
|
||||
return normalized
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
role = str(msg.get("role", "user")).strip().lower()
|
||||
content = _normalize_openai_content(msg.get("content"))
|
||||
if content is None:
|
||||
content = ""
|
||||
if role == "system":
|
||||
normalized.append(SystemMessage(content=content))
|
||||
elif role == "assistant":
|
||||
normalized.append(AIMessage(content=content))
|
||||
else:
|
||||
normalized.append(HumanMessage(content=content))
|
||||
return normalized
|
||||
|
||||
|
||||
def _safe_text_from_chunk(chunk):
|
||||
if chunk is None:
|
||||
return ""
|
||||
value = getattr(chunk, "content", "")
|
||||
return _normalize_openai_content(value)
|
||||
|
||||
|
||||
|
||||
def _build_embedding_url(base_url: str) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
url = base_url.rstrip("/")
|
||||
if url.endswith("/v1/embeddings") or url.endswith("/embeddings"):
|
||||
return url
|
||||
@@ -123,9 +188,20 @@ def _build_embedding_url(base_url: str) -> str:
|
||||
return url[:-len("/v1/chat/completions")] + "/v1/embeddings"
|
||||
if url.endswith("/chat/completions"):
|
||||
return url[:-len("/chat/completions")] + "/embeddings"
|
||||
if url.endswith("/v1"):
|
||||
return url + "/embeddings"
|
||||
return url + "/v1/embeddings"
|
||||
if url.endswith("/v1"):
|
||||
return url + "/embeddings"
|
||||
return url + "/v1/embeddings"
|
||||
|
||||
|
||||
def _build_langchain_base_url(base_url: str) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
url = base_url.rstrip("/")
|
||||
if url.endswith("/v1/chat/completions"):
|
||||
return url[:-len("/chat/completions")]
|
||||
if url.endswith("/chat/completions"):
|
||||
return url[:-len("/chat/completions")]
|
||||
return url
|
||||
|
||||
@__app.route('/api/submit', methods=['post'])
|
||||
def api_submit():
|
||||
@@ -371,61 +447,111 @@ def api_send_v1_chat_completions():
|
||||
if not data:
|
||||
return jsonify({'error': 'missing request body'})
|
||||
try:
|
||||
model = data.get('model', 'fay')
|
||||
if model == 'llm':
|
||||
try:
|
||||
config_util.load_config()
|
||||
llm_url = _build_llm_url(config_util.gpt_base_url)
|
||||
api_key = config_util.key_gpt_api_key
|
||||
model_engine = config_util.gpt_model_engine
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM config load failed: {exc}'}), 500
|
||||
|
||||
if not llm_url:
|
||||
return jsonify({'error': 'LLM base_url is not configured'}), 500
|
||||
|
||||
payload = dict(data)
|
||||
if payload.get('model') == 'llm' and model_engine:
|
||||
payload['model'] = model_engine
|
||||
|
||||
stream_requested = _as_bool(payload.get('stream', False))
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
|
||||
try:
|
||||
if stream_requested:
|
||||
resp = requests.post(llm_url, headers=headers, json=payload, stream=True)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
if not chunk:
|
||||
continue
|
||||
yield chunk
|
||||
finally:
|
||||
resp.close()
|
||||
|
||||
content_type = resp.headers.get("Content-Type", "text/event-stream")
|
||||
if "charset=" not in content_type.lower():
|
||||
content_type = f"{content_type}; charset=utf-8"
|
||||
return Response(
|
||||
generate(),
|
||||
status=resp.status_code,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
resp = requests.post(llm_url, headers=headers, json=payload, timeout=60)
|
||||
content_type = resp.headers.get("Content-Type", "application/json")
|
||||
if "charset=" not in content_type.lower():
|
||||
content_type = f"{content_type}; charset=utf-8"
|
||||
return Response(
|
||||
resp.content,
|
||||
status=resp.status_code,
|
||||
content_type=content_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM request failed: {exc}'}), 500
|
||||
model = data.get('model', 'fay')
|
||||
if model == 'llm':
|
||||
if ChatOpenAI is None or HumanMessage is None:
|
||||
return jsonify({'error': 'langchain_openai or langchain_core is not available'}), 500
|
||||
try:
|
||||
config_util.load_config()
|
||||
api_key = config_util.key_gpt_api_key
|
||||
model_engine = config_util.gpt_model_engine
|
||||
base_url = _build_langchain_base_url(config_util.gpt_base_url)
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM config load failed: {exc}'}), 500
|
||||
|
||||
if not base_url:
|
||||
return jsonify({'error': 'LLM base_url is not configured'}), 500
|
||||
|
||||
payload = dict(data)
|
||||
stream_requested = _as_bool(payload.get('stream', False))
|
||||
model_name = model_engine or payload.get('model')
|
||||
lc_messages = _build_langchain_messages(payload.get('messages', []))
|
||||
if not lc_messages:
|
||||
return jsonify({'error': 'messages is required'}), 400
|
||||
|
||||
llm_kwargs = {
|
||||
"model": model_name,
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
"streaming": bool(stream_requested),
|
||||
}
|
||||
if payload.get("temperature") is not None:
|
||||
llm_kwargs["temperature"] = payload.get("temperature")
|
||||
if payload.get("max_tokens") is not None:
|
||||
llm_kwargs["max_tokens"] = payload.get("max_tokens")
|
||||
model_kwargs = {}
|
||||
if payload.get("top_p") is not None:
|
||||
model_kwargs["top_p"] = payload.get("top_p")
|
||||
if model_kwargs:
|
||||
llm_kwargs["model_kwargs"] = model_kwargs
|
||||
|
||||
try:
|
||||
llm_client = ChatOpenAI(**llm_kwargs)
|
||||
run_cfg = {
|
||||
"tags": ["fay", "api", "model-llm"],
|
||||
"metadata": {"entrypoint": "api_send_v1_chat_completions", "model_alias": "llm"},
|
||||
}
|
||||
if stream_requested:
|
||||
stream_id = "chatcmpl-" + str(uuid.uuid4())
|
||||
def generate():
|
||||
try:
|
||||
for chunk in llm_client.stream(lc_messages, config=run_cfg):
|
||||
text_piece = _safe_text_from_chunk(chunk)
|
||||
if text_piece is None or text_piece == "":
|
||||
continue
|
||||
message = {
|
||||
"id": stream_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"delta": {"content": text_piece},
|
||||
"index": 0,
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(message, ensure_ascii=False)}\n\n"
|
||||
final_message = {
|
||||
"id": stream_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"delta": {},
|
||||
"index": 0,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(final_message, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
finally:
|
||||
pass
|
||||
return Response(generate(), content_type="text/event-stream; charset=utf-8")
|
||||
|
||||
ai_resp = llm_client.invoke(lc_messages, config=run_cfg)
|
||||
answer_text = _normalize_openai_content(getattr(ai_resp, "content", ""))
|
||||
return jsonify({
|
||||
"id": "chatcmpl-" + str(uuid.uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": answer_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
})
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM request failed: {exc}'}), 500
|
||||
|
||||
last_content = ""
|
||||
username = "User"
|
||||
@@ -1099,31 +1225,97 @@ def api_toggle_microphone():
|
||||
|
||||
|
||||
#消息透传接口
|
||||
@__app.route('/transparent-pass', methods=['post'])
|
||||
def transparent_pass():
|
||||
try:
|
||||
data = request.form.get('data')
|
||||
if data is None:
|
||||
data = request.get_json()
|
||||
else:
|
||||
data = json.loads(data)
|
||||
username = data.get('user', 'User')
|
||||
response_text = data.get('text', None)
|
||||
audio_url = data.get('audio', None)
|
||||
if response_text or audio_url:
|
||||
# 新消息到达,立即中断该用户之前的所有处理(文本流+音频队列)
|
||||
util.printInfo(1, username, f'[API中断] 新消息到达,完整中断用户 {username} 之前的所有处理')
|
||||
util.printInfo(1, username, f'[API中断] 用户 {username} 的文本流和音频队列已清空,准备处理新消息')
|
||||
interact = Interact('transparent_pass', 2, {'user': username, 'text': response_text, 'audio': audio_url, 'isend':True, 'isfirst':True})
|
||||
util.printInfo(1, username, '透传播放:{},{}'.format(response_text, audio_url), time.time())
|
||||
success = fay_booter.feiFei.on_interact(interact)
|
||||
if (success == 'success'):
|
||||
return jsonify({'code': 200, 'message' : '成功'})
|
||||
return jsonify({'code': 500, 'message' : '未知原因出错'})
|
||||
except Exception as e:
|
||||
return jsonify({'code': 500, 'message': f'出错: {e}'}), 500
|
||||
|
||||
# 清除记忆API
|
||||
@__app.route('/transparent-pass', methods=['post'])
|
||||
def transparent_pass():
|
||||
try:
|
||||
data = request.form.get('data')
|
||||
if data is None:
|
||||
data = request.get_json(silent=True) or {}
|
||||
else:
|
||||
data = json.loads(data)
|
||||
|
||||
if isinstance(data, dict):
|
||||
nested_data = data.get('data')
|
||||
if isinstance(nested_data, dict):
|
||||
data = nested_data
|
||||
elif isinstance(nested_data, str):
|
||||
nested_data = nested_data.strip()
|
||||
if nested_data:
|
||||
try:
|
||||
data = json.loads(nested_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
|
||||
username = data.get('user', 'User')
|
||||
response_text = data.get('text', None)
|
||||
audio_url = data.get('audio', None)
|
||||
if isinstance(audio_url, str):
|
||||
audio_url = audio_url.strip()
|
||||
if audio_url:
|
||||
parsed_audio = urlparse(audio_url)
|
||||
if not parsed_audio.scheme:
|
||||
if audio_url.startswith('//'):
|
||||
audio_url = 'http:' + audio_url
|
||||
else:
|
||||
base_url = ''
|
||||
origin = (request.headers.get('Origin') or '').strip()
|
||||
referer = (request.headers.get('Referer') or '').strip()
|
||||
if origin:
|
||||
parsed_origin = urlparse(origin)
|
||||
if parsed_origin.scheme and parsed_origin.netloc:
|
||||
base_url = f'{parsed_origin.scheme}://{parsed_origin.netloc}/'
|
||||
if (not base_url) and referer:
|
||||
parsed_referer = urlparse(referer)
|
||||
if parsed_referer.scheme and parsed_referer.netloc:
|
||||
base_url = f'{parsed_referer.scheme}://{parsed_referer.netloc}/'
|
||||
if not base_url:
|
||||
base_url = request.host_url
|
||||
audio_url = urljoin(base_url, audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
|
||||
queue_mode = _as_bool(data.get('queue', False))
|
||||
if not queue_mode:
|
||||
queue_mode = _as_bool(data.get('queue_playback', data.get('enqueue', False)))
|
||||
if not queue_mode:
|
||||
queue_mode = str(data.get('mode', '')).strip().lower() == 'queue'
|
||||
if not queue_mode:
|
||||
queue_mode = _as_bool(data.get('qutue', False))
|
||||
|
||||
if response_text or audio_url:
|
||||
if queue_mode:
|
||||
interact = Interact('transparent_pass', 2, {
|
||||
'user': username,
|
||||
'text': response_text,
|
||||
'audio': audio_url,
|
||||
'isend': True,
|
||||
'isfirst': True,
|
||||
'no_reply': True,
|
||||
'queue': True,
|
||||
'queue_playback': True
|
||||
})
|
||||
else:
|
||||
util.printInfo(1, username, f'[\u0041\u0050\u0049\u4e2d\u65ad] \u65b0\u6d88\u606f\u5230\u8fbe\uff0c\u5b8c\u6574\u4e2d\u65ad\u7528\u6237 {username} \u4e4b\u524d\u7684\u6240\u6709\u5904\u7406')
|
||||
util.printInfo(1, username, f'[\u0041\u0050\u0049\u4e2d\u65ad] \u7528\u6237 {username} \u7684\u6587\u672c\u6d41\u548c\u97f3\u9891\u961f\u5217\u5df2\u6e05\u7a7a\uff0c\u51c6\u5907\u5904\u7406\u65b0\u6d88\u606f')
|
||||
interact = Interact('transparent_pass', 2, {
|
||||
'user': username,
|
||||
'text': response_text,
|
||||
'audio': audio_url,
|
||||
'isend': True,
|
||||
'isfirst': True
|
||||
})
|
||||
|
||||
util.printInfo(1, username, '\u900f\u4f20\u64ad\u653e\uff1a{},{}'.format(response_text, audio_url), time.time())
|
||||
success = fay_booter.feiFei.on_interact(interact)
|
||||
if success == 'success':
|
||||
return jsonify({'code': 200, 'message': '\u6210\u529f'})
|
||||
|
||||
return jsonify({'code': 500, 'message': '\u672a\u77e5\u539f\u56e0\u51fa\u9519'})
|
||||
except Exception as e:
|
||||
return jsonify({'code': 500, 'message': f'\u51fa\u9519: {e}'}), 500
|
||||
@__app.route('/api/clear-memory', methods=['POST'])
|
||||
def api_clear_memory():
|
||||
try:
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 8.6 KiB After Width: | Height: | Size: 8.2 KiB |
8
main.py
8
main.py
@@ -2,7 +2,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.environ['PATH'] += os.pathsep + os.path.join(os.getcwd(), "test", "ovr_lipsync", "ffmpeg", "bin")
|
||||
def _resolve_runtime_dir():
|
||||
if hasattr(sys, "_MEIPASS"):
|
||||
return os.path.abspath(sys._MEIPASS)
|
||||
return os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
_RUNTIME_DIR = _resolve_runtime_dir()
|
||||
os.environ['PATH'] += os.pathsep + os.path.join(_RUNTIME_DIR, "test", "ovr_lipsync", "ffmpeg", "bin")
|
||||
|
||||
def _preload_config_center(argv):
|
||||
for i, arg in enumerate(argv):
|
||||
|
||||
@@ -1,35 +1,35 @@
|
||||
requests
|
||||
numpy
|
||||
pyaudio~=0.2.11
|
||||
websockets~=10.4
|
||||
ws4py~=0.5.1
|
||||
PyQt5==5.15.10
|
||||
PyQt5-sip==12.13.0
|
||||
PyQtWebEngine==5.15.6
|
||||
flask~=3.0.0
|
||||
openpyxl~=3.0.9
|
||||
flask_cors~=3.0.10
|
||||
websocket-client
|
||||
azure-cognitiveservices-speech
|
||||
aliyun-python-sdk-core
|
||||
simhash
|
||||
pytz
|
||||
gevent
|
||||
edge_tts
|
||||
pydub
|
||||
tenacity==8.2.3
|
||||
pygame
|
||||
scipy
|
||||
flask-httpauth
|
||||
opencv-python
|
||||
psutil
|
||||
langchain
|
||||
langchain_openai
|
||||
langgraph
|
||||
bs4
|
||||
schedule
|
||||
mcp
|
||||
python-docx
|
||||
python-pptx
|
||||
chromadb
|
||||
requests
|
||||
numpy
|
||||
pyaudio~=0.2.11
|
||||
websockets~=10.4
|
||||
ws4py~=0.5.1
|
||||
#PyQt5==5.15.10
|
||||
#PyQt5-sip==12.13.0
|
||||
#PyQtWebEngine==5.15.6
|
||||
flask~=3.0.0
|
||||
openpyxl~=3.0.9
|
||||
flask_cors~=3.0.10
|
||||
websocket-client
|
||||
azure-cognitiveservices-speech
|
||||
aliyun-python-sdk-core
|
||||
simhash
|
||||
pytz
|
||||
gevent
|
||||
edge_tts
|
||||
pydub
|
||||
tenacity==8.2.3
|
||||
pygame
|
||||
scipy
|
||||
flask-httpauth
|
||||
opencv-python
|
||||
psutil
|
||||
langchain
|
||||
langchain_openai
|
||||
langgraph
|
||||
bs4
|
||||
schedule
|
||||
mcp
|
||||
python-docx
|
||||
python-pptx
|
||||
chromadb
|
||||
sentence_transformers
|
||||
@@ -1,91 +1,95 @@
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
os.environ['PATH'] += os.pathsep + os.path.join(os.getcwd(), "test", "ovr_lipsync", "ffmpeg", "bin")
|
||||
import sys
|
||||
_RUNTIME_DIR = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
if hasattr(sys, "_MEIPASS"):
|
||||
_RUNTIME_DIR = os.path.abspath(sys._MEIPASS)
|
||||
os.environ['PATH'] += os.pathsep + os.path.join(_RUNTIME_DIR, "test", "ovr_lipsync", "ffmpeg", "bin")
|
||||
from pydub import AudioSegment
|
||||
import json
|
||||
|
||||
def list_files(dir_path):
|
||||
for root, dirs, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
print(os.path.join(root, file))
|
||||
|
||||
class LipSyncGenerator:
|
||||
|
||||
def list_files(dir_path):
|
||||
for root, dirs, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
print(os.path.join(root, file))
|
||||
|
||||
class LipSyncGenerator:
|
||||
def __init__(self):
|
||||
self.viseme_em = [
|
||||
"sil", "PP", "FF", "TH", "DD",
|
||||
"kk", "CH", "SS", "nn", "RR",
|
||||
"aa", "E", "ih", "oh", "ou"]
|
||||
self.viseme = []
|
||||
self.exe_path = os.path.join(os.getcwd(), "test", "ovr_lipsync", "ovr_lipsync_exe", "ProcessWAV.exe")
|
||||
|
||||
def run_exe_and_get_output(self, arguments):
|
||||
process = subprocess.Popen([self.exe_path] + arguments, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
while True:
|
||||
output = process.stdout.readline()
|
||||
if output == b'' and process.poll() is not None:
|
||||
break
|
||||
if output:
|
||||
self.viseme.append(output.strip().decode())
|
||||
rc = process.poll()
|
||||
return rc
|
||||
|
||||
def filter(self, viseme):
|
||||
new_viseme = []
|
||||
for v in self.viseme:
|
||||
if v in self.viseme_em:
|
||||
new_viseme.append(v)
|
||||
return new_viseme
|
||||
|
||||
def generate_visemes(self, wav_filepath):
|
||||
if wav_filepath.endswith(".mp3"):
|
||||
wav_filepath = self.convert_mp3_to_wav(wav_filepath)
|
||||
arguments = ["--print-viseme-name", wav_filepath]
|
||||
self.run_exe_and_get_output(arguments)
|
||||
|
||||
return self.filter(self.viseme)
|
||||
|
||||
def consolidate_visemes(self, viseme_list):
|
||||
if not viseme_list:
|
||||
return []
|
||||
|
||||
result = []
|
||||
current_viseme = viseme_list[0]
|
||||
count = 1
|
||||
|
||||
for viseme in viseme_list[1:]:
|
||||
if viseme == current_viseme:
|
||||
count += 1
|
||||
else:
|
||||
result.append({"Lip": current_viseme, "Time": count*33}) # Multiply by 10 for duration in ms
|
||||
current_viseme = viseme
|
||||
count = 1
|
||||
|
||||
# Add the last viseme to the result
|
||||
result.append({"Lip": current_viseme, "Time": count*33}) # Multiply by 10 for duration in ms
|
||||
|
||||
new_data = []
|
||||
for i in range(len(result)):
|
||||
if result[i]['Time'] < 30:
|
||||
if len(new_data) > 0:
|
||||
new_data[-1]['Time'] += result[i]['Time']
|
||||
else:
|
||||
new_data.append(result[i])
|
||||
return new_data
|
||||
|
||||
def convert_mp3_to_wav(self, mp3_filepath):
|
||||
audio = AudioSegment.from_mp3(mp3_filepath)
|
||||
# 使用 set_frame_rate 方法设置采样率
|
||||
audio = audio.set_frame_rate(44100)
|
||||
wav_filepath = mp3_filepath.rsplit(".", 1)[0] + ".wav"
|
||||
audio.export(wav_filepath, format="wav")
|
||||
return wav_filepath
|
||||
if __name__ == "__main__":
|
||||
start_time = time.time()
|
||||
lip_sync_generator = LipSyncGenerator()
|
||||
viseme_list = lip_sync_generator.generate_visemes("E:\\github\\Fay\\samples\\fay-man.wav")
|
||||
print(viseme_list)
|
||||
consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
|
||||
print(json.dumps(consolidated_visemes))
|
||||
print(time.time() - start_time)
|
||||
self.exe_path = os.path.join(_RUNTIME_DIR, "test", "ovr_lipsync", "ovr_lipsync_exe", "ProcessWAV.exe")
|
||||
|
||||
def run_exe_and_get_output(self, arguments):
|
||||
process = subprocess.Popen([self.exe_path] + arguments, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
while True:
|
||||
output = process.stdout.readline()
|
||||
if output == b'' and process.poll() is not None:
|
||||
break
|
||||
if output:
|
||||
self.viseme.append(output.strip().decode())
|
||||
rc = process.poll()
|
||||
return rc
|
||||
|
||||
def filter(self, viseme):
|
||||
new_viseme = []
|
||||
for v in self.viseme:
|
||||
if v in self.viseme_em:
|
||||
new_viseme.append(v)
|
||||
return new_viseme
|
||||
|
||||
def generate_visemes(self, wav_filepath):
|
||||
if wav_filepath.endswith(".mp3"):
|
||||
wav_filepath = self.convert_mp3_to_wav(wav_filepath)
|
||||
arguments = ["--print-viseme-name", wav_filepath]
|
||||
self.run_exe_and_get_output(arguments)
|
||||
|
||||
return self.filter(self.viseme)
|
||||
|
||||
def consolidate_visemes(self, viseme_list):
|
||||
if not viseme_list:
|
||||
return []
|
||||
|
||||
result = []
|
||||
current_viseme = viseme_list[0]
|
||||
count = 1
|
||||
|
||||
for viseme in viseme_list[1:]:
|
||||
if viseme == current_viseme:
|
||||
count += 1
|
||||
else:
|
||||
result.append({"Lip": current_viseme, "Time": count*33}) # Multiply by 10 for duration in ms
|
||||
current_viseme = viseme
|
||||
count = 1
|
||||
|
||||
# Add the last viseme to the result
|
||||
result.append({"Lip": current_viseme, "Time": count*33}) # Multiply by 10 for duration in ms
|
||||
|
||||
new_data = []
|
||||
for i in range(len(result)):
|
||||
if result[i]['Time'] < 30:
|
||||
if len(new_data) > 0:
|
||||
new_data[-1]['Time'] += result[i]['Time']
|
||||
else:
|
||||
new_data.append(result[i])
|
||||
return new_data
|
||||
|
||||
def convert_mp3_to_wav(self, mp3_filepath):
|
||||
audio = AudioSegment.from_mp3(mp3_filepath)
|
||||
# 使用 set_frame_rate 方法设置采样率
|
||||
audio = audio.set_frame_rate(44100)
|
||||
wav_filepath = mp3_filepath.rsplit(".", 1)[0] + ".wav"
|
||||
audio.export(wav_filepath, format="wav")
|
||||
return wav_filepath
|
||||
if __name__ == "__main__":
|
||||
start_time = time.time()
|
||||
lip_sync_generator = LipSyncGenerator()
|
||||
viseme_list = lip_sync_generator.generate_visemes("E:\\github\\Fay\\samples\\fay-man.wav")
|
||||
print(viseme_list)
|
||||
consolidated_visemes = lip_sync_generator.consolidate_visemes(viseme_list)
|
||||
print(json.dumps(consolidated_visemes))
|
||||
print(time.time() - start_time)
|
||||
|
||||
@@ -278,11 +278,11 @@ def load_config(force_reload=False):
|
||||
root_config_json_exists = os.path.exists(default_config_json_path)
|
||||
root_config_complete = root_system_conf_exists and root_config_json_exists
|
||||
|
||||
# 构建system.conf和config.json的完整路径
|
||||
config_center_fallback = False
|
||||
if using_config_center:
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
# 构建system.conf和config.json相关路径.
|
||||
config_center_fallback = False
|
||||
if using_config_center:
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
if (
|
||||
system_conf_path is None
|
||||
@@ -308,29 +308,34 @@ def load_config(force_reload=False):
|
||||
loaded_from_api = False
|
||||
api_attempted = False
|
||||
if using_config_center:
|
||||
if explicit_config_center:
|
||||
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
else:
|
||||
util.log(1, f"未检测到本地system.conf或config.json,尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
if explicit_config_center:
|
||||
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
else:
|
||||
util.log(1, f"未检测到本地system.conf或config.json,尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
api_attempted = True
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
if config_center_fallback:
|
||||
_bootstrap_loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
forced_loaded = True
|
||||
# 将配置中心配置缓存到本地文件.
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(
|
||||
api_config,
|
||||
system_conf_path,
|
||||
config_json_path,
|
||||
save_config_json=not os.path.exists(config_json_path)
|
||||
)
|
||||
forced_loaded = True
|
||||
|
||||
_warn_public_config_once()
|
||||
else:
|
||||
util.log(2, "配置中心加载失败,尝试使用缓存配置")
|
||||
util.log(2, "配置中心加载失败,尝试使用缓存配置")
|
||||
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
@@ -339,38 +344,48 @@ def load_config(force_reload=False):
|
||||
if (not sys_conf_exists or not config_json_exists) and not forced_loaded:
|
||||
if using_config_center:
|
||||
if not api_attempted:
|
||||
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
|
||||
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
api_attempted = True
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
if config_center_fallback:
|
||||
_bootstrap_loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
# 将配置中心配置缓存到本地文件.
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(
|
||||
api_config,
|
||||
system_conf_path,
|
||||
config_json_path,
|
||||
save_config_json=not os.path.exists(config_json_path)
|
||||
)
|
||||
|
||||
_warn_public_config_once()
|
||||
else:
|
||||
# 使用提取的项目ID或全局项目ID
|
||||
util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'和' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在),尝试从API加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
# 使用项目配置或全局项目配置作为回退来源.
|
||||
util.log(1, f"本地配置文件不完整,尝试从API加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
# 将配置中心配置缓存到本地文件.
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(
|
||||
api_config,
|
||||
system_conf_path,
|
||||
config_json_path,
|
||||
save_config_json=not os.path.exists(config_json_path)
|
||||
)
|
||||
|
||||
_warn_public_config_once()
|
||||
|
||||
@@ -378,17 +393,17 @@ def load_config(force_reload=False):
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
if using_config_center and (not sys_conf_exists or not config_json_exists):
|
||||
if _last_loaded_config is not None and _last_loaded_from_api:
|
||||
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
|
||||
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
|
||||
return _last_loaded_config
|
||||
if config_center_fallback and using_config_center and (not sys_conf_exists or not config_json_exists):
|
||||
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
|
||||
if cache_ready:
|
||||
util.log(2, "配置中心不可用,回退使用缓存配置")
|
||||
util.log(2, "配置中心不可用,回退使用缓存配置")
|
||||
using_config_center = False
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
util.log(2, "配置中心不可用,回退使用本地配置文件")
|
||||
util.log(2, "配置中心不可用,回退使用本地配置文件")
|
||||
using_config_center = False
|
||||
system_conf_path = default_system_conf_path
|
||||
config_json_path = default_config_json_path
|
||||
@@ -497,31 +512,33 @@ def load_config(force_reload=False):
|
||||
|
||||
return config_dict
|
||||
|
||||
def save_api_config_to_local(api_config, system_conf_path, config_json_path):
|
||||
"""
|
||||
将API加载的配置保存到本地文件
|
||||
def save_api_config_to_local(api_config, system_conf_path, config_json_path, save_config_json=True):
|
||||
"""
|
||||
Persist API config to local files.
|
||||
|
||||
Args:
|
||||
api_config: API加载的配置字典
|
||||
system_conf_path: system.conf文件路径
|
||||
config_json_path: config.json文件路径
|
||||
Args:
|
||||
api_config: API response dict.
|
||||
system_conf_path: Path to system.conf.
|
||||
config_json_path: Path to config.json.
|
||||
save_config_json: Whether to write config.json.
|
||||
"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(system_conf_path), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(config_json_path), exist_ok=True)
|
||||
|
||||
# 保存system.conf
|
||||
with open(system_conf_path, 'w', encoding='utf-8') as f:
|
||||
api_config['system_config'].write(f)
|
||||
|
||||
# 保存config.json
|
||||
with codecs.open(config_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(api_config['config'], f, ensure_ascii=False, indent=4)
|
||||
|
||||
util.log(1, f"已将配置中心配置缓存到本地文件: {system_conf_path} 和 {config_json_path}")
|
||||
except Exception as e:
|
||||
util.log(2, f"保存配置中心配置缓存到本地文件时出错: {str(e)}")
|
||||
# 确保目录存在.
|
||||
os.makedirs(os.path.dirname(system_conf_path), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(config_json_path), exist_ok=True)
|
||||
|
||||
# 始终刷新 system.conf.
|
||||
with open(system_conf_path, 'w', encoding='utf-8') as f:
|
||||
api_config['system_config'].write(f)
|
||||
|
||||
# 默认只在首次下载时保存config.json.
|
||||
if save_config_json:
|
||||
with codecs.open(config_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(api_config['config'], f, ensure_ascii=False, indent=4)
|
||||
|
||||
util.log(1, f"已将配置中心配置缓存到本地文件: {system_conf_path} 和 {config_json_path}")
|
||||
except Exception as e:
|
||||
util.log(2, f"保存配置中心配置到本地文件时出错: {str(e)}")
|
||||
|
||||
@synchronized
|
||||
def save_config(config_data):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import time
|
||||
from utils import util
|
||||
from core import stream_manager
|
||||
@@ -27,6 +28,20 @@ class StreamTextProcessor:
|
||||
self.max_cache_size = max_cache_size
|
||||
# 常用中英文分句标点(UTF-8)
|
||||
self.punctuation_marks = [",", "。", ";", ":", "、", "!", "?", ".", "!", "?", "\n"]
|
||||
self.punctuation_mark_set = set(self.punctuation_marks)
|
||||
self.url_regex_list = [
|
||||
re.compile(
|
||||
r"(?i)\b(?:https?://|ftp://|file://|www\.)[^\s\u3002\uff01\uff1f\u3001\uff0c\uff1b\uff1a<>'\"\[\]\(\)\{\}]+"
|
||||
),
|
||||
re.compile(
|
||||
r"(?i)\b[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?"
|
||||
r"(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+"
|
||||
r"(?::\d{2,5})?"
|
||||
r"(?:/[^\s\u3002\uff01\uff1f\u3001\uff0c\uff1b\uff1a<>'\"\[\]\(\)\{\}]*)?"
|
||||
r"(?:\?[^\s\u3002\uff01\uff1f\u3001\uff0c\uff1b\uff1a<>'\"\[\]\(\)\{\}]*)?"
|
||||
r"(?:#[^\s\u3002\uff01\uff1f\u3001\uff0c\uff1b\uff1a<>'\"\[\]\(\)\{\}]*)?"
|
||||
),
|
||||
]
|
||||
|
||||
def process_stream_text(self, text, username, is_qa=False, session_type="stream"):
|
||||
"""
|
||||
@@ -87,7 +102,7 @@ class StreamTextProcessor:
|
||||
|
||||
# 动态缓存大小检查
|
||||
if len(accumulated_text) > self.max_cache_size:
|
||||
util.log(1, f"处理过程中缓存溢出,强制发送剩余文本")
|
||||
util.log(1, "处理过程中缓存溢出,强制发送剩余文本")
|
||||
break
|
||||
|
||||
iteration_count += 1
|
||||
@@ -178,20 +193,95 @@ class StreamTextProcessor:
|
||||
"""
|
||||
try:
|
||||
indices = []
|
||||
for punct in self.punctuation_marks:
|
||||
try:
|
||||
index = text.find(punct)
|
||||
if index != -1:
|
||||
indices.append(index)
|
||||
except Exception as e:
|
||||
util.log(1, f"查找标点符号 '{punct}' 时出错: {str(e)}")
|
||||
url_spans = self._find_url_spans(text)
|
||||
for index, ch in enumerate(text):
|
||||
if ch not in self.punctuation_mark_set:
|
||||
continue
|
||||
|
||||
return sorted([i for i in indices if i != -1])
|
||||
if self._is_in_spans(index, url_spans):
|
||||
continue
|
||||
if ch == "." and self._is_protected_dot(text, index):
|
||||
continue
|
||||
indices.append(index)
|
||||
return indices
|
||||
except Exception as e:
|
||||
util.log(1, f"查找标点符号时出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def _find_url_spans(self, text):
|
||||
spans = []
|
||||
for regex in self.url_regex_list:
|
||||
for match in regex.finditer(text):
|
||||
start, end = match.span()
|
||||
if end > start:
|
||||
spans.append((start, end))
|
||||
if not spans:
|
||||
return []
|
||||
spans.sort(key=lambda item: item[0])
|
||||
merged = [spans[0]]
|
||||
for start, end in spans[1:]:
|
||||
last_start, last_end = merged[-1]
|
||||
if start <= last_end:
|
||||
merged[-1] = (last_start, max(last_end, end))
|
||||
else:
|
||||
merged.append((start, end))
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _is_in_spans(index, spans):
|
||||
for start, end in spans:
|
||||
if start <= index < end:
|
||||
return True
|
||||
if index < start:
|
||||
break
|
||||
return False
|
||||
|
||||
def _is_protected_dot(self, text, index):
|
||||
"""
|
||||
判断英文句点是否位于不应截断的 token 内(如版本号、URL、域名等)。
|
||||
"""
|
||||
try:
|
||||
prev_char = text[index - 1] if index > 0 else ""
|
||||
next_char = text[index + 1] if (index + 1) < len(text) else ""
|
||||
|
||||
# 数字点位:1.2 / 3.14 / 192.168.1.1
|
||||
if prev_char.isdigit() and next_char.isdigit():
|
||||
return True
|
||||
|
||||
token = self._extract_token_around(text, index).lower()
|
||||
if not token:
|
||||
return False
|
||||
|
||||
# URL 写法
|
||||
if "://" in token or token.startswith("www."):
|
||||
return True
|
||||
|
||||
# 版本号写法:v1.2.3 / 1.2.3
|
||||
if re.match(r"^[a-z]*\d+\.\d+(\.\d+)*[a-z]*$", token):
|
||||
return True
|
||||
|
||||
# 域名/主机名(可带路径)
|
||||
if re.match(r"^[a-z0-9-]+(\.[a-z0-9-]+)+(/\S*)?$", token):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _extract_token_around(self, text, index):
|
||||
"""
|
||||
提取句点所在连续 token(按空白与中英文标点分隔)。
|
||||
"""
|
||||
separators = set(" \t\r\n\"'()[]{}<>,!?;:" + "\uFF0C\u3002\uFF01\uFF1F\uFF1B\uFF1A\u3001")
|
||||
left = index
|
||||
right = index + 1
|
||||
|
||||
while left > 0 and text[left - 1] not in separators:
|
||||
left -= 1
|
||||
while right < len(text) and text[right] not in separators:
|
||||
right += 1
|
||||
|
||||
return text[left:right]
|
||||
|
||||
def _send_fallback_text(self, text, username, state_manager, conversation_id):
|
||||
"""
|
||||
备用发送方案:直接发送完整文本(含首尾标记)
|
||||
|
||||
Reference in New Issue
Block a user