mirror of
https://github.com/xszyou/Fay.git
synced 2026-03-12 17:51:28 +08:00
fay支持mcp客户端连接,对外暴广播工具
This commit is contained in:
BIN
cache_data/tmpl6qo3gps.wav
Normal file
BIN
cache_data/tmpl6qo3gps.wav
Normal file
Binary file not shown.
@@ -1,42 +1,42 @@
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
import websockets
|
||||
import asyncio
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from websockets.legacy.server import Serve
|
||||
|
||||
from utils import util
|
||||
from scheduler.thread_manager import MyThread
|
||||
|
||||
class MyServer:
|
||||
def __init__(self, host='0.0.0.0', port=10000):
|
||||
self.lock = asyncio.Lock()
|
||||
self.__host = host # ip
|
||||
self.__port = port # 端口号
|
||||
self.__listCmd = [] # 要发送的信息的列表
|
||||
self.__clients = list()
|
||||
self.__server: Serve = None
|
||||
self.__event_loop: AbstractEventLoop = None
|
||||
self.__running = True
|
||||
self.__pending = None
|
||||
self.isConnect = False
|
||||
self.TIMEOUT = 3 # 设置任何超时时间为 3 秒
|
||||
self.__tasks = {} # 记录任务和开始时间的字典
|
||||
|
||||
# 接收处理
|
||||
async def __consumer_handler(self, websocket, path):
|
||||
username = None
|
||||
output_setting = None
|
||||
try:
|
||||
async for message in websocket:
|
||||
await asyncio.sleep(0.01)
|
||||
try:
|
||||
data = json.loads(message)
|
||||
username = data.get("Username")
|
||||
output_setting = data.get("Output")
|
||||
except json.JSONDecodeError:
|
||||
pass # Ignore invalid JSON messages
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
import websockets
|
||||
import asyncio
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from websockets.legacy.server import Serve
|
||||
|
||||
from utils import util
|
||||
from scheduler.thread_manager import MyThread
|
||||
|
||||
class MyServer:
|
||||
def __init__(self, host='0.0.0.0', port=10000):
|
||||
self.lock = asyncio.Lock()
|
||||
self.__host = host # ip
|
||||
self.__port = port # 端口号
|
||||
self.__listCmd = [] # 要发送的信息的列表
|
||||
self.__clients = list()
|
||||
self.__server: Serve = None
|
||||
self.__event_loop: AbstractEventLoop = None
|
||||
self.__running = True
|
||||
self.__pending = None
|
||||
self.isConnect = False
|
||||
self.TIMEOUT = 3 # 设置任何超时时间为 3 秒
|
||||
self.__tasks = {} # 记录任务和开始时间的字典
|
||||
|
||||
# 接收处理
|
||||
async def __consumer_handler(self, websocket, path):
|
||||
username = None
|
||||
output_setting = None
|
||||
try:
|
||||
async for message in websocket:
|
||||
await asyncio.sleep(0.01)
|
||||
try:
|
||||
data = json.loads(message)
|
||||
username = data.get("Username")
|
||||
output_setting = data.get("Output")
|
||||
except json.JSONDecodeError:
|
||||
pass # Ignore invalid JSON messages
|
||||
if username is not None or output_setting is not None:
|
||||
remote_address = websocket.remote_address
|
||||
unique_id = f"{remote_address[0]}:{remote_address[1]}"
|
||||
@@ -47,270 +47,271 @@ class MyServer:
|
||||
self.__clients[i]["username"] = username
|
||||
if output_setting is not None:
|
||||
self.__clients[i]["output"] = output_setting
|
||||
await self.__consumer(message)
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
# 从客户端列表中移除已断开的连接
|
||||
await self.remove_client(websocket)
|
||||
util.printInfo(1, "User" if username is None else username, f"WebSocket 连接关闭: {e}")
|
||||
|
||||
def get_client_output(self, username):
|
||||
clients_with_username = [c for c in self.__clients if c.get("username") == username]
|
||||
if not clients_with_username:
|
||||
return False
|
||||
for client in clients_with_username:
|
||||
# 获取output设置,支持布尔值、字符串布尔值、数字等多种格式
|
||||
output = client.get("output", True) # 默认为True,表示需要音频
|
||||
|
||||
# 处理不同类型的输入
|
||||
if isinstance(output, bool):
|
||||
if output: # 如果是True
|
||||
return True
|
||||
elif isinstance(output, str):
|
||||
if output.lower() == 'true': # 字符串"true"
|
||||
return True
|
||||
elif isinstance(output, (int, float)):
|
||||
if output != 0 and output != '0': # 0以外的数字
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# 发送处理
|
||||
async def __producer_handler(self, websocket, path):
|
||||
while self.__running:
|
||||
await asyncio.sleep(0.01)
|
||||
if len(self.__listCmd) > 0:
|
||||
message = await self.__producer()
|
||||
if message:
|
||||
username = json.loads(message).get("Username")
|
||||
if username is None:
|
||||
# 群发消息
|
||||
async with self.lock:
|
||||
wsclients = [c["websocket"] for c in self.__clients]
|
||||
tasks = [self.send_message_with_timeout(client, message, username, timeout=3) for client in wsclients]
|
||||
await asyncio.gather(*tasks)
|
||||
else:
|
||||
# 向指定用户发送消息
|
||||
async with self.lock:
|
||||
target_clients = [c["websocket"] for c in self.__clients if c.get("username") == username]
|
||||
tasks = [self.send_message_with_timeout(client, message, username, timeout=3) for client in target_clients]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# 发送消息(设置超时)
|
||||
async def send_message_with_timeout(self, client, message, username, timeout=3):
|
||||
try:
|
||||
await asyncio.wait_for(self.send_message(client, message, username), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
util.printInfo(1, "User" if username is None else username, f"发送消息超时: 用户名 {username}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
# 从客户端列表中移除已断开的连接
|
||||
await self.remove_client(client)
|
||||
util.printInfo(1, "User" if username is None else username, f"WebSocket 连接关闭: {e}")
|
||||
|
||||
# 发送消息
|
||||
async def send_message(self, client, message, username):
|
||||
try:
|
||||
await client.send(message)
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
# 从客户端列表中移除已断开的连接
|
||||
await self.remove_client(client)
|
||||
util.printInfo(1, "User" if username is None else username, f"WebSocket 连接关闭: {e}")
|
||||
|
||||
|
||||
async def __handler(self, websocket, path):
|
||||
self.isConnect = True
|
||||
util.log(1,"websocket连接上:{}".format(self.__port))
|
||||
self.on_connect_handler()
|
||||
remote_address = websocket.remote_address
|
||||
unique_id = f"{remote_address[0]}:{remote_address[1]}"
|
||||
async with self.lock:
|
||||
self.__clients.append({"id" : unique_id, "websocket" : websocket, "username" : "User"})
|
||||
consumer_task = asyncio.create_task(self.__consumer_handler(websocket, path))#接收
|
||||
producer_task = asyncio.create_task(self.__producer_handler(websocket, path))#发送
|
||||
done, self.__pending = await asyncio.wait([consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in self.__pending:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 从客户端列表中移除已断开的连接
|
||||
await self.remove_client(websocket)
|
||||
util.log(1, "websocket连接断开:{}".format(unique_id))
|
||||
|
||||
async def __consumer(self, message):
|
||||
self.on_revice_handler(message)
|
||||
|
||||
async def __producer(self):
|
||||
if len(self.__listCmd) > 0:
|
||||
message = self.on_send_handler(self.__listCmd.pop(0))
|
||||
return message
|
||||
else:
|
||||
return None
|
||||
|
||||
async def remove_client(self, websocket):
|
||||
async with self.lock:
|
||||
self.__clients = [c for c in self.__clients if c["websocket"] != websocket]
|
||||
if len(self.__clients) == 0:
|
||||
self.isConnect = False
|
||||
self.on_close_handler()
|
||||
|
||||
def is_connected(self, username):
|
||||
if username is None:
|
||||
username = "User"
|
||||
if len(self.__clients) == 0:
|
||||
return False
|
||||
clients = [c for c in self.__clients if c["username"] == username]
|
||||
if len(clients) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
#Edit by xszyou on 20230113:通过继承此类来实现服务端的接收后处理逻辑
|
||||
@abstractmethod
|
||||
def on_revice_handler(self, message):
|
||||
pass
|
||||
|
||||
#Edit by xszyou on 20230114:通过继承此类来实现服务端的连接处理逻辑
|
||||
@abstractmethod
|
||||
def on_connect_handler(self):
|
||||
pass
|
||||
|
||||
#Edit by xszyou on 20230804:通过继承此类来实现服务端的发送前的处理逻辑
|
||||
@abstractmethod
|
||||
def on_send_handler(self, message):
|
||||
return message
|
||||
|
||||
#Edit by xszyou on 20230816:通过继承此类来实现服务端的断开后的处理逻辑
|
||||
@abstractmethod
|
||||
def on_close_handler(self):
|
||||
pass
|
||||
|
||||
# 创建server
|
||||
def __connect(self):
|
||||
self.__event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.__event_loop)
|
||||
self.__isExecute = True
|
||||
if self.__server:
|
||||
util.log(1, 'server already exist')
|
||||
return
|
||||
self.__server = websockets.serve(self.__handler, self.__host, self.__port)
|
||||
asyncio.get_event_loop().run_until_complete(self.__server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
# 往要发送的命令列表中,添加命令
|
||||
await self.__consumer(message)
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
# 从客户端列表中移除已断开的连接
|
||||
await self.remove_client(websocket)
|
||||
util.printInfo(1, "User" if username is None else username, f"WebSocket 连接关闭: {e}")
|
||||
|
||||
def get_client_output(self, username):
|
||||
clients_with_username = [c for c in self.__clients if c.get("username") == username]
|
||||
if not clients_with_username:
|
||||
return False
|
||||
for client in clients_with_username:
|
||||
# 获取output设置,支持布尔值、字符串布尔值、数字等多种格式
|
||||
output = client.get("output", True) # 默认为True,表示需要音频
|
||||
|
||||
# 处理不同类型的输入
|
||||
if isinstance(output, bool):
|
||||
if output: # 如果是True
|
||||
return True
|
||||
elif isinstance(output, str):
|
||||
if output.lower() == 'true': # 字符串"true"
|
||||
return True
|
||||
elif isinstance(output, (int, float)):
|
||||
if output != 0 and output != '0': # 0以外的数字
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# 发送处理
|
||||
async def __producer_handler(self, websocket, path):
|
||||
while self.__running:
|
||||
await asyncio.sleep(0.01)
|
||||
if len(self.__listCmd) > 0:
|
||||
message = await self.__producer()
|
||||
if message:
|
||||
username = json.loads(message).get("Username")
|
||||
if username is None:
|
||||
# 群发消息
|
||||
async with self.lock:
|
||||
wsclients = [c["websocket"] for c in self.__clients]
|
||||
tasks = [self.send_message_with_timeout(client, message, username, timeout=3) for client in wsclients]
|
||||
await asyncio.gather(*tasks)
|
||||
else:
|
||||
# 向指定用户发送消息
|
||||
async with self.lock:
|
||||
target_clients = [c["websocket"] for c in self.__clients if c.get("username") == username]
|
||||
tasks = [self.send_message_with_timeout(client, message, username, timeout=3) for client in target_clients]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# 发送消息(设置超时)
|
||||
async def send_message_with_timeout(self, client, message, username, timeout=3):
|
||||
try:
|
||||
await asyncio.wait_for(self.send_message(client, message, username), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
util.printInfo(1, "User" if username is None else username, f"发送消息超时: 用户名 {username}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
# 从客户端列表中移除已断开的连接
|
||||
await self.remove_client(client)
|
||||
util.printInfo(1, "User" if username is None else username, f"WebSocket 连接关闭: {e}")
|
||||
|
||||
# 发送消息
|
||||
async def send_message(self, client, message, username):
|
||||
try:
|
||||
await client.send(message)
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
# 从客户端列表中移除已断开的连接
|
||||
await self.remove_client(client)
|
||||
util.printInfo(1, "User" if username is None else username, f"WebSocket 连接关闭: {e}")
|
||||
|
||||
|
||||
async def __handler(self, websocket, path):
|
||||
self.isConnect = True
|
||||
util.log(1,"websocket连接上:{}".format(self.__port))
|
||||
self.on_connect_handler()
|
||||
remote_address = websocket.remote_address
|
||||
unique_id = f"{remote_address[0]}:{remote_address[1]}"
|
||||
async with self.lock:
|
||||
self.__clients.append({"id" : unique_id, "websocket" : websocket, "username" : "User"})
|
||||
consumer_task = asyncio.create_task(self.__consumer_handler(websocket, path))#接收
|
||||
producer_task = asyncio.create_task(self.__producer_handler(websocket, path))#发送
|
||||
done, self.__pending = await asyncio.wait([consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in self.__pending:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 从客户端列表中移除已断开的连接
|
||||
await self.remove_client(websocket)
|
||||
util.log(1, "websocket连接断开:{}".format(unique_id))
|
||||
|
||||
async def __consumer(self, message):
|
||||
self.on_revice_handler(message)
|
||||
|
||||
async def __producer(self):
|
||||
if len(self.__listCmd) > 0:
|
||||
message = self.on_send_handler(self.__listCmd.pop(0))
|
||||
return message
|
||||
else:
|
||||
return None
|
||||
|
||||
async def remove_client(self, websocket):
|
||||
async with self.lock:
|
||||
self.__clients = [c for c in self.__clients if c["websocket"] != websocket]
|
||||
if len(self.__clients) == 0:
|
||||
self.isConnect = False
|
||||
self.on_close_handler()
|
||||
|
||||
def is_connected(self, username):
|
||||
if username is None:
|
||||
username = "User"
|
||||
if len(self.__clients) == 0:
|
||||
return False
|
||||
clients = [c for c in self.__clients if c["username"] == username]
|
||||
if len(clients) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
#Edit by xszyou on 20230113:通过继承此类来实现服务端的接收后处理逻辑
|
||||
@abstractmethod
|
||||
def on_revice_handler(self, message):
|
||||
pass
|
||||
|
||||
#Edit by xszyou on 20230114:通过继承此类来实现服务端的连接处理逻辑
|
||||
@abstractmethod
|
||||
def on_connect_handler(self):
|
||||
pass
|
||||
|
||||
#Edit by xszyou on 20230804:通过继承此类来实现服务端的发送前的处理逻辑
|
||||
@abstractmethod
|
||||
def on_send_handler(self, message):
|
||||
return message
|
||||
|
||||
#Edit by xszyou on 20230816:通过继承此类来实现服务端的断开后的处理逻辑
|
||||
@abstractmethod
|
||||
def on_close_handler(self):
|
||||
pass
|
||||
|
||||
# 创建server
|
||||
def __connect(self):
|
||||
self.__event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.__event_loop)
|
||||
self.__isExecute = True
|
||||
if self.__server:
|
||||
util.log(1, 'server already exist')
|
||||
return
|
||||
self.__server = websockets.serve(self.__handler, self.__host, self.__port)
|
||||
asyncio.get_event_loop().run_until_complete(self.__server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
# 往要发送的命令列表中,添加命令
|
||||
def add_cmd(self, content):
|
||||
if not self.__running:
|
||||
return
|
||||
jsonStr = json.dumps(content)
|
||||
# keep unicode (emoji/中文) intact for websocket consumers
|
||||
jsonStr = json.dumps(content, ensure_ascii=False)
|
||||
self.__listCmd.append(jsonStr)
|
||||
# util.log('命令 {}'.format(content))
|
||||
|
||||
# 开启服务
|
||||
def start_server(self):
|
||||
MyThread(target=self.__connect).start()
|
||||
|
||||
# 关闭服务
|
||||
def stop_server(self):
|
||||
self.__running = False
|
||||
self.isConnect = False
|
||||
if self.__server is None:
|
||||
return
|
||||
self.__server.close()
|
||||
self.__server = None
|
||||
self.__clients = []
|
||||
util.log(1, "WebSocket server stopped.")
|
||||
|
||||
|
||||
#ui端server
|
||||
class WebServer(MyServer):
|
||||
def __init__(self, host='0.0.0.0', port=10003):
|
||||
super().__init__(host, port)
|
||||
|
||||
def on_revice_handler(self, message):
|
||||
pass
|
||||
|
||||
def on_connect_handler(self):
|
||||
self.add_cmd({"panelMsg": "使用提示:Fay可以独立使用,启动数字人将自动对接。"})
|
||||
|
||||
def on_send_handler(self, message):
|
||||
return message
|
||||
|
||||
def on_close_handler(self):
|
||||
pass
|
||||
|
||||
#数字人端server
|
||||
class HumanServer(MyServer):
|
||||
def __init__(self, host='0.0.0.0', port=10002):
|
||||
super().__init__(host, port)
|
||||
|
||||
def on_revice_handler(self, message):
|
||||
pass
|
||||
|
||||
def on_connect_handler(self):
|
||||
web_server_instance = get_web_instance()
|
||||
web_server_instance.add_cmd({"is_connect": self.isConnect})
|
||||
|
||||
|
||||
def on_send_handler(self, message):
|
||||
# util.log(1, '向human发送 {}'.format(message))
|
||||
if not self.isConnect:
|
||||
return None
|
||||
return message
|
||||
|
||||
def on_close_handler(self):
|
||||
web_server_instance = get_web_instance()
|
||||
web_server_instance.add_cmd({"is_connect": self.isConnect})
|
||||
|
||||
|
||||
|
||||
#测试
|
||||
class TestServer(MyServer):
|
||||
def __init__(self, host='0.0.0.0', port=10000):
|
||||
super().__init__(host, port)
|
||||
|
||||
def on_revice_handler(self, message):
|
||||
print(message)
|
||||
|
||||
def on_connect_handler(self):
|
||||
print("连接上了")
|
||||
|
||||
def on_send_handler(self, message):
|
||||
return message
|
||||
|
||||
def on_close_handler(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
#单例
|
||||
|
||||
__instance: MyServer = None
|
||||
__web_instance: MyServer = None
|
||||
|
||||
|
||||
def new_instance(host='0.0.0.0', port=10002) -> MyServer:
|
||||
global __instance
|
||||
if __instance is None:
|
||||
__instance = HumanServer(host, port)
|
||||
return __instance
|
||||
|
||||
|
||||
def new_web_instance(host='0.0.0.0', port=10003) -> MyServer:
|
||||
global __web_instance
|
||||
if __web_instance is None:
|
||||
__web_instance = WebServer(host, port)
|
||||
return __web_instance
|
||||
|
||||
|
||||
def get_instance() -> MyServer:
|
||||
return __instance
|
||||
|
||||
|
||||
def get_web_instance() -> MyServer:
|
||||
return __web_instance
|
||||
|
||||
if __name__ == '__main__':
|
||||
testServer = TestServer(host='0.0.0.0', port=10000)
|
||||
# util.log('命令 {}'.format(content))
|
||||
|
||||
# 开启服务
|
||||
def start_server(self):
|
||||
MyThread(target=self.__connect).start()
|
||||
|
||||
# 关闭服务
|
||||
def stop_server(self):
|
||||
self.__running = False
|
||||
self.isConnect = False
|
||||
if self.__server is None:
|
||||
return
|
||||
self.__server.close()
|
||||
self.__server = None
|
||||
self.__clients = []
|
||||
util.log(1, "WebSocket server stopped.")
|
||||
|
||||
|
||||
#ui端server
|
||||
class WebServer(MyServer):
|
||||
def __init__(self, host='0.0.0.0', port=10003):
|
||||
super().__init__(host, port)
|
||||
|
||||
def on_revice_handler(self, message):
|
||||
pass
|
||||
|
||||
def on_connect_handler(self):
|
||||
self.add_cmd({"panelMsg": "使用提示:Fay可以独立使用,启动数字人将自动对接。"})
|
||||
|
||||
def on_send_handler(self, message):
|
||||
return message
|
||||
|
||||
def on_close_handler(self):
|
||||
pass
|
||||
|
||||
#数字人端server
|
||||
class HumanServer(MyServer):
|
||||
def __init__(self, host='0.0.0.0', port=10002):
|
||||
super().__init__(host, port)
|
||||
|
||||
def on_revice_handler(self, message):
|
||||
pass
|
||||
|
||||
def on_connect_handler(self):
|
||||
web_server_instance = get_web_instance()
|
||||
web_server_instance.add_cmd({"is_connect": self.isConnect})
|
||||
|
||||
|
||||
def on_send_handler(self, message):
|
||||
# util.log(1, '向human发送 {}'.format(message))
|
||||
if not self.isConnect:
|
||||
return None
|
||||
return message
|
||||
|
||||
def on_close_handler(self):
|
||||
web_server_instance = get_web_instance()
|
||||
web_server_instance.add_cmd({"is_connect": self.isConnect})
|
||||
|
||||
|
||||
|
||||
#测试
|
||||
class TestServer(MyServer):
|
||||
def __init__(self, host='0.0.0.0', port=10000):
|
||||
super().__init__(host, port)
|
||||
|
||||
def on_revice_handler(self, message):
|
||||
print(message)
|
||||
|
||||
def on_connect_handler(self):
|
||||
print("连接上了")
|
||||
|
||||
def on_send_handler(self, message):
|
||||
return message
|
||||
|
||||
def on_close_handler(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
#单例
|
||||
|
||||
__instance: MyServer = None
|
||||
__web_instance: MyServer = None
|
||||
|
||||
|
||||
def new_instance(host='0.0.0.0', port=10002) -> MyServer:
|
||||
global __instance
|
||||
if __instance is None:
|
||||
__instance = HumanServer(host, port)
|
||||
return __instance
|
||||
|
||||
|
||||
def new_web_instance(host='0.0.0.0', port=10003) -> MyServer:
|
||||
global __web_instance
|
||||
if __web_instance is None:
|
||||
__web_instance = WebServer(host, port)
|
||||
return __web_instance
|
||||
|
||||
|
||||
def get_instance() -> MyServer:
|
||||
return __instance
|
||||
|
||||
|
||||
def get_web_instance() -> MyServer:
|
||||
return __web_instance
|
||||
|
||||
if __name__ == '__main__':
|
||||
testServer = TestServer(host='0.0.0.0', port=10000)
|
||||
testServer.start_server()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#核心启动模块
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
import pyaudio
|
||||
import socket
|
||||
@@ -21,6 +22,10 @@ deviceSocketServer = None
|
||||
DeviceInputListenerDict = {}
|
||||
ngrok = None
|
||||
socket_service_instance = None
|
||||
mcp_sse_server = None
|
||||
mcp_sse_thread = None
|
||||
# 是否启用内置 MCP SSE 服务器(默认关闭,需显式开启以避免端口/代理问题)
|
||||
mcp_sse_enabled = True
|
||||
|
||||
# 延迟导入fay_core
|
||||
def get_fay_core():
|
||||
@@ -287,9 +292,25 @@ def stop():
|
||||
global ngrok
|
||||
global socket_service_instance
|
||||
global deviceSocketServer
|
||||
global mcp_sse_server
|
||||
global mcp_sse_thread
|
||||
|
||||
util.log(1, '正在关闭服务...')
|
||||
__running = False
|
||||
|
||||
# 关闭 MCP SSE 服务
|
||||
try:
|
||||
if mcp_sse_server is not None:
|
||||
util.log(1, '正在关闭MCP SSE服务器...')
|
||||
try:
|
||||
mcp_sse_server.should_exit = True
|
||||
except Exception:
|
||||
pass
|
||||
if mcp_sse_thread is not None and mcp_sse_thread.is_alive():
|
||||
mcp_sse_thread.join(timeout=2)
|
||||
util.log(1, 'MCP SSE服务器已关闭')
|
||||
except Exception as e:
|
||||
util.log(1, f'MCP SSE服务器关闭异常: {e}')
|
||||
|
||||
# 断开所有MCP服务连接
|
||||
util.log(1, '正在断开所有MCP服务连接...')
|
||||
@@ -338,6 +359,8 @@ def start():
|
||||
global recorderListener
|
||||
global __running
|
||||
global socket_service_instance
|
||||
global mcp_sse_server
|
||||
global mcp_sse_thread
|
||||
|
||||
util.log(1, '开启服务...')
|
||||
__running = True
|
||||
@@ -375,6 +398,26 @@ def start():
|
||||
#启动自动播报服务
|
||||
util.log(1,'启动自动播报服务...')
|
||||
MyThread(target=start_auto_play_service).start()
|
||||
|
||||
# 启动 MCP SSE 服务(需显式开启)
|
||||
if mcp_sse_enabled:
|
||||
try:
|
||||
from faymcp import mcp_server as fay_mcp_server
|
||||
import uvicorn
|
||||
util.log(1, f"MCP SSE服务器启动中: http://{fay_mcp_server.HOST}:{fay_mcp_server.PORT}{fay_mcp_server.SSE_PATH}")
|
||||
config = uvicorn.Config(
|
||||
app=fay_mcp_server.app,
|
||||
host=fay_mcp_server.HOST,
|
||||
port=fay_mcp_server.PORT,
|
||||
log_level="info"
|
||||
)
|
||||
mcp_sse_server = uvicorn.Server(config)
|
||||
mcp_sse_thread = MyThread(target=mcp_sse_server.run, daemon=True)
|
||||
mcp_sse_thread.start()
|
||||
except Exception as e:
|
||||
util.log(1, f"MCP SSE服务器启动异常: {e}")
|
||||
else:
|
||||
util.log(1, 'MCP SSE服务器默认未开启,设 FAY_MCP_SSE_ENABLE=1 可启用')
|
||||
|
||||
util.log(1, '服务启动完成!')
|
||||
|
||||
|
||||
189
faymcp/mcp_server.py
Normal file
189
faymcp/mcp_server.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Fay broadcast MCP server (SSE transport).
|
||||
|
||||
暴露 `broadcast_message` 工具,将文本/音频透传到 Fay 的 `/transparent-pass`。
|
||||
|
||||
环境变量:
|
||||
- FAY_BROADCAST_API 默认 http://127.0.0.1:5000/transparent-pass
|
||||
- FAY_BROADCAST_USER 默认 User
|
||||
- FAY_BROADCAST_TIMEOUT 默认 10
|
||||
- FAY_MCP_SSE_HOST 默认 0.0.0.0
|
||||
- FAY_MCP_SSE_PORT 默认 8765
|
||||
- FAY_MCP_SSE_PATH SSE 路径(默认 /sse)
|
||||
- FAY_MCP_MSG_PATH 消息 POST 路径(默认 /messages)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
try:
|
||||
from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent
|
||||
from mcp.server.sse import SseServerTransport
|
||||
except ImportError:
|
||||
print("缺少 mcp 库,请先安装:pip install mcp", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
except ImportError:
|
||||
print("缺少 starlette,请先安装:pip install starlette sse-starlette", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
print("缺少 uvicorn,请先安装:pip install uvicorn", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
print("缺少 requests,请先安装:pip install requests", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
log = logging.getLogger("fay_mcp_server")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
|
||||
SERVER_NAME = "fay_broadcast"
|
||||
|
||||
DEFAULT_API_URL = os.environ.get("FAY_BROADCAST_API", "http://127.0.0.1:5000/transparent-pass")
|
||||
DEFAULT_USER = os.environ.get("FAY_BROADCAST_USER", "User")
|
||||
REQUEST_TIMEOUT = float(os.environ.get("FAY_BROADCAST_TIMEOUT", "10"))
|
||||
|
||||
HOST = os.environ.get("FAY_MCP_SSE_HOST", "0.0.0.0")
|
||||
PORT = int(os.environ.get("FAY_MCP_SSE_PORT", "8765"))
|
||||
SSE_PATH = os.environ.get("FAY_MCP_SSE_PATH", "/sse")
|
||||
MSG_PATH = os.environ.get("FAY_MCP_MSG_PATH", "/messages")
|
||||
|
||||
server = Server(SERVER_NAME)
|
||||
sse_transport = SseServerTransport(MSG_PATH)
|
||||
|
||||
|
||||
def _text_content(text: str) -> TextContent:
|
||||
try:
|
||||
return TextContent(type="text", text=text)
|
||||
except Exception:
|
||||
return {"type": "text", "text": text} # type: ignore[return-value]
|
||||
|
||||
|
||||
TOOLS: list[Tool] = [
|
||||
Tool(
|
||||
name="broadcast_message",
|
||||
description="通过 Fay 的 /transparent-pass 广播文本/音频(SSE 服务器)。",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要广播的文本(audio_url为空时必填)"},
|
||||
"audio_url": {"type": "string", "description": "可选音频 URL"},
|
||||
"user": {"type": "string", "description": "目标用户名,默认 FAY_BROADCAST_USER 或 User"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
return TOOLS
|
||||
|
||||
|
||||
def _parse_arguments(arguments: Dict[str, Any]) -> Tuple[str, str, str]:
|
||||
text = str(arguments.get("text", "") or "").strip()
|
||||
audio_url = str(arguments.get("audio_url", "") or "").strip()
|
||||
user = str(arguments.get("user", "") or "").strip() or DEFAULT_USER
|
||||
return text, audio_url, user
|
||||
|
||||
|
||||
async def _send_broadcast(payload: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
def _post() -> Tuple[bool, str]:
|
||||
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
resp = requests.post(
|
||||
DEFAULT_API_URL,
|
||||
data=body,
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
data = None
|
||||
|
||||
if resp.ok:
|
||||
if isinstance(data, dict):
|
||||
msg = data.get("message") or data.get("msg") or ""
|
||||
code = data.get("code")
|
||||
if isinstance(code, int) and code >= 400:
|
||||
return False, msg or f"Broadcast failed with code {code}"
|
||||
return True, msg or "Broadcast sent via Fay."
|
||||
return True, "Broadcast sent via Fay."
|
||||
|
||||
err_detail = ""
|
||||
if isinstance(data, dict):
|
||||
err_detail = data.get("message") or data.get("error") or data.get("msg") or ""
|
||||
if not err_detail:
|
||||
err_detail = resp.text
|
||||
return False, f"HTTP {resp.status_code}: {err_detail}"
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_post)
|
||||
except Exception as e:
|
||||
return False, f"{type(e).__name__}: {e}"
|
||||
|
||||
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: Dict[str, Any]) -> list[TextContent]:
|
||||
if name != "broadcast_message":
|
||||
return [_text_content(f"Unknown tool: {name}")]
|
||||
|
||||
text, audio_url, user = _parse_arguments(arguments or {})
|
||||
if not text and not audio_url:
|
||||
return [_text_content("Either 'text' or 'audio_url' must be provided.")]
|
||||
|
||||
payload: Dict[str, Any] = {"user": user}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if audio_url:
|
||||
payload["audio"] = audio_url
|
||||
|
||||
ok, message = await _send_broadcast(payload)
|
||||
prefix = "success" if ok else "error"
|
||||
return [_text_content(f"{prefix}: {message}")]
|
||||
|
||||
|
||||
async def sse_endpoint(request):
|
||||
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream):
|
||||
await server.run(read_stream, write_stream, server.create_initialization_options())
|
||||
# 客户端断开时返回空响应,避免 NoneType 问题
|
||||
return Response()
|
||||
|
||||
|
||||
routes = [
|
||||
Route(SSE_PATH, sse_endpoint, methods=["GET"]),
|
||||
Mount(MSG_PATH, app=sse_transport.handle_post_message),
|
||||
]
|
||||
|
||||
app = Starlette(routes=routes)
|
||||
|
||||
|
||||
def main():
|
||||
log.info(f"SSE MCP server started at http://{HOST}:{PORT}{SSE_PATH}")
|
||||
log.info(f"Message endpoint mounted at {MSG_PATH}")
|
||||
uvicorn.run(app, host=HOST, port=PORT, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
Reference in New Issue
Block a user