mirror of
https://github.com/xszyou/Fay.git
synced 2026-03-12 17:51:28 +08:00
自然进化
1. fay启动命令增加参数config_center; 2. 修复多个think标签处理逻辑问题; 3. 修复llm透传模式编码问题;
This commit is contained in:
@@ -1,218 +1,228 @@
|
||||
import logging
|
||||
import requests
|
||||
from typing import List, Optional
|
||||
import threading
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
try:
|
||||
import utils.config_util as cfg
|
||||
CONFIG_UTIL_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
CONFIG_UTIL_AVAILABLE = False
|
||||
cfg = None
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
import re
|
||||
import requests
|
||||
from typing import List, Optional
|
||||
import threading
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
try:
|
||||
import utils.config_util as cfg
|
||||
CONFIG_UTIL_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
CONFIG_UTIL_AVAILABLE = False
|
||||
cfg = None
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if not CONFIG_UTIL_AVAILABLE:
|
||||
logger.warning("无法导入 config_util,将使用环境变量配置")
|
||||
|
||||
def _sanitize_text(text: str) -> str:
|
||||
if not isinstance(text, str) or not text:
|
||||
return text
|
||||
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', text, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
|
||||
return cleaned
|
||||
|
||||
class ApiEmbeddingService:
|
||||
"""API Embedding服务 - 单例模式,调用 OpenAI 兼容的 API"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._initialize_config()
|
||||
ApiEmbeddingService._initialized = True
|
||||
|
||||
def _initialize_config(self):
|
||||
"""初始化配置,只执行一次"""
|
||||
try:
|
||||
# 优先从 system.conf 读取配置
|
||||
api_base_url = None
|
||||
api_key = None
|
||||
model_name = None
|
||||
|
||||
if CONFIG_UTIL_AVAILABLE and cfg:
|
||||
try:
|
||||
# 确保配置已加载
|
||||
if cfg.config is None:
|
||||
cfg.load_config()
|
||||
|
||||
# 从 config_util 获取配置(自动复用 LLM 配置)
|
||||
api_base_url = cfg.embedding_api_base_url
|
||||
api_key = cfg.embedding_api_key
|
||||
model_name = cfg.embedding_api_model
|
||||
|
||||
logger.info(f"从 system.conf 读取配置:")
|
||||
logger.info(f" - embedding_api_model: {model_name}")
|
||||
logger.info(f" - embedding_api_base_url: {api_base_url}")
|
||||
logger.info(f" - embedding_api_key: {'已配置' if api_key else '未配置'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"从 system.conf 读取配置失败: {e}")
|
||||
|
||||
# 验证必需配置并提供更好的错误提示
|
||||
if not api_base_url:
|
||||
api_base_url = os.getenv('EMBEDDING_API_BASE_URL')
|
||||
if not api_base_url:
|
||||
error_msg = ("未配置 embedding_api_base_url!\n"
|
||||
"请确保 system.conf 中配置了 gpt_base_url,"
|
||||
"或设置环境变量 EMBEDDING_API_BASE_URL")
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.warning(f"使用环境变量配置: base_url={api_base_url}")
|
||||
|
||||
if not api_key:
|
||||
api_key = os.getenv('EMBEDDING_API_KEY')
|
||||
if not api_key:
|
||||
error_msg = ("未配置 embedding_api_key!\n"
|
||||
"请确保 system.conf 中配置了 gpt_api_key,"
|
||||
"或设置环境变量 EMBEDDING_API_KEY")
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.warning("使用环境变量配置: api_key")
|
||||
|
||||
if not model_name:
|
||||
model_name = os.getenv('EMBEDDING_API_MODEL', 'text-embedding-ada-002')
|
||||
logger.warning(f"未配置 embedding_api_model,使用默认值: {model_name}")
|
||||
|
||||
# 保存配置信息
|
||||
self.api_base_url = api_base_url.rstrip('/') # 移除末尾的斜杠
|
||||
self.api_key = api_key
|
||||
self.model_name = model_name
|
||||
self.embedding_dim = None # 将在首次调用时动态获取
|
||||
self.timeout = 60 # API 请求超时时间(秒),默认 60 秒
|
||||
self.max_retries = 2 # 最大重试次数
|
||||
|
||||
logger.info(f"API Embedding 服务初始化完成")
|
||||
logger.info(f"模型: {self.model_name}")
|
||||
logger.info(f"API 地址: {self.api_base_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API Embedding 服务初始化失败: {e}")
|
||||
raise
|
||||
|
||||
"""API Embedding服务 - 单例模式,调用 OpenAI 兼容的 API"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._initialize_config()
|
||||
ApiEmbeddingService._initialized = True
|
||||
|
||||
def _initialize_config(self):
|
||||
"""初始化配置,只执行一次"""
|
||||
try:
|
||||
# 优先从 system.conf 读取配置
|
||||
api_base_url = None
|
||||
api_key = None
|
||||
model_name = None
|
||||
|
||||
if CONFIG_UTIL_AVAILABLE and cfg:
|
||||
try:
|
||||
# 确保配置已加载
|
||||
if cfg.config is None:
|
||||
cfg.load_config()
|
||||
|
||||
# 从 config_util 获取配置(自动复用 LLM 配置)
|
||||
api_base_url = cfg.embedding_api_base_url
|
||||
api_key = cfg.embedding_api_key
|
||||
model_name = cfg.embedding_api_model
|
||||
|
||||
logger.info(f"从 system.conf 读取配置:")
|
||||
logger.info(f" - embedding_api_model: {model_name}")
|
||||
logger.info(f" - embedding_api_base_url: {api_base_url}")
|
||||
logger.info(f" - embedding_api_key: {'已配置' if api_key else '未配置'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"从 system.conf 读取配置失败: {e}")
|
||||
|
||||
# 验证必需配置并提供更好的错误提示
|
||||
if not api_base_url:
|
||||
api_base_url = os.getenv('EMBEDDING_API_BASE_URL')
|
||||
if not api_base_url:
|
||||
error_msg = ("未配置 embedding_api_base_url!\n"
|
||||
"请确保 system.conf 中配置了 gpt_base_url,"
|
||||
"或设置环境变量 EMBEDDING_API_BASE_URL")
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.warning(f"使用环境变量配置: base_url={api_base_url}")
|
||||
|
||||
if not api_key:
|
||||
api_key = os.getenv('EMBEDDING_API_KEY')
|
||||
if not api_key:
|
||||
error_msg = ("未配置 embedding_api_key!\n"
|
||||
"请确保 system.conf 中配置了 gpt_api_key,"
|
||||
"或设置环境变量 EMBEDDING_API_KEY")
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.warning("使用环境变量配置: api_key")
|
||||
|
||||
if not model_name:
|
||||
model_name = os.getenv('EMBEDDING_API_MODEL', 'text-embedding-ada-002')
|
||||
logger.warning(f"未配置 embedding_api_model,使用默认值: {model_name}")
|
||||
|
||||
# 保存配置信息
|
||||
self.api_base_url = api_base_url.rstrip('/') # 移除末尾的斜杠
|
||||
self.api_key = api_key
|
||||
self.model_name = model_name
|
||||
self.embedding_dim = None # 将在首次调用时动态获取
|
||||
self.timeout = 60 # API 请求超时时间(秒),默认 60 秒
|
||||
self.max_retries = 2 # 最大重试次数
|
||||
|
||||
logger.info(f"API Embedding 服务初始化完成")
|
||||
logger.info(f"模型: {self.model_name}")
|
||||
logger.info(f"API 地址: {self.api_base_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API Embedding 服务初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def encode_text(self, text: str) -> List[float]:
|
||||
"""编码单个文本(带重试机制)"""
|
||||
import time
|
||||
|
||||
text = _sanitize_text(text)
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
# 调用 API 进行编码
|
||||
url = f"{self.api_base_url}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text
|
||||
}
|
||||
|
||||
# 记录请求信息
|
||||
text_preview = text[:50] + "..." if len(text) > 50 else text
|
||||
logger.info(f"发送 embedding 请求 (尝试 {attempt + 1}/{self.max_retries + 1}): 文本长度={len(text)}, 预览='{text_preview}'")
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embedding = result['data'][0]['embedding']
|
||||
|
||||
# 首次调用时获取实际维度
|
||||
if self.embedding_dim is None:
|
||||
self.embedding_dim = len(embedding)
|
||||
logger.info(f"动态获取 embedding 维度: {self.embedding_dim}")
|
||||
|
||||
logger.info(f"embedding 生成成功")
|
||||
return embedding
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
last_error = e
|
||||
logger.warning(f"请求超时 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
if attempt < self.max_retries:
|
||||
wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"所有重试均失败,文本长度: {len(text)}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error(f"文本编码失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
if attempt < self.max_retries:
|
||||
wait_time = 2 ** attempt
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
# 调用 API 进行编码
|
||||
url = f"{self.api_base_url}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text
|
||||
}
|
||||
|
||||
# 记录请求信息
|
||||
text_preview = text[:50] + "..." if len(text) > 50 else text
|
||||
logger.info(f"发送 embedding 请求 (尝试 {attempt + 1}/{self.max_retries + 1}): 文本长度={len(text)}, 预览='{text_preview}'")
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embedding = result['data'][0]['embedding']
|
||||
|
||||
# 首次调用时获取实际维度
|
||||
if self.embedding_dim is None:
|
||||
self.embedding_dim = len(embedding)
|
||||
logger.info(f"动态获取 embedding 维度: {self.embedding_dim}")
|
||||
|
||||
logger.info(f"embedding 生成成功")
|
||||
return embedding
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
last_error = e
|
||||
logger.warning(f"请求超时 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
if attempt < self.max_retries:
|
||||
wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"所有重试均失败,文本长度: {len(text)}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error(f"文本编码失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
if attempt < self.max_retries:
|
||||
wait_time = 2 ** attempt
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
def encode_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""批量编码文本"""
|
||||
try:
|
||||
texts = [_sanitize_text(text) for text in texts]
|
||||
# 调用 API 进行批量编码
|
||||
url = f"{self.api_base_url}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": texts
|
||||
}
|
||||
|
||||
# 批量请求使用更长的超时时间
|
||||
batch_timeout = self.timeout * 2 # 批量请求超时时间加倍
|
||||
logger.info(f"发送批量 embedding 请求: 文本数={len(texts)}, 超时={batch_timeout}秒")
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=batch_timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embeddings = [item['embedding'] for item in result['data']]
|
||||
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"批量文本编码失败: {e}")
|
||||
raise
|
||||
|
||||
def get_model_info(self) -> dict:
|
||||
"""获取模型信息"""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"api_base_url": self.api_base_url,
|
||||
"initialized": self._initialized,
|
||||
"service_type": "api"
|
||||
}
|
||||
|
||||
# 全局实例
|
||||
_global_embedding_service = None
|
||||
|
||||
def get_embedding_service() -> ApiEmbeddingService:
|
||||
"""获取全局embedding服务实例"""
|
||||
global _global_embedding_service
|
||||
if _global_embedding_service is None:
|
||||
_global_embedding_service = ApiEmbeddingService()
|
||||
return _global_embedding_service
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": texts
|
||||
}
|
||||
|
||||
# 批量请求使用更长的超时时间
|
||||
batch_timeout = self.timeout * 2 # 批量请求超时时间加倍
|
||||
logger.info(f"发送批量 embedding 请求: 文本数={len(texts)}, 超时={batch_timeout}秒")
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=batch_timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embeddings = [item['embedding'] for item in result['data']]
|
||||
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"批量文本编码失败: {e}")
|
||||
raise
|
||||
|
||||
def get_model_info(self) -> dict:
|
||||
"""获取模型信息"""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"api_base_url": self.api_base_url,
|
||||
"initialized": self._initialized,
|
||||
"service_type": "api"
|
||||
}
|
||||
|
||||
# 全局实例
|
||||
_global_embedding_service = None
|
||||
|
||||
def get_embedding_service() -> ApiEmbeddingService:
|
||||
"""获取全局embedding服务实例"""
|
||||
global _global_embedding_service
|
||||
if _global_embedding_service is None:
|
||||
_global_embedding_service = ApiEmbeddingService()
|
||||
return _global_embedding_service
|
||||
|
||||
186
core/fay_core.py
186
core/fay_core.py
@@ -76,7 +76,9 @@ class FeiFei:
|
||||
from ai_module import nlp_cemotion
|
||||
|
||||
|
||||
from core import stream_manager
|
||||
from core import stream_manager
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -331,7 +333,9 @@ class FeiFei:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 执行emoji过滤
|
||||
|
||||
|
||||
filtered_text = emoji_pattern.sub('', text)
|
||||
@@ -748,6 +752,177 @@ class FeiFei:
|
||||
|
||||
|
||||
full_path = os.path.join(path, filename)
|
||||
|
||||
|
||||
with open(full_path, 'w', encoding='utf-8') as file:
|
||||
|
||||
|
||||
file.write(content)
|
||||
|
||||
|
||||
file.flush()
|
||||
|
||||
|
||||
os.fsync(file.fileno())
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#触发交互
|
||||
|
||||
|
||||
def on_interact(self, interact: Interact):
|
||||
|
||||
|
||||
#创建用户
|
||||
|
||||
|
||||
username = interact.data.get("user", "User")
|
||||
|
||||
|
||||
if member_db.new_instance().is_username_exist(username) == "notexists":
|
||||
|
||||
|
||||
member_db.new_instance().add_user(username)
|
||||
|
||||
|
||||
no_reply = interact.data.get("no_reply", False)
|
||||
|
||||
if isinstance(no_reply, str):
|
||||
|
||||
no_reply = no_reply.strip().lower() in ("1", "true", "yes", "y", "on")
|
||||
|
||||
else:
|
||||
|
||||
no_reply = bool(no_reply)
|
||||
|
||||
|
||||
|
||||
if not no_reply:
|
||||
|
||||
try:
|
||||
|
||||
|
||||
from utils.stream_state_manager import get_state_manager
|
||||
|
||||
|
||||
import uuid
|
||||
|
||||
|
||||
if get_state_manager().is_session_active(username):
|
||||
|
||||
|
||||
stream_manager.new_instance().clear_Stream_with_audio(username)
|
||||
|
||||
|
||||
conv_id = "conv_" + str(uuid.uuid4())
|
||||
|
||||
|
||||
stream_manager.new_instance().set_current_conversation(username, conv_id)
|
||||
|
||||
|
||||
# 将当前会话ID附加到交互数据
|
||||
|
||||
|
||||
interact.data["conversation_id"] = conv_id
|
||||
|
||||
|
||||
# 允许新的生成
|
||||
|
||||
|
||||
stream_manager.new_instance().set_stop_generation(username, stop=False)
|
||||
|
||||
|
||||
except Exception:
|
||||
|
||||
|
||||
util.log(3, "开启新会话失败")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if interact.interact_type == 1:
|
||||
|
||||
|
||||
MyThread(target=self.__process_interact, args=[interact]).start()
|
||||
|
||||
|
||||
else:
|
||||
|
||||
|
||||
return self.__process_interact(interact)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#获取不同情绪声音
|
||||
|
||||
|
||||
def __get_mood_voice(self):
|
||||
|
||||
|
||||
voice = tts_voice.get_voice_of(config_util.config["attribute"]["voice"])
|
||||
|
||||
|
||||
if voice is None:
|
||||
|
||||
|
||||
voice = EnumVoice.XIAO_XIAO
|
||||
|
||||
|
||||
styleList = voice.value["styleList"]
|
||||
|
||||
|
||||
sayType = styleList["calm"]
|
||||
|
||||
|
||||
return sayType
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 合成声音
|
||||
|
||||
|
||||
def say(self, interact, text, type = ""):
|
||||
|
||||
|
||||
try:
|
||||
|
||||
|
||||
uid = member_db.new_instance().find_user(interact.data.get("user"))
|
||||
|
||||
|
||||
is_end = interact.data.get("isend", False)
|
||||
|
||||
|
||||
is_first = interact.data.get("isfirst", False)
|
||||
|
||||
|
||||
username = interact.data.get("user", "User")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 提前进行会话有效性与中断检查,避免产生多余面板/数字人输出
|
||||
|
||||
|
||||
try:
|
||||
|
||||
|
||||
user_for_stop = interact.data.get("user", "User")
|
||||
|
||||
|
||||
conv_id_for_stop = interact.data.get("conversation_id")
|
||||
|
||||
|
||||
if not is_end and stream_manager.new_instance().should_stop_generation(user_for_stop, conversation_id=conv_id_for_stop):
|
||||
|
||||
|
||||
return None
|
||||
|
||||
@@ -762,12 +937,13 @@ class FeiFei:
|
||||
|
||||
|
||||
#无效流式文本提前结束
|
||||
|
||||
|
||||
|
||||
if not is_first and not is_end and (text is None or text.strip() == ""):
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -776,7 +952,7 @@ class FeiFei:
|
||||
|
||||
|
||||
is_prestart_content = self.__has_prestart(text)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,270 +1,270 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import threading
|
||||
import time
|
||||
from utils import stream_sentence
|
||||
from scheduler.thread_manager import MyThread
|
||||
import fay_booter
|
||||
from core import member_db
|
||||
from core.interact import Interact
|
||||
|
||||
# 全局变量,用于存储StreamManager的单例实例
|
||||
__streams = None
|
||||
# 线程锁,用于保护全局变量的访问
|
||||
__streams_lock = threading.Lock()
|
||||
|
||||
def new_instance(max_sentences=1024):
|
||||
"""
|
||||
创建并返回StreamManager的单例实例
|
||||
:param max_sentences: 最大句子缓存数量
|
||||
:return: StreamManager实例
|
||||
"""
|
||||
global __streams
|
||||
with __streams_lock:
|
||||
if __streams is None:
|
||||
__streams = StreamManager(max_sentences)
|
||||
return __streams
|
||||
|
||||
class StreamManager:
|
||||
"""
|
||||
流管理器类,用于管理和处理文本流数据
|
||||
"""
|
||||
def __init__(self, max_sentences=3):
|
||||
"""
|
||||
初始化StreamManager
|
||||
:param max_sentences: 每个流的最大句子缓存数量
|
||||
"""
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
# 使用两个独立的锁,避免死锁
|
||||
self.stream_lock = threading.RLock() # 流读写操作锁(可重入锁,允许同一线程多次获取)
|
||||
self.control_lock = threading.Lock() # 控制标志锁(用于停止生成标志)
|
||||
self.streams = {} # 存储用户ID到句子缓存的映射
|
||||
self.nlp_streams = {} # 存储用户ID到句子缓存的映射
|
||||
self.max_sentences = max_sentences # 最大句子缓存数量
|
||||
self.listener_threads = {} # 存储用户ID到监听线程的映射
|
||||
self.running = True # 控制监听线程的运行状态
|
||||
self._initialized = True # 标记是否已初始化
|
||||
self.msgid = "" # 消息ID
|
||||
self.stop_generation_flags = {} # 存储用户的停止生成标志
|
||||
self.conversation_ids = {} # 存储每个用户的会话ID(conv_前缀)
|
||||
|
||||
|
||||
def set_current_conversation(self, username, conversation_id, session_type=None):
|
||||
"""设置当前会话ID(conv_*)并对齐状态管理器的会话。
|
||||
session_type 可选;未提供则沿用已存在状态的类型或默认 'stream'。
|
||||
"""
|
||||
with self.control_lock:
|
||||
self.conversation_ids[username] = conversation_id
|
||||
|
||||
# 对齐 StreamStateManager 的会话,以防用户名级状态跨会话串线
|
||||
try:
|
||||
from utils.stream_state_manager import get_state_manager # 延迟导入避免循环依赖
|
||||
smgr = get_state_manager()
|
||||
info = smgr.get_session_info(username)
|
||||
if (not info) or (info.get('conversation_id') != conversation_id):
|
||||
smgr.start_new_session(
|
||||
username,
|
||||
session_type if session_type else (info.get('session_type') if info else 'stream'),
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
except Exception:
|
||||
# 状态对齐失败不阻断主流程
|
||||
pass
|
||||
|
||||
def get_conversation_id(self, username):
|
||||
"""获取当前会话ID(可能为空字符串)"""
|
||||
with self.control_lock:
|
||||
return self.conversation_ids.get(username, "")
|
||||
|
||||
def _get_Stream_internal(self, username):
|
||||
"""
|
||||
内部方法:获取指定用户ID的文本流(不加锁,调用者必须已持有stream_lock)
|
||||
:param username: 用户名
|
||||
:return: 对应的句子缓存对象
|
||||
"""
|
||||
if username not in self.streams or username not in self.nlp_streams:
|
||||
# 创建新的流缓存
|
||||
self.streams[username] = stream_sentence.SentenceCache(self.max_sentences)
|
||||
self.nlp_streams[username] = stream_sentence.SentenceCache(self.max_sentences)
|
||||
|
||||
# 启动监听线程(如果还没有)
|
||||
if username not in self.listener_threads:
|
||||
stream = self.streams[username]
|
||||
nlp_stream = self.nlp_streams[username]
|
||||
thread = MyThread(target=self.listen, args=(username, stream, nlp_stream), daemon=True)
|
||||
self.listener_threads[username] = thread
|
||||
thread.start()
|
||||
|
||||
return self.streams[username], self.nlp_streams[username]
|
||||
|
||||
def get_Stream(self, username):
|
||||
"""
|
||||
获取指定用户ID的文本流,如果不存在则创建新的(线程安全)
|
||||
:param username: 用户名
|
||||
:return: 对应的句子缓存对象
|
||||
"""
|
||||
# 使用stream_lock保护流的读写操作
|
||||
with self.stream_lock:
|
||||
return self._get_Stream_internal(username)
|
||||
|
||||
def write_sentence(self, username, sentence, conversation_id=None, session_version=None):
|
||||
"""
|
||||
写入句子到指定用户的文本流(线程安全)
|
||||
:param username: 用户名
|
||||
:param sentence: 要写入的句子
|
||||
:param conversation_id: 句子产生时的会话ID(可选,优先于版本判断)
|
||||
:param session_version: 句子产生时的会话版本(可选,兼容旧路径)
|
||||
:return: 写入是否成功
|
||||
"""
|
||||
# 检查句子长度,防止过大的句子导致内存问题
|
||||
if len(sentence) > 10240: # 10KB限制
|
||||
sentence = sentence[:10240]
|
||||
|
||||
# 若当前处于停止状态且这不是新会话的首句,则丢弃写入,避免残余输出
|
||||
with self.control_lock:
|
||||
stop_flag = self.stop_generation_flags.get(username, False)
|
||||
current_cid = self.conversation_ids.get(username, "")
|
||||
if stop_flag and ('_<isfirst>' not in sentence):
|
||||
return False
|
||||
|
||||
# 优先使用会话ID进行校验
|
||||
if conversation_id is not None and conversation_id != current_cid:
|
||||
return False
|
||||
# 兼容旧逻辑:按版本校验
|
||||
|
||||
|
||||
# 检查是否包含_<isfirst>标记(可能在句子中间)
|
||||
if '_<isfirst>' in sentence:
|
||||
# 收到新处理的第一个句子,重置停止标志,允许后续处理
|
||||
with self.control_lock:
|
||||
self.stop_generation_flags[username] = False
|
||||
|
||||
# 使用stream_lock保护写入操作
|
||||
with self.stream_lock:
|
||||
try:
|
||||
# 使用内部方法避免重复加锁
|
||||
Stream, nlp_Stream = self._get_Stream_internal(username)
|
||||
# 将会话ID以隐藏标签形式附在主流句子尾部,便于入口解析
|
||||
tag_cid = conversation_id if conversation_id is not None else current_cid
|
||||
tagged_sentence = f"{sentence}__<cid={tag_cid}>__" if tag_cid else sentence
|
||||
success = Stream.write(tagged_sentence)
|
||||
# 让 NLP 流也携带隐藏的会话ID,便于前端按会话过滤
|
||||
nlp_success = nlp_Stream.write(tagged_sentence)
|
||||
return success and nlp_success
|
||||
except Exception as e:
|
||||
print(f"写入句子时出错: {e}")
|
||||
return False
|
||||
|
||||
def _clear_Stream_internal(self, username):
|
||||
"""
|
||||
内部清除文本流方法,不获取锁(调用者必须已持有锁)
|
||||
:param username: 用户名
|
||||
"""
|
||||
if username in self.streams:
|
||||
self.streams[username].clear()
|
||||
if username in self.nlp_streams:
|
||||
self.nlp_streams[username].clear()
|
||||
|
||||
# 清除后写入一条结束标记,分别通知主流与NLP流结束
|
||||
try:
|
||||
# 确保流存在(监听线程也会在首次创建时启动)
|
||||
stream, nlp_stream = self._get_Stream_internal(username)
|
||||
cid = self.conversation_ids.get(username, "")
|
||||
end_marker = "_<isend>"
|
||||
# 主流带会话ID隐藏标签,供下游按会话拦截
|
||||
tagged = f"{end_marker}__<cid={cid}>__" if cid else end_marker
|
||||
stream.write(tagged)
|
||||
# NLP 流也写入带会话ID的结束标记,前端会按会话过滤
|
||||
nlp_stream.write(tagged)
|
||||
except Exception:
|
||||
# 忽略写入哨兵失败
|
||||
pass
|
||||
|
||||
def set_stop_generation(self, username, stop=True):
|
||||
"""
|
||||
设置指定用户的停止生成标志
|
||||
:param username: 用户名
|
||||
:param stop: 是否停止,默认True
|
||||
"""
|
||||
with self.control_lock:
|
||||
self.stop_generation_flags[username] = stop
|
||||
|
||||
def should_stop_generation(self, username, conversation_id=None, session_version=None):
|
||||
"""
|
||||
检查指定用户是否应该停止生成
|
||||
:param username: 用户名
|
||||
:return: 是否应该停止
|
||||
"""
|
||||
with self.control_lock:
|
||||
flag = self.stop_generation_flags.get(username, False)
|
||||
if flag:
|
||||
return True
|
||||
# 优先按会话ID判断
|
||||
current_cid = self.conversation_ids.get(username, "")
|
||||
if conversation_id is not None and conversation_id != current_cid:
|
||||
return True
|
||||
# 兼容旧逻辑:按版本判断
|
||||
return False
|
||||
|
||||
# 内部方法已移除,直接使用带锁的公共方法
|
||||
|
||||
def _clear_user_specific_audio(self, username, sound_queue):
|
||||
"""
|
||||
清理特定用户的音频队列项,保留其他用户的音频
|
||||
:param username: 要清理的用户名
|
||||
:param sound_queue: 音频队列
|
||||
"""
|
||||
import queue
|
||||
from utils import util
|
||||
temp_items = []
|
||||
|
||||
# 使用非阻塞方式提取所有项,避免死锁
|
||||
try:
|
||||
while True:
|
||||
item = sound_queue.get_nowait() # 非阻塞获取
|
||||
file_url, audio_length, interact = item
|
||||
item_user = interact.data.get('user', '')
|
||||
if item_user != username:
|
||||
temp_items.append(item) # 保留非目标用户的项
|
||||
# 目标用户的项直接丢弃(不添加到 temp_items)
|
||||
except queue.Empty:
|
||||
# 队列空了,正常退出循环
|
||||
pass
|
||||
|
||||
# 将保留的项重新放入队列(使用非阻塞方式)
|
||||
for item in temp_items:
|
||||
try:
|
||||
sound_queue.put_nowait(item) # 非阻塞放入
|
||||
except queue.Full:
|
||||
# 队列满的情况很少见,如果发生则记录日志
|
||||
util.printInfo(1, username, "音频队列已满,跳过部分音频项")
|
||||
break
|
||||
|
||||
|
||||
def _clear_audio_queue(self, username):
|
||||
"""
|
||||
清空指定用户的音频队列
|
||||
:param username: 用户名
|
||||
注意:此方法假设调用者已持有必要的锁
|
||||
"""
|
||||
fay_core = fay_booter.feiFei
|
||||
# 只清理特定用户的音频项,保留其他用户的音频
|
||||
self._clear_user_specific_audio(username, fay_core.sound_query)
|
||||
|
||||
def clear_Stream_with_audio(self, username):
|
||||
"""
|
||||
清除指定用户ID的文本流数据和音频队列(完全清除)
|
||||
注意:分步操作,避免锁嵌套
|
||||
:param username: 用户名
|
||||
"""
|
||||
# 第一步:切换会话版本,令现有读/写循环尽快退出
|
||||
# 不在清理时递增会话版本,由新交互开始时统一递增
|
||||
|
||||
# 第二步:设置停止标志(独立操作)
|
||||
with self.control_lock:
|
||||
self.stop_generation_flags[username] = True
|
||||
|
||||
# 第三步:清除音频队列(Queue线程安全,不需要锁)
|
||||
# -*- coding: utf-8 -*-
|
||||
import threading
|
||||
import time
|
||||
from utils import stream_sentence
|
||||
from scheduler.thread_manager import MyThread
|
||||
import fay_booter
|
||||
from core import member_db
|
||||
from core.interact import Interact
|
||||
|
||||
# 全局变量,用于存储StreamManager的单例实例
|
||||
__streams = None
|
||||
# 线程锁,用于保护全局变量的访问
|
||||
__streams_lock = threading.Lock()
|
||||
|
||||
def new_instance(max_sentences=1024):
|
||||
"""
|
||||
创建并返回StreamManager的单例实例
|
||||
:param max_sentences: 最大句子缓存数量
|
||||
:return: StreamManager实例
|
||||
"""
|
||||
global __streams
|
||||
with __streams_lock:
|
||||
if __streams is None:
|
||||
__streams = StreamManager(max_sentences)
|
||||
return __streams
|
||||
|
||||
class StreamManager:
|
||||
"""
|
||||
流管理器类,用于管理和处理文本流数据
|
||||
"""
|
||||
def __init__(self, max_sentences=3):
|
||||
"""
|
||||
初始化StreamManager
|
||||
:param max_sentences: 每个流的最大句子缓存数量
|
||||
"""
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
return
|
||||
# 使用两个独立的锁,避免死锁
|
||||
self.stream_lock = threading.RLock() # 流读写操作锁(可重入锁,允许同一线程多次获取)
|
||||
self.control_lock = threading.Lock() # 控制标志锁(用于停止生成标志)
|
||||
self.streams = {} # 存储用户ID到句子缓存的映射
|
||||
self.nlp_streams = {} # 存储用户ID到句子缓存的映射
|
||||
self.max_sentences = max_sentences # 最大句子缓存数量
|
||||
self.listener_threads = {} # 存储用户ID到监听线程的映射
|
||||
self.running = True # 控制监听线程的运行状态
|
||||
self._initialized = True # 标记是否已初始化
|
||||
self.msgid = "" # 消息ID
|
||||
self.stop_generation_flags = {} # 存储用户的停止生成标志
|
||||
self.conversation_ids = {} # 存储每个用户的会话ID(conv_前缀)
|
||||
|
||||
|
||||
def set_current_conversation(self, username, conversation_id, session_type=None):
|
||||
"""设置当前会话ID(conv_*)并对齐状态管理器的会话。
|
||||
session_type 可选;未提供则沿用已存在状态的类型或默认 'stream'。
|
||||
"""
|
||||
with self.control_lock:
|
||||
self.conversation_ids[username] = conversation_id
|
||||
|
||||
# 对齐 StreamStateManager 的会话,以防用户名级状态跨会话串线
|
||||
try:
|
||||
from utils.stream_state_manager import get_state_manager # 延迟导入避免循环依赖
|
||||
smgr = get_state_manager()
|
||||
info = smgr.get_session_info(username)
|
||||
if (not info) or (info.get('conversation_id') != conversation_id):
|
||||
smgr.start_new_session(
|
||||
username,
|
||||
session_type if session_type else (info.get('session_type') if info else 'stream'),
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
except Exception:
|
||||
# 状态对齐失败不阻断主流程
|
||||
pass
|
||||
|
||||
def get_conversation_id(self, username):
|
||||
"""获取当前会话ID(可能为空字符串)"""
|
||||
with self.control_lock:
|
||||
return self.conversation_ids.get(username, "")
|
||||
|
||||
def _get_Stream_internal(self, username):
|
||||
"""
|
||||
内部方法:获取指定用户ID的文本流(不加锁,调用者必须已持有stream_lock)
|
||||
:param username: 用户名
|
||||
:return: 对应的句子缓存对象
|
||||
"""
|
||||
if username not in self.streams or username not in self.nlp_streams:
|
||||
# 创建新的流缓存
|
||||
self.streams[username] = stream_sentence.SentenceCache(self.max_sentences)
|
||||
self.nlp_streams[username] = stream_sentence.SentenceCache(self.max_sentences)
|
||||
|
||||
# 启动监听线程(如果还没有)
|
||||
if username not in self.listener_threads:
|
||||
stream = self.streams[username]
|
||||
nlp_stream = self.nlp_streams[username]
|
||||
thread = MyThread(target=self.listen, args=(username, stream, nlp_stream), daemon=True)
|
||||
self.listener_threads[username] = thread
|
||||
thread.start()
|
||||
|
||||
return self.streams[username], self.nlp_streams[username]
|
||||
|
||||
def get_Stream(self, username):
|
||||
"""
|
||||
获取指定用户ID的文本流,如果不存在则创建新的(线程安全)
|
||||
:param username: 用户名
|
||||
:return: 对应的句子缓存对象
|
||||
"""
|
||||
# 使用stream_lock保护流的读写操作
|
||||
with self.stream_lock:
|
||||
return self._get_Stream_internal(username)
|
||||
|
||||
def write_sentence(self, username, sentence, conversation_id=None, session_version=None):
|
||||
"""
|
||||
写入句子到指定用户的文本流(线程安全)
|
||||
:param username: 用户名
|
||||
:param sentence: 要写入的句子
|
||||
:param conversation_id: 句子产生时的会话ID(可选,优先于版本判断)
|
||||
:param session_version: 句子产生时的会话版本(可选,兼容旧路径)
|
||||
:return: 写入是否成功
|
||||
"""
|
||||
# 检查句子长度,防止过大的句子导致内存问题
|
||||
if len(sentence) > 10240: # 10KB限制
|
||||
sentence = sentence[:10240]
|
||||
|
||||
# 若当前处于停止状态且这不是新会话的首句,则丢弃写入,避免残余输出
|
||||
with self.control_lock:
|
||||
stop_flag = self.stop_generation_flags.get(username, False)
|
||||
current_cid = self.conversation_ids.get(username, "")
|
||||
if stop_flag and ('_<isfirst>' not in sentence):
|
||||
return False
|
||||
|
||||
# 优先使用会话ID进行校验
|
||||
if conversation_id is not None and conversation_id != current_cid:
|
||||
return False
|
||||
# 兼容旧逻辑:按版本校验
|
||||
|
||||
|
||||
# 检查是否包含_<isfirst>标记(可能在句子中间)
|
||||
if '_<isfirst>' in sentence:
|
||||
# 收到新处理的第一个句子,重置停止标志,允许后续处理
|
||||
with self.control_lock:
|
||||
self.stop_generation_flags[username] = False
|
||||
|
||||
# 使用stream_lock保护写入操作
|
||||
with self.stream_lock:
|
||||
try:
|
||||
# 使用内部方法避免重复加锁
|
||||
Stream, nlp_Stream = self._get_Stream_internal(username)
|
||||
# 将会话ID以隐藏标签形式附在主流句子尾部,便于入口解析
|
||||
tag_cid = conversation_id if conversation_id is not None else current_cid
|
||||
tagged_sentence = f"{sentence}__<cid={tag_cid}>__" if tag_cid else sentence
|
||||
success = Stream.write(tagged_sentence)
|
||||
# 让 NLP 流也携带隐藏的会话ID,便于前端按会话过滤
|
||||
nlp_success = nlp_Stream.write(tagged_sentence)
|
||||
return success and nlp_success
|
||||
except Exception as e:
|
||||
print(f"写入句子时出错: {e}")
|
||||
return False
|
||||
|
||||
def _clear_Stream_internal(self, username):
|
||||
"""
|
||||
内部清除文本流方法,不获取锁(调用者必须已持有锁)
|
||||
:param username: 用户名
|
||||
"""
|
||||
if username in self.streams:
|
||||
self.streams[username].clear()
|
||||
if username in self.nlp_streams:
|
||||
self.nlp_streams[username].clear()
|
||||
|
||||
# 清除后写入一条结束标记,分别通知主流与NLP流结束
|
||||
try:
|
||||
# 确保流存在(监听线程也会在首次创建时启动)
|
||||
stream, nlp_stream = self._get_Stream_internal(username)
|
||||
cid = self.conversation_ids.get(username, "")
|
||||
end_marker = "_<isend>"
|
||||
# 主流带会话ID隐藏标签,供下游按会话拦截
|
||||
tagged = f"{end_marker}__<cid={cid}>__" if cid else end_marker
|
||||
stream.write(tagged)
|
||||
# NLP 流也写入带会话ID的结束标记,前端会按会话过滤
|
||||
nlp_stream.write(tagged)
|
||||
except Exception:
|
||||
# 忽略写入哨兵失败
|
||||
pass
|
||||
|
||||
def set_stop_generation(self, username, stop=True):
|
||||
"""
|
||||
设置指定用户的停止生成标志
|
||||
:param username: 用户名
|
||||
:param stop: 是否停止,默认True
|
||||
"""
|
||||
with self.control_lock:
|
||||
self.stop_generation_flags[username] = stop
|
||||
|
||||
def should_stop_generation(self, username, conversation_id=None, session_version=None):
|
||||
"""
|
||||
检查指定用户是否应该停止生成
|
||||
:param username: 用户名
|
||||
:return: 是否应该停止
|
||||
"""
|
||||
with self.control_lock:
|
||||
flag = self.stop_generation_flags.get(username, False)
|
||||
if flag:
|
||||
return True
|
||||
# 优先按会话ID判断
|
||||
current_cid = self.conversation_ids.get(username, "")
|
||||
if conversation_id is not None and conversation_id != current_cid:
|
||||
return True
|
||||
# 兼容旧逻辑:按版本判断
|
||||
return False
|
||||
|
||||
# 内部方法已移除,直接使用带锁的公共方法
|
||||
|
||||
def _clear_user_specific_audio(self, username, sound_queue):
|
||||
"""
|
||||
清理特定用户的音频队列项,保留其他用户的音频
|
||||
:param username: 要清理的用户名
|
||||
:param sound_queue: 音频队列
|
||||
"""
|
||||
import queue
|
||||
from utils import util
|
||||
temp_items = []
|
||||
|
||||
# 使用非阻塞方式提取所有项,避免死锁
|
||||
try:
|
||||
while True:
|
||||
item = sound_queue.get_nowait() # 非阻塞获取
|
||||
file_url, audio_length, interact = item
|
||||
item_user = interact.data.get('user', '')
|
||||
if item_user != username:
|
||||
temp_items.append(item) # 保留非目标用户的项
|
||||
# 目标用户的项直接丢弃(不添加到 temp_items)
|
||||
except queue.Empty:
|
||||
# 队列空了,正常退出循环
|
||||
pass
|
||||
|
||||
# 将保留的项重新放入队列(使用非阻塞方式)
|
||||
for item in temp_items:
|
||||
try:
|
||||
sound_queue.put_nowait(item) # 非阻塞放入
|
||||
except queue.Full:
|
||||
# 队列满的情况很少见,如果发生则记录日志
|
||||
util.printInfo(1, username, "音频队列已满,跳过部分音频项")
|
||||
break
|
||||
|
||||
|
||||
def _clear_audio_queue(self, username):
|
||||
"""
|
||||
清空指定用户的音频队列
|
||||
:param username: 用户名
|
||||
注意:此方法假设调用者已持有必要的锁
|
||||
"""
|
||||
fay_core = fay_booter.feiFei
|
||||
# 只清理特定用户的音频项,保留其他用户的音频
|
||||
self._clear_user_specific_audio(username, fay_core.sound_query)
|
||||
|
||||
def clear_Stream_with_audio(self, username):
|
||||
"""
|
||||
清除指定用户ID的文本流数据和音频队列(完全清除)
|
||||
注意:分步操作,避免锁嵌套
|
||||
:param username: 用户名
|
||||
"""
|
||||
# 第一步:切换会话版本,令现有读/写循环尽快退出
|
||||
# 不在清理时递增会话版本,由新交互开始时统一递增
|
||||
|
||||
# 第二步:设置停止标志(独立操作)
|
||||
with self.control_lock:
|
||||
self.stop_generation_flags[username] = True
|
||||
|
||||
# 第三步:清除音频队列(Queue线程安全,不需要锁)
|
||||
self._clear_audio_queue(username)
|
||||
|
||||
# reset think state for username on force stop
|
||||
@@ -276,71 +276,73 @@ class StreamManager:
|
||||
fei.think_mode_users[uid_tmp] = False
|
||||
if uid_tmp in getattr(fei, 'think_time_users', {}):
|
||||
del fei.think_time_users[uid_tmp]
|
||||
if uid_tmp in getattr(fei, 'think_display_state', {}):
|
||||
del fei.think_display_state[uid_tmp]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 第四步:清除文本流(独立操作)
|
||||
with self.stream_lock:
|
||||
self._clear_Stream_internal(username)
|
||||
|
||||
|
||||
|
||||
def listen(self, username, stream, nlp_stream):
|
||||
while self.running:
|
||||
sentence = stream.read()
|
||||
if sentence:
|
||||
self.execute(username, sentence)
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
def execute(self, username, sentence):
|
||||
"""
|
||||
执行句子处理逻辑
|
||||
:param username: 用户名
|
||||
:param sentence: 要处理的句子
|
||||
"""
|
||||
# 从句子尾部解析隐藏的会话ID标签
|
||||
producer_cid = None
|
||||
try:
|
||||
import re as _re
|
||||
m = _re.search(r"__<cid=([^>]+)>__", sentence)
|
||||
if m:
|
||||
producer_cid = m.group(1)
|
||||
sentence = sentence.replace(m.group(0), "")
|
||||
except Exception:
|
||||
producer_cid = None
|
||||
|
||||
# 检查停止标志(使用control_lock)
|
||||
with self.control_lock:
|
||||
should_stop = self.stop_generation_flags.get(username, False)
|
||||
|
||||
if should_stop:
|
||||
return
|
||||
|
||||
# 进一步进行基于会话ID/版本的快速拦截(避免进入下游 say)
|
||||
try:
|
||||
current_cid = getattr(self, 'conversation_ids', {}).get(username, "")
|
||||
check_cid = producer_cid if producer_cid is not None else current_cid
|
||||
if self.should_stop_generation(username, conversation_id=check_cid):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 处理句子标记(无锁,避免长时间持有锁)
|
||||
is_first = "_<isfirst>" in sentence
|
||||
is_end = "_<isend>" in sentence
|
||||
is_qa = "_<isqa>" in sentence
|
||||
sentence = sentence.replace("_<isfirst>", "").replace("_<isend>", "").replace("_<isqa>", "")
|
||||
|
||||
# 执行实际处理(无锁,避免死锁)
|
||||
if sentence or is_first or is_end or is_qa:
|
||||
fay_core = fay_booter.feiFei
|
||||
# 附带当前会话ID,方便下游按会话控制输出
|
||||
effective_cid = producer_cid if producer_cid is not None else getattr(self, 'conversation_ids', {}).get(username, "")
|
||||
interact = Interact("stream", 1, {"user": username, "msg": sentence, "isfirst": is_first, "isend": is_end, "conversation_id": effective_cid})
|
||||
fay_core.say(interact, sentence, type="qa" if is_qa else "") # 调用核心处理模块进行响应
|
||||
time.sleep(0.01) # 短暂休眠以控制处理频率
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 第四步:清除文本流(独立操作)
|
||||
with self.stream_lock:
|
||||
self._clear_Stream_internal(username)
|
||||
|
||||
|
||||
|
||||
def listen(self, username, stream, nlp_stream):
|
||||
while self.running:
|
||||
sentence = stream.read()
|
||||
if sentence:
|
||||
self.execute(username, sentence)
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
def execute(self, username, sentence):
|
||||
"""
|
||||
执行句子处理逻辑
|
||||
:param username: 用户名
|
||||
:param sentence: 要处理的句子
|
||||
"""
|
||||
# 从句子尾部解析隐藏的会话ID标签
|
||||
producer_cid = None
|
||||
try:
|
||||
import re as _re
|
||||
m = _re.search(r"__<cid=([^>]+)>__", sentence)
|
||||
if m:
|
||||
producer_cid = m.group(1)
|
||||
sentence = sentence.replace(m.group(0), "")
|
||||
except Exception:
|
||||
producer_cid = None
|
||||
|
||||
# 检查停止标志(使用control_lock)
|
||||
with self.control_lock:
|
||||
should_stop = self.stop_generation_flags.get(username, False)
|
||||
|
||||
if should_stop:
|
||||
return
|
||||
|
||||
# 进一步进行基于会话ID/版本的快速拦截(避免进入下游 say)
|
||||
try:
|
||||
current_cid = getattr(self, 'conversation_ids', {}).get(username, "")
|
||||
check_cid = producer_cid if producer_cid is not None else current_cid
|
||||
if self.should_stop_generation(username, conversation_id=check_cid):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 处理句子标记(无锁,避免长时间持有锁)
|
||||
is_first = "_<isfirst>" in sentence
|
||||
is_end = "_<isend>" in sentence
|
||||
is_qa = "_<isqa>" in sentence
|
||||
sentence = sentence.replace("_<isfirst>", "").replace("_<isend>", "").replace("_<isqa>", "")
|
||||
|
||||
# 执行实际处理(无锁,避免死锁)
|
||||
if sentence or is_first or is_end or is_qa:
|
||||
fay_core = fay_booter.feiFei
|
||||
# 附带当前会话ID,方便下游按会话控制输出
|
||||
effective_cid = producer_cid if producer_cid is not None else getattr(self, 'conversation_ids', {}).get(username, "")
|
||||
interact = Interact("stream", 1, {"user": username, "msg": sentence, "isfirst": is_first, "isend": is_end, "conversation_id": effective_cid})
|
||||
fay_core.say(interact, sentence, type="qa" if is_qa else "") # 调用核心处理模块进行响应
|
||||
time.sleep(0.01) # 短暂休眠以控制处理频率
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -383,24 +383,30 @@ def api_send_v1_chat_completions():
|
||||
|
||||
def generate():
|
||||
try:
|
||||
for line in resp.iter_lines(decode_unicode=True):
|
||||
if line is None:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
if not chunk:
|
||||
continue
|
||||
yield f"{line}\n"
|
||||
yield chunk
|
||||
finally:
|
||||
resp.close()
|
||||
|
||||
content_type = resp.headers.get("Content-Type", "text/event-stream")
|
||||
if "charset=" not in content_type.lower():
|
||||
content_type = f"{content_type}; charset=utf-8"
|
||||
return Response(
|
||||
generate(),
|
||||
status=resp.status_code,
|
||||
mimetype=resp.headers.get("Content-Type", "text/event-stream"),
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
resp = requests.post(llm_url, headers=headers, json=payload, timeout=60)
|
||||
content_type = resp.headers.get("Content-Type", "application/json")
|
||||
if "charset=" not in content_type.lower():
|
||||
content_type = f"{content_type}; charset=utf-8"
|
||||
return Response(
|
||||
resp.content,
|
||||
status=resp.status_code,
|
||||
content_type=resp.headers.get("Content-Type", "application/json"),
|
||||
content_type=content_type,
|
||||
)
|
||||
except Exception as exc:
|
||||
return jsonify({'error': f'LLM request failed: {exc}'}), 500
|
||||
|
||||
@@ -314,12 +314,14 @@ def _remove_prestart_from_text(text: str, keep_marked: bool = True) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _remove_think_from_text(text: str) -> str:
|
||||
"""从文本中移除 think 标签及其内容"""
|
||||
if not text:
|
||||
return text
|
||||
import re
|
||||
return re.sub(r'<think>[\s\S]*?</think>', '', text, flags=re.IGNORECASE).strip()
|
||||
def _remove_think_from_text(text: str) -> str:
|
||||
"""从文本中移除 think 标签及其内容"""
|
||||
if not text:
|
||||
return text
|
||||
import re
|
||||
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', text, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _format_conversation_block(conversation: List[Dict], username: str = "User") -> str:
|
||||
@@ -438,22 +440,35 @@ def _run_prestart_tools(user_question: str) -> List[Dict[str, Any]]:
|
||||
return results
|
||||
|
||||
|
||||
def _truncate_history(history: List[ToolResult], limit: int = 6) -> str:
|
||||
if not history:
|
||||
return "(暂无)"
|
||||
lines: List[str] = []
|
||||
for item in history[-limit:]:
|
||||
call = item.get("call", {})
|
||||
name = call.get("name", "未知工具")
|
||||
attempt = item.get("attempt", 0)
|
||||
success = item.get("success", False)
|
||||
status = "成功" if success else "失败"
|
||||
lines.append(f"- {name} 第 {attempt} 次 → {status}")
|
||||
if item.get("output"):
|
||||
lines.append(" 输出:" + _truncate_text(item["output"], 200))
|
||||
if item.get("error"):
|
||||
lines.append(" 错误:" + _truncate_text(item["error"], 200))
|
||||
return "\n".join(lines)
|
||||
def _truncate_history(
|
||||
history: List[ToolResult],
|
||||
limit: Optional[int] = None,
|
||||
output_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
if not history:
|
||||
return "(暂无)"
|
||||
lines: List[str] = []
|
||||
selected = history if limit is None else history[-limit:]
|
||||
for item in selected:
|
||||
call = item.get("call", {})
|
||||
name = call.get("name", "未知工具")
|
||||
attempt = item.get("attempt", 0)
|
||||
success = item.get("success", False)
|
||||
status = "成功" if success else "失败"
|
||||
lines.append(f"- {name} 第 {attempt} 次 → {status}")
|
||||
output = item.get("output")
|
||||
if output is not None:
|
||||
output_text = str(output)
|
||||
if output_limit is not None:
|
||||
output_text = _truncate_text(output_text, output_limit)
|
||||
lines.append(" 输出:" + output_text)
|
||||
error = item.get("error")
|
||||
if error is not None:
|
||||
error_text = str(error)
|
||||
if output_limit is not None:
|
||||
error_text = _truncate_text(error_text, output_limit)
|
||||
lines.append(" 错误:" + error_text)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_schema_parameters(schema: Dict[str, Any]) -> List[str]:
|
||||
@@ -583,9 +598,9 @@ def _build_planner_messages(state: AgentState) -> List[SystemMessage | HumanMess
|
||||
|
||||
# 生成对话文本,使用代码块包裹每条消息
|
||||
convo_text = _format_conversation_block(conversation, username)
|
||||
history_text = _truncate_history(history)
|
||||
tools_text = _format_tools_for_prompt(tool_specs)
|
||||
preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else ""
|
||||
history_text = _truncate_history(history)
|
||||
tools_text = _format_tools_for_prompt(tool_specs)
|
||||
preview_section = f"\n(规划器预览:{planner_preview})" if planner_preview else ""
|
||||
|
||||
# 只有当有预启动工具结果时才显示,工具名+参数在外,结果在代码块内
|
||||
if prestart_context and prestart_context.strip():
|
||||
@@ -667,9 +682,9 @@ def _build_final_messages(state: AgentState) -> List[SystemMessage | HumanMessag
|
||||
display_username = "主人" if username == "User" else username
|
||||
|
||||
# 生成对话文本,使用代码块包裹每条消息
|
||||
conversation_block = _format_conversation_block(conversation, username)
|
||||
history_text = _truncate_history(state.get("tool_results", []))
|
||||
preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else ""
|
||||
conversation_block = _format_conversation_block(conversation, username)
|
||||
history_text = _truncate_history(state.get("tool_results", []))
|
||||
preview_section = f"\n(规划器建议:{planner_preview})" if planner_preview else ""
|
||||
|
||||
# 只有当有预启动工具结果时才显示,工具名+参数在外,结果在代码块内
|
||||
if prestart_context and prestart_context.strip():
|
||||
@@ -2005,15 +2020,22 @@ def question(content, username, observation=None):
|
||||
# 创建规划器流式回调,用于实时输出 finish+message 响应
|
||||
planner_stream_buffer = {"text": "", "first_chunk": True}
|
||||
|
||||
def planner_stream_callback(chunk_text: str) -> None:
|
||||
"""规划器流式回调,将 message 内容实时输出"""
|
||||
nonlocal accumulated_text, full_response_text, is_first_sentence
|
||||
if not chunk_text:
|
||||
return
|
||||
planner_stream_buffer["text"] += chunk_text
|
||||
# 使用 stream_response_chunks 的逻辑进行分句流式输出
|
||||
accumulated_text += chunk_text
|
||||
full_response_text += chunk_text
|
||||
def planner_stream_callback(chunk_text: str) -> None:
|
||||
"""规划器流式回调,将 message 内容实时输出"""
|
||||
nonlocal accumulated_text, full_response_text, is_first_sentence, is_agent_think_start
|
||||
if not chunk_text:
|
||||
return
|
||||
planner_stream_buffer["text"] += chunk_text
|
||||
if planner_stream_buffer["first_chunk"]:
|
||||
planner_stream_buffer["first_chunk"] = False
|
||||
if is_agent_think_start:
|
||||
closing = "</think>"
|
||||
accumulated_text += closing
|
||||
full_response_text += closing
|
||||
is_agent_think_start = False
|
||||
# 使用 stream_response_chunks 的逻辑进行分句流式输出
|
||||
accumulated_text += chunk_text
|
||||
full_response_text += chunk_text
|
||||
# 检查是否有完整句子可以输出
|
||||
if len(accumulated_text) >= 20:
|
||||
while True:
|
||||
@@ -2485,27 +2507,46 @@ def save_agent_memory():
|
||||
agent.scratch = {}
|
||||
|
||||
# 保存记忆前进行完整性检查
|
||||
try:
|
||||
# 检查seq_nodes中的每个节点是否有效
|
||||
valid_nodes = []
|
||||
for node in agent.memory_stream.seq_nodes:
|
||||
if node is None:
|
||||
util.log(1, "发现无效节点(None),跳过")
|
||||
continue
|
||||
|
||||
if not hasattr(node, 'node_id') or not hasattr(node, 'content'):
|
||||
util.log(1, f"发现无效节点(缺少必要属性),跳过")
|
||||
continue
|
||||
|
||||
valid_nodes.append(node)
|
||||
|
||||
# 更新seq_nodes为有效节点列表
|
||||
agent.memory_stream.seq_nodes = valid_nodes
|
||||
|
||||
# 重建id_to_node字典
|
||||
agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')}
|
||||
except Exception as e:
|
||||
util.log(1, f"检查记忆完整性时出错: {str(e)}")
|
||||
try:
|
||||
# 检查seq_nodes中的每个节点是否有效
|
||||
valid_nodes = []
|
||||
for node in agent.memory_stream.seq_nodes:
|
||||
if node is None:
|
||||
util.log(1, "发现无效节点(None),跳过")
|
||||
continue
|
||||
|
||||
if not hasattr(node, 'node_id') or not hasattr(node, 'content'):
|
||||
util.log(1, f"发现无效节点(缺少必要属性),跳过")
|
||||
continue
|
||||
raw_content = node.content if isinstance(node.content, str) else str(node.content)
|
||||
cleaned_content = _remove_think_from_text(raw_content)
|
||||
if cleaned_content != raw_content:
|
||||
old_content = raw_content
|
||||
node.content = cleaned_content
|
||||
if (
|
||||
agent.memory_stream.embeddings is not None
|
||||
and old_content in agent.memory_stream.embeddings
|
||||
and cleaned_content not in agent.memory_stream.embeddings
|
||||
):
|
||||
agent.memory_stream.embeddings[cleaned_content] = agent.memory_stream.embeddings[old_content]
|
||||
else:
|
||||
node.content = raw_content
|
||||
valid_nodes.append(node)
|
||||
|
||||
# 更新seq_nodes为有效节点列表
|
||||
agent.memory_stream.seq_nodes = valid_nodes
|
||||
|
||||
# 重建id_to_node字典
|
||||
agent.memory_stream.id_to_node = {node.node_id: node for node in valid_nodes if hasattr(node, 'node_id')}
|
||||
if agent.memory_stream.embeddings is not None:
|
||||
kept_contents = {node.content for node in valid_nodes if hasattr(node, 'content')}
|
||||
agent.memory_stream.embeddings = {
|
||||
key: value
|
||||
for key, value in agent.memory_stream.embeddings.items()
|
||||
if key in kept_contents
|
||||
}
|
||||
except Exception as e:
|
||||
util.log(1, f"检查记忆完整性时出错: {str(e)}")
|
||||
|
||||
# 保存记忆
|
||||
try:
|
||||
|
||||
65
main.py
65
main.py
@@ -1,22 +1,33 @@
|
||||
#入口文件main
|
||||
import os
|
||||
os.environ['PATH'] += os.pathsep + os.path.join(os.getcwd(), "test", "ovr_lipsync", "ffmpeg", "bin")
|
||||
import sys
|
||||
import time
|
||||
import psutil
|
||||
import re
|
||||
import argparse
|
||||
import signal
|
||||
import atexit
|
||||
import threading
|
||||
from utils import config_util, util
|
||||
from asr import ali_nls
|
||||
from core import wsa_server
|
||||
from gui import flask_server
|
||||
from core import content_db
|
||||
import fay_booter
|
||||
from scheduler.thread_manager import MyThread
|
||||
from core.interact import Interact
|
||||
#入口文件main
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.environ['PATH'] += os.pathsep + os.path.join(os.getcwd(), "test", "ovr_lipsync", "ffmpeg", "bin")
|
||||
|
||||
def _preload_config_center(argv):
|
||||
for i, arg in enumerate(argv):
|
||||
if arg in ("-config_center", "--config_center"):
|
||||
if i + 1 < len(argv):
|
||||
os.environ["FAY_CONFIG_CENTER_ID"] = argv[i + 1]
|
||||
break
|
||||
|
||||
_preload_config_center(sys.argv[1:])
|
||||
|
||||
import time
|
||||
import psutil
|
||||
import re
|
||||
import argparse
|
||||
import signal
|
||||
import atexit
|
||||
import threading
|
||||
from utils import config_util, util
|
||||
from asr import ali_nls
|
||||
from core import wsa_server
|
||||
from gui import flask_server
|
||||
from core import content_db
|
||||
import fay_booter
|
||||
from scheduler.thread_manager import MyThread
|
||||
from core.interact import Interact
|
||||
|
||||
# import sys, io, traceback
|
||||
# class StdoutInterceptor(io.TextIOBase):
|
||||
@@ -247,12 +258,16 @@ if __name__ == '__main__':
|
||||
if config_util.start_mode == 'web':
|
||||
util.log(1, '请通过浏览器访问 http://127.0.0.1:5000/ 管理您的Fay')
|
||||
|
||||
parser = argparse.ArgumentParser(description="start自启动")
|
||||
parser.add_argument('command', nargs='?', default='', help="start")
|
||||
|
||||
parsed_args = parser.parse_args()
|
||||
if parsed_args.command.lower() == 'start':
|
||||
MyThread(target=fay_booter.start).start()
|
||||
parser = argparse.ArgumentParser(description="start自启动")
|
||||
parser.add_argument('command', nargs='?', default='', help="start")
|
||||
parser.add_argument('-config_center', '--config_center', dest='config_center', default=None, help="配置中心项目ID")
|
||||
|
||||
parsed_args = parser.parse_args()
|
||||
if parsed_args.config_center:
|
||||
os.environ["FAY_CONFIG_CENTER_ID"] = parsed_args.config_center
|
||||
config_util.CONFIG_SERVER['PROJECT_ID'] = parsed_args.config_center
|
||||
if parsed_args.command.lower() == 'start':
|
||||
MyThread(target=fay_booter.start).start()
|
||||
|
||||
|
||||
#普通模式下启动窗口
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
def test_gpt_nonstream(prompt):
|
||||
url = 'http://127.0.0.1:8000/v1/chat/completions' # 替换为您的接口地址
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer YOUR_API_KEY', # 如果您的接口需要身份验证
|
||||
}
|
||||
data = {
|
||||
'model': 'moonshotai/Kimi-K2-Instruct-0905',
|
||||
'messages': [
|
||||
{'role': 'system', 'content': prompt}
|
||||
],
|
||||
'stream': False # 禁用流式传输,使用非流式响应
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"请求失败,状态码:{response.status_code}")
|
||||
print(f"响应内容:{response.text}")
|
||||
return
|
||||
|
||||
# 处理非流式响应
|
||||
try:
|
||||
response_data = response.json()
|
||||
|
||||
# 从响应中提取内容
|
||||
choices = response_data.get('choices', [])
|
||||
if choices:
|
||||
message = choices[0].get('message', {})
|
||||
content = message.get('content', '')
|
||||
print(f"完整响应内容: {content}")
|
||||
|
||||
# 打印一些额外的响应信息
|
||||
print(f"\n请求ID: {response_data.get('id', 'N/A')}")
|
||||
print(f"模型: {response_data.get('model', 'N/A')}")
|
||||
|
||||
# 打印使用量信息
|
||||
usage = response_data.get('usage', {})
|
||||
if usage:
|
||||
print(f"Token 使用情况:")
|
||||
print(f" - 提示词 tokens: {usage.get('prompt_tokens', 0)}")
|
||||
print(f" - 补全 tokens: {usage.get('completion_tokens', 0)}")
|
||||
print(f" - 总计 tokens: {usage.get('total_tokens', 0)}")
|
||||
|
||||
return content
|
||||
except json.JSONDecodeError:
|
||||
print(f"无法解析响应数据为JSON: {response.text}")
|
||||
except Exception as e:
|
||||
print(f"处理响应时出错: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
user_input = "哈哈"
|
||||
print("发送请求到 GPT API (非流式模式)...")
|
||||
print("-" * 50)
|
||||
test_gpt_nonstream(user_input)
|
||||
print("-" * 50)
|
||||
print("请求完成")
|
||||
@@ -8,7 +8,7 @@ def test_gpt(prompt, username="张三", observation="", no_reply=False):
|
||||
'Authorization': f'Bearer YOUR_API_KEY', # 如果您的接口需要身份验证
|
||||
}
|
||||
data = {
|
||||
'model': 'fay-streaming',
|
||||
'model': 'fay-streaming', #model为llm时,会直接透传到上游的llm输出,fay不作任何处理、记录
|
||||
'messages': [
|
||||
{'role': username, 'content': prompt}
|
||||
],
|
||||
@@ -47,7 +47,8 @@ def test_gpt(prompt, username="张三", observation="", no_reply=False):
|
||||
if choices:
|
||||
delta = choices[0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
print(content, end='', flush=True)
|
||||
if content:
|
||||
print(content, end='', flush=True)
|
||||
except json.JSONDecodeError:
|
||||
print(f"\n无法解析的 JSON 数据:{line}")
|
||||
else:
|
||||
@@ -90,7 +91,7 @@ if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("示例1:张三的对话(带观察数据)")
|
||||
print("=" * 60)
|
||||
test_gpt("你好,今天天气不错啊", username="张三", observation=OBSERVATION_SAMPLES["张三"])
|
||||
test_gpt("你好,今天天气不错啊", username="user", observation=OBSERVATION_SAMPLES["张三"])
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
@@ -1,58 +1,58 @@
|
||||
import os
|
||||
import json
|
||||
import codecs
|
||||
from langsmith.schemas import Feedback
|
||||
import requests
|
||||
from configparser import ConfigParser
|
||||
import functools
|
||||
from threading import Lock
|
||||
import threading
|
||||
from utils import util
|
||||
|
||||
# 线程本地存储,用于支持多个项目配置
|
||||
_thread_local = threading.local()
|
||||
|
||||
# 全局锁,确保线程安全
|
||||
lock = Lock()
|
||||
def synchronized(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with lock:
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
# 默认配置,用于全局访问
|
||||
config: json = None
|
||||
system_config: ConfigParser = None
|
||||
system_chrome_driver = None
|
||||
key_ali_nls_key_id = None
|
||||
key_ali_nls_key_secret = None
|
||||
key_ali_nls_app_key = None
|
||||
key_ms_tts_key = None
|
||||
Key_ms_tts_region = None
|
||||
baidu_emotion_app_id = None
|
||||
baidu_emotion_api_key = None
|
||||
baidu_emotion_secret_key = None
|
||||
key_gpt_api_key = None
|
||||
gpt_model_engine = None
|
||||
proxy_config = None
|
||||
ASR_mode = None
|
||||
local_asr_ip = None
|
||||
local_asr_port = None
|
||||
ltp_mode = None
|
||||
gpt_base_url = None
|
||||
tts_module = None
|
||||
key_ali_tss_key_id = None
|
||||
key_ali_tss_key_secret = None
|
||||
key_ali_tss_app_key = None
|
||||
volcano_tts_appid = None
|
||||
volcano_tts_access_token = None
|
||||
volcano_tts_cluster = None
|
||||
volcano_tts_voice_type = None
|
||||
start_mode = None
|
||||
fay_url = None
|
||||
system_conf_path = None
|
||||
config_json_path = None
|
||||
import os
|
||||
import json
|
||||
import codecs
|
||||
from langsmith.schemas import Feedback
|
||||
import requests
|
||||
from configparser import ConfigParser
|
||||
import functools
|
||||
from threading import Lock
|
||||
import threading
|
||||
from utils import util
|
||||
|
||||
# 线程本地存储,用于支持多个项目配置
|
||||
_thread_local = threading.local()
|
||||
|
||||
# 全局锁,确保线程安全
|
||||
lock = Lock()
|
||||
def synchronized(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with lock:
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
# 默认配置,用于全局访问
|
||||
config: json = None
|
||||
system_config: ConfigParser = None
|
||||
system_chrome_driver = None
|
||||
key_ali_nls_key_id = None
|
||||
key_ali_nls_key_secret = None
|
||||
key_ali_nls_app_key = None
|
||||
key_ms_tts_key = None
|
||||
Key_ms_tts_region = None
|
||||
baidu_emotion_app_id = None
|
||||
baidu_emotion_api_key = None
|
||||
baidu_emotion_secret_key = None
|
||||
key_gpt_api_key = None
|
||||
gpt_model_engine = None
|
||||
proxy_config = None
|
||||
ASR_mode = None
|
||||
local_asr_ip = None
|
||||
local_asr_port = None
|
||||
ltp_mode = None
|
||||
gpt_base_url = None
|
||||
tts_module = None
|
||||
key_ali_tss_key_id = None
|
||||
key_ali_tss_key_secret = None
|
||||
key_ali_tss_app_key = None
|
||||
volcano_tts_appid = None
|
||||
volcano_tts_access_token = None
|
||||
volcano_tts_cluster = None
|
||||
volcano_tts_voice_type = None
|
||||
start_mode = None
|
||||
fay_url = None
|
||||
system_conf_path = None
|
||||
config_json_path = None
|
||||
use_bionic_memory = None
|
||||
|
||||
# Embedding API 配置全局变量
|
||||
@@ -60,303 +60,411 @@ embedding_api_model = None
|
||||
embedding_api_base_url = None
|
||||
embedding_api_key = None
|
||||
|
||||
# config server中心配置,system.conf与config.json存在时不会使用配置中心
|
||||
# 避免重复加载配置中心导致日志刷屏
|
||||
_last_loaded_project_id = None
|
||||
_last_loaded_config = None
|
||||
_last_loaded_from_api = False # 表示上次加载来自配置中心(含缓存)
|
||||
_warned_public_project_ids = set()
|
||||
|
||||
# config server中心配置,system.conf与config.json存在时不会使用配置中心
|
||||
CONFIG_SERVER = {
|
||||
'BASE_URL': 'http://1.12.69.110:5500', # 默认API服务器地址
|
||||
'API_KEY': 'your-api-key-here', # 默认API密钥
|
||||
'PROJECT_ID': 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' # 项目ID,需要在使用前设置
|
||||
}
|
||||
|
||||
def _refresh_config_center():
|
||||
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
|
||||
if env_project_id:
|
||||
CONFIG_SERVER['PROJECT_ID'] = env_project_id
|
||||
|
||||
_refresh_config_center()
|
||||
|
||||
def load_config_from_api(project_id=None):
|
||||
global CONFIG_SERVER
|
||||
|
||||
"""
|
||||
从API加载配置
|
||||
|
||||
Args:
|
||||
project_id: 项目ID,如果为None则使用全局设置的项目ID
|
||||
|
||||
Returns:
|
||||
包含配置信息的字典,加载失败则返回None
|
||||
"""
|
||||
# 使用参数提供的项目ID或全局设置的项目ID
|
||||
pid = project_id or CONFIG_SERVER['PROJECT_ID']
|
||||
if not pid:
|
||||
util.log(2, "错误: 未指定项目ID,无法从API加载配置")
|
||||
return None
|
||||
|
||||
# 构建API请求URL
|
||||
url = f"{CONFIG_SERVER['BASE_URL']}/api/projects/{pid}/config"
|
||||
|
||||
# 设置请求头
|
||||
headers = {
|
||||
'X-API-Key': CONFIG_SERVER['API_KEY'],
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送API请求
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('success'):
|
||||
# 提取配置数据
|
||||
project_data = result.get('project', {})
|
||||
|
||||
# 创建并填充ConfigParser对象
|
||||
sys_config = ConfigParser()
|
||||
sys_config.add_section('key')
|
||||
|
||||
# 获取系统配置字典
|
||||
system_dict = project_data.get('system_config', {})
|
||||
for section, items in system_dict.items():
|
||||
if not sys_config.has_section(section):
|
||||
sys_config.add_section(section)
|
||||
for key, value in items.items():
|
||||
sys_config.set(section, key, str(value))
|
||||
|
||||
# 获取用户配置
|
||||
user_config = project_data.get('config_json', {})
|
||||
|
||||
# 创建配置字典
|
||||
config_dict = {
|
||||
'system_config': sys_config,
|
||||
'config': user_config,
|
||||
'project_id': pid,
|
||||
'name': project_data.get('name', ''),
|
||||
'description': project_data.get('description', ''),
|
||||
'source': 'api' # 标记配置来源
|
||||
}
|
||||
|
||||
# 提取所有配置项到配置字典
|
||||
for section in sys_config.sections():
|
||||
for key, value in sys_config.items(section):
|
||||
config_dict[f'{section}_{key}'] = value
|
||||
|
||||
return config_dict
|
||||
else:
|
||||
util.log(2, f"API错误: {result.get('message', '未知错误')}")
|
||||
else:
|
||||
util.log(2, f"API请求失败: HTTP状态码 {response.status_code}")
|
||||
except Exception as e:
|
||||
util.log(2, f"从API加载配置时出错: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
"""
|
||||
从API加载配置
|
||||
|
||||
Args:
|
||||
project_id: 项目ID,如果为None则使用全局设置的项目ID
|
||||
|
||||
Returns:
|
||||
包含配置信息的字典,加载失败则返回None
|
||||
"""
|
||||
# 使用参数提供的项目ID或全局设置的项目ID
|
||||
pid = project_id or CONFIG_SERVER['PROJECT_ID']
|
||||
if not pid:
|
||||
util.log(2, "错误: 未指定项目ID,无法从API加载配置")
|
||||
return None
|
||||
|
||||
# 构建API请求URL
|
||||
url = f"{CONFIG_SERVER['BASE_URL']}/api/projects/{pid}/config"
|
||||
|
||||
# 设置请求头
|
||||
headers = {
|
||||
'X-API-Key': CONFIG_SERVER['API_KEY'],
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送API请求
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('success'):
|
||||
# 提取配置数据
|
||||
project_data = result.get('project', {})
|
||||
|
||||
# 创建并填充ConfigParser对象
|
||||
sys_config = ConfigParser()
|
||||
sys_config.add_section('key')
|
||||
|
||||
# 获取系统配置字典
|
||||
system_dict = project_data.get('system_config', {})
|
||||
for section, items in system_dict.items():
|
||||
if not sys_config.has_section(section):
|
||||
sys_config.add_section(section)
|
||||
for key, value in items.items():
|
||||
sys_config.set(section, key, str(value))
|
||||
|
||||
# 获取用户配置
|
||||
user_config = project_data.get('config_json', {})
|
||||
|
||||
# 创建配置字典
|
||||
config_dict = {
|
||||
'system_config': sys_config,
|
||||
'config': user_config,
|
||||
'project_id': pid,
|
||||
'name': project_data.get('name', ''),
|
||||
'description': project_data.get('description', ''),
|
||||
'source': 'api' # 标记配置来源
|
||||
}
|
||||
|
||||
# 提取所有配置项到配置字典
|
||||
for section in sys_config.sections():
|
||||
for key, value in sys_config.items(section):
|
||||
config_dict[f'{section}_{key}'] = value
|
||||
|
||||
return config_dict
|
||||
else:
|
||||
util.log(2, f"API错误: {result.get('message', '未知错误')}")
|
||||
else:
|
||||
util.log(2, f"API请求失败: HTTP状态码 {response.status_code}")
|
||||
except Exception as e:
|
||||
util.log(2, f"从API加载配置时出错: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
@synchronized
|
||||
def load_config():
|
||||
def load_config(force_reload=False):
|
||||
"""
|
||||
加载配置文件,如果本地文件不存在则直接使用API加载
|
||||
|
||||
|
||||
Returns:
|
||||
包含配置信息的字典
|
||||
"""
|
||||
global config
|
||||
global system_config
|
||||
global key_ali_nls_key_id
|
||||
global key_ali_nls_key_secret
|
||||
global key_ali_nls_app_key
|
||||
global key_ms_tts_key
|
||||
global key_ms_tts_region
|
||||
global baidu_emotion_app_id
|
||||
global baidu_emotion_secret_key
|
||||
global baidu_emotion_api_key
|
||||
global key_gpt_api_key
|
||||
global gpt_model_engine
|
||||
global proxy_config
|
||||
global ASR_mode
|
||||
global local_asr_ip
|
||||
global local_asr_port
|
||||
global ltp_mode
|
||||
global gpt_base_url
|
||||
global tts_module
|
||||
global key_ali_tss_key_id
|
||||
global key_ali_tss_key_secret
|
||||
global key_ali_tss_app_key
|
||||
global volcano_tts_appid
|
||||
global volcano_tts_access_token
|
||||
global volcano_tts_cluster
|
||||
global volcano_tts_voice_type
|
||||
global start_mode
|
||||
global fay_url
|
||||
global use_bionic_memory
|
||||
global embedding_api_model
|
||||
global embedding_api_base_url
|
||||
global embedding_api_key
|
||||
|
||||
global system_config
|
||||
global key_ali_nls_key_id
|
||||
global key_ali_nls_key_secret
|
||||
global key_ali_nls_app_key
|
||||
global key_ms_tts_key
|
||||
global key_ms_tts_region
|
||||
global baidu_emotion_app_id
|
||||
global baidu_emotion_secret_key
|
||||
global baidu_emotion_api_key
|
||||
global key_gpt_api_key
|
||||
global gpt_model_engine
|
||||
global proxy_config
|
||||
global ASR_mode
|
||||
global local_asr_ip
|
||||
global local_asr_port
|
||||
global ltp_mode
|
||||
global gpt_base_url
|
||||
global tts_module
|
||||
global key_ali_tss_key_id
|
||||
global key_ali_tss_key_secret
|
||||
global key_ali_tss_app_key
|
||||
global volcano_tts_appid
|
||||
global volcano_tts_access_token
|
||||
global volcano_tts_cluster
|
||||
global volcano_tts_voice_type
|
||||
global start_mode
|
||||
global fay_url
|
||||
global use_bionic_memory
|
||||
global embedding_api_model
|
||||
global embedding_api_base_url
|
||||
global embedding_api_key
|
||||
|
||||
global CONFIG_SERVER
|
||||
global system_conf_path
|
||||
global config_json_path
|
||||
global _last_loaded_project_id
|
||||
global _last_loaded_config
|
||||
global _last_loaded_from_api
|
||||
|
||||
_refresh_config_center()
|
||||
|
||||
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
|
||||
using_config_center = bool(env_project_id)
|
||||
if (
|
||||
env_project_id
|
||||
and not force_reload
|
||||
and _last_loaded_config is not None
|
||||
and _last_loaded_project_id == env_project_id
|
||||
and _last_loaded_from_api
|
||||
):
|
||||
return _last_loaded_config
|
||||
|
||||
default_system_conf_path = os.path.join(os.getcwd(), 'system.conf')
|
||||
default_config_json_path = os.path.join(os.getcwd(), 'config.json')
|
||||
cache_system_conf_path = os.path.join(os.getcwd(), 'cache_data', 'system.conf')
|
||||
cache_config_json_path = os.path.join(os.getcwd(), 'cache_data', 'config.json')
|
||||
|
||||
# 构建system.conf和config.json的完整路径
|
||||
if system_conf_path is None or config_json_path is None:
|
||||
system_conf_path = os.path.join(os.getcwd(), 'system.conf')
|
||||
config_json_path = os.path.join(os.getcwd(), 'config.json')
|
||||
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
|
||||
# 如果任一本地文件不存在,直接尝试从API加载
|
||||
if not sys_conf_exists or not config_json_exists:
|
||||
|
||||
# 使用提取的项目ID或全局项目ID
|
||||
util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'和' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在),尝试从API加载配置...")
|
||||
if using_config_center:
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
if (
|
||||
system_conf_path is None
|
||||
or config_json_path is None
|
||||
or system_conf_path == cache_system_conf_path
|
||||
or config_json_path == cache_config_json_path
|
||||
):
|
||||
system_conf_path = default_system_conf_path
|
||||
config_json_path = default_config_json_path
|
||||
|
||||
forced_loaded = False
|
||||
loaded_from_api = False
|
||||
api_attempted = False
|
||||
if using_config_center:
|
||||
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
|
||||
api_attempted = True
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = os.path.join(os.getcwd(), 'cache_data', 'system.conf')
|
||||
config_json_path = os.path.join(os.getcwd(), 'cache_data', 'config.json')
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
forced_loaded = True
|
||||
|
||||
if CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69':
|
||||
if (
|
||||
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
|
||||
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
|
||||
):
|
||||
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
|
||||
print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m")
|
||||
# 如果本地文件存在,从本地文件加载
|
||||
# 加载system.conf
|
||||
system_config = ConfigParser()
|
||||
system_config.read(system_conf_path, encoding='UTF-8')
|
||||
else:
|
||||
util.log(2, "配置中心加载失败,尝试使用缓存配置")
|
||||
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
|
||||
# 从system.conf中读取所有配置项
|
||||
key_ali_nls_key_id = system_config.get('key', 'ali_nls_key_id', fallback=None)
|
||||
key_ali_nls_key_secret = system_config.get('key', 'ali_nls_key_secret', fallback=None)
|
||||
key_ali_nls_app_key = system_config.get('key', 'ali_nls_app_key', fallback=None)
|
||||
key_ali_tss_key_id = system_config.get('key', 'ali_tss_key_id', fallback=None)
|
||||
key_ali_tss_key_secret = system_config.get('key', 'ali_tss_key_secret', fallback=None)
|
||||
key_ali_tss_app_key = system_config.get('key', 'ali_tss_app_key', fallback=None)
|
||||
key_ms_tts_key = system_config.get('key', 'ms_tts_key', fallback=None)
|
||||
key_ms_tts_region = system_config.get('key', 'ms_tts_region', fallback=None)
|
||||
baidu_emotion_app_id = system_config.get('key', 'baidu_emotion_app_id', fallback=None)
|
||||
baidu_emotion_api_key = system_config.get('key', 'baidu_emotion_api_key', fallback=None)
|
||||
baidu_emotion_secret_key = system_config.get('key', 'baidu_emotion_secret_key', fallback=None)
|
||||
key_gpt_api_key = system_config.get('key', 'gpt_api_key', fallback=None)
|
||||
gpt_model_engine = system_config.get('key', 'gpt_model_engine', fallback=None)
|
||||
ASR_mode = system_config.get('key', 'ASR_mode', fallback=None)
|
||||
local_asr_ip = system_config.get('key', 'local_asr_ip', fallback=None)
|
||||
local_asr_port = system_config.get('key', 'local_asr_port', fallback=None)
|
||||
proxy_config = system_config.get('key', 'proxy_config', fallback=None)
|
||||
ltp_mode = system_config.get('key', 'ltp_mode', fallback=None)
|
||||
gpt_base_url = system_config.get('key', 'gpt_base_url', fallback=None)
|
||||
tts_module = system_config.get('key', 'tts_module', fallback=None)
|
||||
volcano_tts_appid = system_config.get('key', 'volcano_tts_appid', fallback=None)
|
||||
volcano_tts_access_token = system_config.get('key', 'volcano_tts_access_token', fallback=None)
|
||||
volcano_tts_cluster = system_config.get('key', 'volcano_tts_cluster', fallback=None)
|
||||
volcano_tts_voice_type = system_config.get('key', 'volcano_tts_voice_type', fallback=None)
|
||||
# 如果任一本地文件不存在,直接尝试从API加载
|
||||
if (not sys_conf_exists or not config_json_exists) and not forced_loaded:
|
||||
if using_config_center:
|
||||
if not api_attempted:
|
||||
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
api_attempted = True
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
|
||||
# 读取 Embedding API 配置(复用 LLM 的 url 和 key)
|
||||
embedding_api_model = system_config.get('key', 'embedding_api_model', fallback='BAAI/bge-large-zh-v1.5')
|
||||
embedding_api_base_url = gpt_base_url # 复用 LLM base_url
|
||||
embedding_api_key = key_gpt_api_key # 复用 LLM api_key
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
|
||||
start_mode = system_config.get('key', 'start_mode', fallback=None)
|
||||
fay_url = system_config.get('key', 'fay_url', fallback=None)
|
||||
# 如果fay_url为空或None,则动态获取本机IP地址
|
||||
if not fay_url:
|
||||
from utils.util import get_local_ip
|
||||
local_ip = get_local_ip()
|
||||
fay_url = f"http://{local_ip}:5000"
|
||||
# 更新system_config中的值,但不写入文件
|
||||
if not system_config.has_section('key'):
|
||||
system_config.add_section('key')
|
||||
system_config.set('key', 'fay_url', fay_url)
|
||||
|
||||
# 读取用户配置
|
||||
with codecs.open(config_json_path, encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
if (
|
||||
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
|
||||
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
|
||||
):
|
||||
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
|
||||
print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m")
|
||||
else:
|
||||
# 使用提取的项目ID或全局项目ID
|
||||
util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'和' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在),尝试从API加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
|
||||
# 读取仿生记忆配置
|
||||
use_bionic_memory = config.get('memory', {}).get('use_bionic_memory', False)
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
|
||||
# 构建配置字典
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
|
||||
if (
|
||||
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
|
||||
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
|
||||
):
|
||||
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
|
||||
print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m")
|
||||
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
if using_config_center and (not sys_conf_exists or not config_json_exists):
|
||||
if _last_loaded_config is not None and _last_loaded_from_api:
|
||||
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
|
||||
return _last_loaded_config
|
||||
# 如果本地文件存在,从本地文件加载
|
||||
# 加载system.conf
|
||||
system_config = ConfigParser()
|
||||
system_config.read(system_conf_path, encoding='UTF-8')
|
||||
|
||||
# 从system.conf中读取所有配置项
|
||||
key_ali_nls_key_id = system_config.get('key', 'ali_nls_key_id', fallback=None)
|
||||
key_ali_nls_key_secret = system_config.get('key', 'ali_nls_key_secret', fallback=None)
|
||||
key_ali_nls_app_key = system_config.get('key', 'ali_nls_app_key', fallback=None)
|
||||
key_ali_tss_key_id = system_config.get('key', 'ali_tss_key_id', fallback=None)
|
||||
key_ali_tss_key_secret = system_config.get('key', 'ali_tss_key_secret', fallback=None)
|
||||
key_ali_tss_app_key = system_config.get('key', 'ali_tss_app_key', fallback=None)
|
||||
key_ms_tts_key = system_config.get('key', 'ms_tts_key', fallback=None)
|
||||
key_ms_tts_region = system_config.get('key', 'ms_tts_region', fallback=None)
|
||||
baidu_emotion_app_id = system_config.get('key', 'baidu_emotion_app_id', fallback=None)
|
||||
baidu_emotion_api_key = system_config.get('key', 'baidu_emotion_api_key', fallback=None)
|
||||
baidu_emotion_secret_key = system_config.get('key', 'baidu_emotion_secret_key', fallback=None)
|
||||
key_gpt_api_key = system_config.get('key', 'gpt_api_key', fallback=None)
|
||||
gpt_model_engine = system_config.get('key', 'gpt_model_engine', fallback=None)
|
||||
ASR_mode = system_config.get('key', 'ASR_mode', fallback=None)
|
||||
local_asr_ip = system_config.get('key', 'local_asr_ip', fallback=None)
|
||||
local_asr_port = system_config.get('key', 'local_asr_port', fallback=None)
|
||||
proxy_config = system_config.get('key', 'proxy_config', fallback=None)
|
||||
ltp_mode = system_config.get('key', 'ltp_mode', fallback=None)
|
||||
gpt_base_url = system_config.get('key', 'gpt_base_url', fallback=None)
|
||||
tts_module = system_config.get('key', 'tts_module', fallback=None)
|
||||
volcano_tts_appid = system_config.get('key', 'volcano_tts_appid', fallback=None)
|
||||
volcano_tts_access_token = system_config.get('key', 'volcano_tts_access_token', fallback=None)
|
||||
volcano_tts_cluster = system_config.get('key', 'volcano_tts_cluster', fallback=None)
|
||||
volcano_tts_voice_type = system_config.get('key', 'volcano_tts_voice_type', fallback=None)
|
||||
|
||||
# 读取 Embedding API 配置(复用 LLM 的 url 和 key)
|
||||
embedding_api_model = system_config.get('key', 'embedding_api_model', fallback='BAAI/bge-large-zh-v1.5')
|
||||
embedding_api_base_url = gpt_base_url # 复用 LLM base_url
|
||||
embedding_api_key = key_gpt_api_key # 复用 LLM api_key
|
||||
|
||||
start_mode = system_config.get('key', 'start_mode', fallback=None)
|
||||
fay_url = system_config.get('key', 'fay_url', fallback=None)
|
||||
# 如果fay_url为空或None,则动态获取本机IP地址
|
||||
if not fay_url:
|
||||
from utils.util import get_local_ip
|
||||
local_ip = get_local_ip()
|
||||
fay_url = f"http://{local_ip}:5000"
|
||||
# 更新system_config中的值,但不写入文件
|
||||
if not system_config.has_section('key'):
|
||||
system_config.add_section('key')
|
||||
system_config.set('key', 'fay_url', fay_url)
|
||||
|
||||
# 读取用户配置
|
||||
with codecs.open(config_json_path, encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 读取仿生记忆配置
|
||||
use_bionic_memory = config.get('memory', {}).get('use_bionic_memory', False)
|
||||
|
||||
# 构建配置字典
|
||||
config_dict = {
|
||||
'system_config': system_config,
|
||||
'config': config,
|
||||
'ali_nls_key_id': key_ali_nls_key_id,
|
||||
'ali_nls_key_secret': key_ali_nls_key_secret,
|
||||
'ali_nls_app_key': key_ali_nls_app_key,
|
||||
'ms_tts_key': key_ms_tts_key,
|
||||
'ms_tts_region': key_ms_tts_region,
|
||||
'baidu_emotion_app_id': baidu_emotion_app_id,
|
||||
'baidu_emotion_api_key': baidu_emotion_api_key,
|
||||
'baidu_emotion_secret_key': baidu_emotion_secret_key,
|
||||
'gpt_api_key': key_gpt_api_key,
|
||||
'gpt_model_engine': gpt_model_engine,
|
||||
'ASR_mode': ASR_mode,
|
||||
'local_asr_ip': local_asr_ip,
|
||||
'local_asr_port': local_asr_port,
|
||||
'proxy_config': proxy_config,
|
||||
'ltp_mode': ltp_mode,
|
||||
|
||||
'gpt_base_url': gpt_base_url,
|
||||
'tts_module': tts_module,
|
||||
'ali_tss_key_id': key_ali_tss_key_id,
|
||||
'ali_tss_key_secret': key_ali_tss_key_secret,
|
||||
'ali_tss_app_key': key_ali_tss_app_key,
|
||||
'volcano_tts_appid': volcano_tts_appid,
|
||||
'volcano_tts_access_token': volcano_tts_access_token,
|
||||
'volcano_tts_cluster': volcano_tts_cluster,
|
||||
'volcano_tts_voice_type': volcano_tts_voice_type,
|
||||
|
||||
'start_mode': start_mode,
|
||||
'fay_url': fay_url,
|
||||
'use_bionic_memory': use_bionic_memory,
|
||||
|
||||
# Embedding API 配置
|
||||
'embedding_api_model': embedding_api_model,
|
||||
'embedding_api_base_url': embedding_api_base_url,
|
||||
'embedding_api_key': embedding_api_key,
|
||||
|
||||
'source': 'local' # 标记配置来源
|
||||
'ali_nls_key_id': key_ali_nls_key_id,
|
||||
'ali_nls_key_secret': key_ali_nls_key_secret,
|
||||
'ali_nls_app_key': key_ali_nls_app_key,
|
||||
'ms_tts_key': key_ms_tts_key,
|
||||
'ms_tts_region': key_ms_tts_region,
|
||||
'baidu_emotion_app_id': baidu_emotion_app_id,
|
||||
'baidu_emotion_api_key': baidu_emotion_api_key,
|
||||
'baidu_emotion_secret_key': baidu_emotion_secret_key,
|
||||
'gpt_api_key': key_gpt_api_key,
|
||||
'gpt_model_engine': gpt_model_engine,
|
||||
'ASR_mode': ASR_mode,
|
||||
'local_asr_ip': local_asr_ip,
|
||||
'local_asr_port': local_asr_port,
|
||||
'proxy_config': proxy_config,
|
||||
'ltp_mode': ltp_mode,
|
||||
|
||||
'gpt_base_url': gpt_base_url,
|
||||
'tts_module': tts_module,
|
||||
'ali_tss_key_id': key_ali_tss_key_id,
|
||||
'ali_tss_key_secret': key_ali_tss_key_secret,
|
||||
'ali_tss_app_key': key_ali_tss_app_key,
|
||||
'volcano_tts_appid': volcano_tts_appid,
|
||||
'volcano_tts_access_token': volcano_tts_access_token,
|
||||
'volcano_tts_cluster': volcano_tts_cluster,
|
||||
'volcano_tts_voice_type': volcano_tts_voice_type,
|
||||
|
||||
'start_mode': start_mode,
|
||||
'fay_url': fay_url,
|
||||
'use_bionic_memory': use_bionic_memory,
|
||||
|
||||
# Embedding API 配置
|
||||
'embedding_api_model': embedding_api_model,
|
||||
'embedding_api_base_url': embedding_api_base_url,
|
||||
'embedding_api_key': embedding_api_key,
|
||||
|
||||
'source': 'api' if using_config_center else 'local' # 标记配置来源
|
||||
}
|
||||
|
||||
_last_loaded_project_id = CONFIG_SERVER['PROJECT_ID'] if using_config_center else None
|
||||
_last_loaded_config = config_dict
|
||||
_last_loaded_from_api = using_config_center
|
||||
|
||||
return config_dict
|
||||
|
||||
def save_api_config_to_local(api_config, system_conf_path, config_json_path):
|
||||
"""
|
||||
将API加载的配置保存到本地文件
|
||||
|
||||
Args:
|
||||
api_config: API加载的配置字典
|
||||
system_conf_path: system.conf文件路径
|
||||
config_json_path: config.json文件路径
|
||||
"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(system_conf_path), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(config_json_path), exist_ok=True)
|
||||
|
||||
# 保存system.conf
|
||||
with open(system_conf_path, 'w', encoding='utf-8') as f:
|
||||
api_config['system_config'].write(f)
|
||||
|
||||
# 保存config.json
|
||||
with codecs.open(config_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(api_config['config'], f, ensure_ascii=False, indent=4)
|
||||
|
||||
util.log(1, f"已将配置中心配置缓存到本地文件: {system_conf_path} 和 {config_json_path}")
|
||||
except Exception as e:
|
||||
util.log(2, f"保存配置中心配置缓存到本地文件时出错: {str(e)}")
|
||||
|
||||
@synchronized
|
||||
def save_config(config_data):
|
||||
"""
|
||||
保存配置到config.json文件
|
||||
|
||||
Args:
|
||||
config_data: 要保存的配置数据
|
||||
config_dir: 配置文件目录,如果为None则使用当前目录
|
||||
"""
|
||||
global config
|
||||
global config_json_path
|
||||
|
||||
config = config_data
|
||||
|
||||
# 保存到文件
|
||||
with codecs.open(config_json_path, mode='w', encoding='utf-8') as file:
|
||||
file.write(json.dumps(config_data, sort_keys=True, indent=4, separators=(',', ': ')))
|
||||
|
||||
def save_api_config_to_local(api_config, system_conf_path, config_json_path):
|
||||
"""
|
||||
将API加载的配置保存到本地文件
|
||||
|
||||
Args:
|
||||
api_config: API加载的配置字典
|
||||
system_conf_path: system.conf文件路径
|
||||
config_json_path: config.json文件路径
|
||||
"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(system_conf_path), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(config_json_path), exist_ok=True)
|
||||
|
||||
# 保存system.conf
|
||||
with open(system_conf_path, 'w', encoding='utf-8') as f:
|
||||
api_config['system_config'].write(f)
|
||||
|
||||
# 保存config.json
|
||||
with codecs.open(config_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(api_config['config'], f, ensure_ascii=False, indent=4)
|
||||
|
||||
util.log(1, f"已将配置中心配置缓存到本地文件: {system_conf_path} 和 {config_json_path}")
|
||||
except Exception as e:
|
||||
util.log(2, f"保存配置中心配置缓存到本地文件时出错: {str(e)}")
|
||||
|
||||
@synchronized
|
||||
def save_config(config_data):
|
||||
"""
|
||||
保存配置到config.json文件
|
||||
|
||||
Args:
|
||||
config_data: 要保存的配置数据
|
||||
config_dir: 配置文件目录,如果为None则使用当前目录
|
||||
"""
|
||||
global config
|
||||
global config_json_path
|
||||
|
||||
config = config_data
|
||||
|
||||
# 保存到文件
|
||||
with codecs.open(config_json_path, mode='w', encoding='utf-8') as file:
|
||||
file.write(json.dumps(config_data, sort_keys=True, indent=4, separators=(',', ': ')))
|
||||
|
||||
Reference in New Issue
Block a user