交互模块

1. [增加]embedding维度检查自动修正逻辑。
This commit is contained in:
guo zebin
2026-01-29 15:09:37 +08:00
parent 96e467bbd6
commit eecbb931a9
5 changed files with 546 additions and 318 deletions

View File

@@ -369,6 +369,26 @@ def start():
util.log(1, '读取配置...')
config_util.load_config()
# 启动阶段预热 embedding 服务(避免首条消息时才初始化维度)
try:
util.log(1, '启动阶段预热 embedding 服务...')
from simulation_engine.gpt_structure import get_text_embedding
# 检查服务是否已经初始化
from utils.api_embedding_service import get_embedding_service
service = get_embedding_service()
if hasattr(service, 'embedding_dim') and service.embedding_dim is not None:
util.log(1, f'embedding 服务已初始化,维度: {service.embedding_dim}')
else:
util.log(1, '初始化 embedding 服务维度...')
get_text_embedding("dimension_check")
util.log(1, f'embedding 服务维度初始化完成: {service.embedding_dim}')
util.log(1, 'embedding 服务预热完成')
except Exception as e:
util.log(1, f'embedding 服务预热失败: {str(e)}')
#开启核心服务
util.log(1, '开启核心服务...')
feiFei = get_fay_core().FeiFei()

View File

@@ -12,6 +12,7 @@ from simulation_engine.settings import *
from simulation_engine.global_methods import *
from simulation_engine.gpt_structure import *
from simulation_engine.llm_json_parser import *
from utils import util
def run_gpt_generate_importance(
@@ -31,7 +32,7 @@ def run_gpt_generate_importance(
gpt_response = extract_first_json_dict(gpt_response)
# 处理gpt_response为None的情况
if gpt_response is None:
print("警告: extract_first_json_dict返回None使用默认值")
util.log(2, "警告: extract_first_json_dict返回None使用默认值")
return [50] # 返回默认重要性分数
return list(gpt_response.values())
@@ -172,11 +173,11 @@ def normalize_dict_floats(d, target_min, target_max):
"""
# 检查字典是否为None或为空
if d is None:
print("警告: normalize_dict_floats接收到None字典")
util.log(2, "警告: normalize_dict_floats接收到None字典")
return {}
if not d:
print("警告: normalize_dict_floats接收到空字典")
util.log(2, "警告: normalize_dict_floats接收到空字典")
return {}
try:
@@ -193,7 +194,7 @@ def normalize_dict_floats(d, target_min, target_max):
/ range_val + target_min)
return d
except Exception as e:
print(f"normalize_dict_floats处理字典时出错: {str(e)}")
util.log(3, f"normalize_dict_floats处理字典时出错: {str(e)}")
# 返回原始字典,避免处理失败
return d
@@ -238,11 +239,11 @@ def extract_recency(seq_nodes):
"""
# 检查seq_nodes是否为None或为空
if seq_nodes is None:
print("警告: extract_recency接收到None节点列表")
util.log(2, "警告: extract_recency接收到None节点列表")
return {}
if not seq_nodes:
print("警告: extract_recency接收到空节点列表")
util.log(2, "警告: extract_recency接收到空节点列表")
return {}
try:
@@ -250,11 +251,11 @@ def extract_recency(seq_nodes):
normalized_timestamps = []
for node in seq_nodes:
if node is None:
print("警告: 节点为None跳过")
util.log(2, "警告: 节点为None跳过")
continue
if not hasattr(node, 'last_retrieved'):
print(f"警告: 节点 {node} 没有last_retrieved属性使用默认值0")
util.log(2, f"警告: 节点 {node} 没有last_retrieved属性使用默认值0")
normalized_timestamps.append(0)
continue
@@ -284,18 +285,18 @@ def extract_recency(seq_nodes):
recency_out[node.node_id] = (recency_decay
** (max_timestep - last_retrieved))
except Exception as e:
print(f"计算节点 {node.node_id} 的recency时出错: {str(e)}")
util.log(3, f"计算节点 {node.node_id} 的recency时出错: {str(e)}")
# 使用默认值
recency_out[node.node_id] = 1.0
return recency_out
except Exception as e:
print(f"extract_recency处理节点列表时出错: {str(e)}")
util.log(3, f"extract_recency处理节点列表时出错: {str(e)}")
# 返回一个默认字典
return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
def extract_importance(seq_nodes):
def extract_importance(seq_nodes):
"""
Gets the current Persona object and a list of nodes that are in a
chronological order, and outputs a dictionary that has the importance score
@@ -309,22 +310,22 @@ def extract_importance(seq_nodes):
"""
# 检查seq_nodes是否为None或为空
if seq_nodes is None:
print("警告: extract_importance接收到None节点列表")
util.log(2, "警告: extract_importance接收到None节点列表")
return {}
if not seq_nodes:
print("警告: extract_importance接收到空节点列表")
util.log(2, "警告: extract_importance接收到空节点列表")
return {}
try:
importance_out = dict()
for count, node in enumerate(seq_nodes):
if node is None:
print("警告: 节点为None跳过")
util.log(2, "警告: 节点为None跳过")
continue
if not hasattr(node, 'node_id') or not hasattr(node, 'importance'):
print(f"警告: 节点缺少必要属性,跳过")
util.log(2, f"警告: 节点缺少必要属性,跳过")
continue
# 确保importance是数值类型
@@ -333,32 +334,32 @@ def extract_importance(seq_nodes):
importance_out[node.node_id] = float(node.importance)
except ValueError:
# 如果无法转换为数值,使用默认值
print(f"警告: 节点 {node.node_id} 的importance无法转换为数值使用默认值")
util.log(2, f"警告: 节点 {node.node_id} 的importance无法转换为数值使用默认值")
importance_out[node.node_id] = 50.0
else:
importance_out[node.node_id] = node.importance
return importance_out
except Exception as e:
print(f"extract_importance处理节点列表时出错: {str(e)}")
# 返回一个默认字典
return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
def _is_valid_embedding(vec, expected_dim):
if vec is None:
return False
if not isinstance(vec, (list, tuple)):
return False
if expected_dim is not None and len(vec) != expected_dim:
return False
for val in vec:
if not isinstance(val, (int, float)):
return False
return True
def extract_relevance(seq_nodes, embeddings, focal_pt):
except Exception as e:
util.log(3, f"extract_importance处理节点列表时出错: {str(e)}")
# 返回一个默认字典
return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
def _is_valid_embedding(vec, expected_dim):
if vec is None:
return False
if not isinstance(vec, (list, tuple)):
return False
if expected_dim is not None and len(vec) != expected_dim:
return False
for val in vec:
if not isinstance(val, (int, float)):
return False
return True
def extract_relevance(seq_nodes, embeddings, focal_pt):
"""
Gets the current Persona object, a list of seq_nodes that are in a
chronological order, and the focal_pt string and outputs a dictionary
@@ -373,45 +374,39 @@ def extract_relevance(seq_nodes, embeddings, focal_pt):
"""
# 确保embeddings不为None
if embeddings is None:
print("警告: embeddings为None使用空字典代替")
util.log(2, "警告: embeddings为None使用空字典代替")
embeddings = {}
try:
focal_embedding = get_text_embedding(focal_pt)
except Exception as e:
print(f"获取焦点嵌入向量时出错: {str(e)}")
# 如果无法获取嵌入向量,返回默认值
return {node.node_id: 0.5 for node in seq_nodes}
expected_dim = len(focal_embedding) if isinstance(focal_embedding, (list, tuple)) else None
relevance_out = dict()
for count, node in enumerate(seq_nodes):
try:
# 检查节点内容是否在embeddings中
if node.content in embeddings:
node_embedding = embeddings[node.content]
if not _is_valid_embedding(node_embedding, expected_dim):
try:
regenerated = get_text_embedding(node.content)
if _is_valid_embedding(regenerated, expected_dim):
embeddings[node.content] = regenerated
node_embedding = regenerated
else:
print("Warning: regenerated embedding has unexpected dimension, using default relevance")
node_embedding = None
except Exception as e:
print(f"Regenerate embedding failed: {str(e)}")
node_embedding = None
# 计算余弦相似度
if node_embedding is None:
relevance_out[node.node_id] = 0.5
else:
relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding)
else:
# 如果没有对应的嵌入向量,使用默认值
relevance_out[node.node_id] = 0.5
try:
focal_embedding = get_text_embedding(focal_pt)
except Exception as e:
util.log(3, f"获取焦点嵌入向量时出错: {str(e)}")
# 如果无法获取嵌入向量,返回默认值
return {node.node_id: 0.5 for node in seq_nodes}
expected_dim = len(focal_embedding) if isinstance(focal_embedding, (list, tuple)) else None
relevance_out = dict()
for count, node in enumerate(seq_nodes):
try:
# 检查节点内容是否在embeddings中
if node.content in embeddings:
node_embedding = embeddings[node.content]
if not _is_valid_embedding(node_embedding, expected_dim):
# 维度检查与修复在启动阶段完成,这里不再重算,避免首条消息扣时间
current_dim = len(node_embedding) if isinstance(node_embedding, (list, tuple)) else "未知"
util.log(2, f"检索时发现维度不一致的embedding: 节点ID={node.node_id}, 内容='{node.content[:30]}...', 当前维度={current_dim}, 期望维度={expected_dim}")
util.log(2, f" -> 使用默认相关性分数 0.5 (建议重启系统进行维度修复)")
node_embedding = None
# 计算余弦相似度
if node_embedding is None:
relevance_out[node.node_id] = 0.5
else:
relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding)
else:
# 如果没有对应的嵌入向量,使用默认值
relevance_out[node.node_id] = 0.5
except Exception as e:
print(f"计算节点 {node.node_id} 的相关性时出错: {str(e)}")
util.log(3, f"计算节点 {node.node_id} 的相关性时出错: {str(e)}")
# 如果计算过程中出错,使用默认值
relevance_out[node.node_id] = 0.5
@@ -425,12 +420,12 @@ def extract_relevance(seq_nodes, embeddings, focal_pt):
class ConceptNode:
def __init__(self, node_dict):
# Loading the content of a memory node in the memory stream.
self.node_id = node_dict["node_id"]
self.node_type = node_dict["node_type"]
self.content = node_dict["content"]
self.importance = node_dict["importance"]
self.datetime = node_dict.get("datetime", "")
# 确保created是整数类型
self.node_id = node_dict["node_id"]
self.node_type = node_dict["node_type"]
self.content = node_dict["content"]
self.importance = node_dict["importance"]
self.datetime = node_dict.get("datetime", "")
# 确保created是整数类型
self.created = int(node_dict["created"]) if node_dict["created"] is not None else 0
# 确保last_retrieved是整数类型
self.last_retrieved = int(node_dict["last_retrieved"]) if node_dict["last_retrieved"] is not None else 0
@@ -449,10 +444,10 @@ class ConceptNode:
curr_package = {}
curr_package["node_id"] = self.node_id
curr_package["node_type"] = self.node_type
curr_package["content"] = self.content
curr_package["importance"] = self.importance
curr_package["datetime"] = self.datetime
curr_package["created"] = self.created
curr_package["content"] = self.content
curr_package["importance"] = self.importance
curr_package["datetime"] = self.datetime
curr_package["created"] = self.created
curr_package["last_retrieved"] = self.last_retrieved
curr_package["pointer_id"] = self.pointer_id
@@ -474,6 +469,87 @@ class MemoryStream:
self.id_to_node[new_node.node_id] = new_node
self.embeddings = embeddings
self._embedding_dim_checked = False
def precheck_embedding_dimensions(self, force: bool = False):
"""
启动阶段检查并修复记忆节点 embedding 维度,避免首条消息检索时重算。
"""
result = {"checked": False, "expected_dim": None, "fixed": 0}
if self._embedding_dim_checked and not force:
return result
# 确保embeddings不为None
if self.embeddings is None:
self.embeddings = {}
if not force:
return result
try:
# 首先尝试从已初始化的embedding服务获取维度避免重复调用
from utils.api_embedding_service import get_embedding_service
from utils import util
service = get_embedding_service()
# 如果服务已经有维度信息,直接使用
if hasattr(service, 'embedding_dim') and service.embedding_dim is not None:
expected_dim = service.embedding_dim
util.log(1, f"使用已初始化的embedding服务维度: {expected_dim}")
else:
# 只有在服务未初始化维度时才调用dimension_check
util.log(1, "embedding服务维度未初始化进行维度检查...")
sample_embedding = get_text_embedding("dimension_check")
expected_dim = len(sample_embedding) if isinstance(sample_embedding, (list, tuple)) else None
except Exception as e:
from utils import util
util.log(2, f"启动阶段 embedding 维度检查失败: {str(e)}")
# 即使检查失败,也标记为已检查,避免重复尝试
self._embedding_dim_checked = True
return result
if expected_dim is None:
from utils import util
util.log(2, "无法获取 embedding 维度,跳过维度检查")
self._embedding_dim_checked = True
return result
fixed = 0
if self.seq_nodes:
contents = [node.content for node in self.seq_nodes if node is not None]
else:
contents = list(self.embeddings.keys())
for content in contents:
if content in self.embeddings:
node_embedding = self.embeddings[content]
if not _is_valid_embedding(node_embedding, expected_dim):
# 记录维度不一致的详细信息
current_dim = len(node_embedding) if isinstance(node_embedding, (list, tuple)) else "未知"
from utils import util
util.log(2, f"发现维度不一致的embedding: 内容='{content[:30]}...', 当前维度={current_dim}, 期望维度={expected_dim}")
try:
util.log(1, f"正在重新生成embedding: '{content[:30]}...'")
regenerated = get_text_embedding(content)
if regenerated is not None and _is_valid_embedding(regenerated, expected_dim):
self.embeddings[content] = regenerated
fixed += 1
util.log(1, f"成功修复embedding维度: '{content[:30]}...' ({current_dim} -> {len(regenerated)})")
else:
regenerated_dim = len(regenerated) if isinstance(regenerated, (list, tuple)) else "未知"
util.log(2, f"重新生成的embedding维度仍不一致: '{content[:30]}...' (期望={expected_dim}, 实际={regenerated_dim}),已忽略")
except Exception as e:
util.log(2, f"重新生成embedding失败: '{content[:30]}...' - {str(e)}")
if fixed > 0:
from utils import util
util.log(1, f"启动阶段已修复 {fixed} 条记忆节点 embedding 维度")
self._embedding_dim_checked = True
result["checked"] = True
result["expected_dim"] = expected_dim
result["fixed"] = fixed
return result
def count_observations(self):
@@ -503,8 +579,8 @@ class MemoryStream:
the elemnts of the list are the query sentences.
time_step: Current time_step
n_count: The number of nodes that we want to retrieve.
curr_filter: Filtering the node.type that we want to retrieve.
Acceptable values are 'all', 'reflection', 'observation', 'conversation'
curr_filter: Filtering the node.type that we want to retrieve.
Acceptable values are 'all', 'reflection', 'observation', 'conversation'
hp: Hyperparameter for [recency_w, relevance_w, importance_w]
verbose: verbose
Returns:
@@ -517,8 +593,8 @@ class MemoryStream:
if len(self.seq_nodes) == 0:
return dict()
# Filtering for the desired node type. curr_filter can be one of the
# elements: 'all', 'reflection', 'observation', 'conversation'
# Filtering for the desired node type. curr_filter can be one of the
# elements: 'all', 'reflection', 'observation', 'conversation'
if curr_filter == "all":
curr_nodes = self.seq_nodes
else:
@@ -528,7 +604,7 @@ class MemoryStream:
# 确保embeddings不为None
if self.embeddings is None:
print("警告: 在retrieve方法中embeddings为None初始化为空字典")
util.log(2, "警告: 在retrieve方法中embeddings为None初始化为空字典")
self.embeddings = {}
# <retrieved> is the main dictionary that we are returning
@@ -587,7 +663,7 @@ class MemoryStream:
Parameters:
time_step: Current time_step
node_type: type of node -- it's either reflection, observation, conversation
node_type: type of node -- it's either reflection, observation, conversation
content: the str content of the memory record
importance: int score of the importance score
pointer_id: the str of the parent node
@@ -597,11 +673,11 @@ class MemoryStream:
"""
node_dict = dict()
node_dict["node_id"] = len(self.seq_nodes)
node_dict["node_type"] = node_type
node_dict["content"] = content
node_dict["importance"] = importance
node_dict["datetime"] = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")
node_dict["created"] = time_step
node_dict["node_type"] = node_type
node_dict["content"] = content
node_dict["importance"] = importance
node_dict["datetime"] = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")
node_dict["created"] = time_step
node_dict["last_retrieved"] = time_step
node_dict["pointer_id"] = pointer_id
new_node = ConceptNode(node_dict)
@@ -616,19 +692,19 @@ class MemoryStream:
try:
self.embeddings[content] = get_text_embedding(content)
except Exception as e:
print(f"获取文本嵌入时出错: {str(e)}")
util.log(3, f"获取文本嵌入时出错: {str(e)}")
# 如果获取嵌入失败,使用空列表代替
self.embeddings[content] = []
def remember(self, content, time_step=0):
score = generate_importance_score([content])[0]
self._add_node(time_step, "observation", content, score, None)
def remember_conversation(self, content, time_step=0):
score = generate_importance_score([content])[0]
self._add_node(time_step, "conversation", content, score, None)
def remember(self, content, time_step=0):
score = generate_importance_score([content])[0]
self._add_node(time_step, "observation", content, score, None)
def remember_conversation(self, content, time_step=0):
score = generate_importance_score([content])[0]
self._add_node(time_step, "conversation", content, score, None)
def reflect(self, anchor, reflection_count=5,
retrieval_count=120, time_step=0):

View File

@@ -1517,9 +1517,28 @@ def init_memory_scheduler():
global agents
# 确保agent已经创建
if not agents:
util.log(1, '创建代理实例...')
create_agent()
agent = None
if not agents:
util.log(1, '创建代理实例...')
agent = create_agent()
else:
agent = agents.get("User")
if agent is None and len(agents) > 0:
agent = next(iter(agents.values()))
# 启动阶段做一次 embedding 维度检查,避免首条消息时触发
try:
if agent and agent.memory_stream and hasattr(agent.memory_stream, "precheck_embedding_dimensions"):
result = agent.memory_stream.precheck_embedding_dimensions()
if result.get("checked"):
util.log(
1,
f"启动阶段记忆 embedding 维度检查完成: dim={result.get('expected_dim')}, 修复={result.get('fixed')}"
)
else:
util.log(1, "启动阶段记忆 embedding 维度检查跳过(无记忆/无embedding")
except Exception as e:
util.log(1, f"启动阶段 embedding 维度检查失败: {str(e)}")
# 设置每天0点保存记忆
schedule.every().day.at("00:00").do(save_agent_memory)
@@ -1640,8 +1659,28 @@ def create_agent(username=None):
agent.scratch = scratch_data
# 如果memory目录存在且不为空则加载之前保存的记忆不包括scratch数据
if is_exist:
load_agent_memory(agent, username)
if is_exist:
load_agent_memory(agent, username)
try:
if agent.memory_stream and hasattr(agent.memory_stream, "precheck_embedding_dimensions"):
result = agent.memory_stream.precheck_embedding_dimensions(force=True)
if result.get("checked"):
util.log(
1,
f"启动阶段记忆 embedding 维度检查完成: dim={result.get('expected_dim')}, 修复={result.get('fixed')}"
)
if result.get("fixed"):
try:
embeddings_path = os.path.join(memory_dir, "memory_stream", "embeddings.json")
with open(embeddings_path, "w", encoding="utf-8") as f:
json.dump(agent.memory_stream.embeddings or {}, f, ensure_ascii=False, indent=2)
util.log(1, f"启动阶段已写回 embeddings.json (修复={result.get('fixed')})")
except Exception as write_err:
util.log(1, f"写回 embeddings.json 失败: {str(write_err)}")
else:
util.log(1, "启动阶段记忆 embedding 维度检查跳过(无记忆/无embedding")
except Exception as e:
util.log(1, f"启动阶段 embedding 维度检查失败: {str(e)}")
# 缓存到字典
agents[username] = agent

View File

@@ -123,6 +123,7 @@ class ApiEmbeddingService:
def encode_text(self, text: str) -> List[float]:
"""编码单个文本(带重试机制)"""
import time
import requests.exceptions
text = _sanitize_text(text)
last_error = None
@@ -152,7 +153,15 @@ class ApiEmbeddingService:
# 首次调用时获取实际维度
if self.embedding_dim is None:
self.embedding_dim = len(embedding)
logger.info(f"动态获取 embedding 维度: {self.embedding_dim},与原记忆节点不一致,将重新生成记忆节点的 embedding维度")
logger.info(f"动态获取 embedding 维度: {self.embedding_dim}")
else:
# 检查维度一致性
current_dim = len(embedding)
if current_dim != self.embedding_dim:
logger.warning(f"⚠️ Embedding维度不一致! 期望={self.embedding_dim}, 实际={current_dim}, 文本='{text_preview}'")
logger.warning(f" 建议检查API配置或模型设置")
# 更新维度记录
self.embedding_dim = current_dim
logger.info(f"embedding 生成成功")
return embedding
@@ -168,6 +177,38 @@ class ApiEmbeddingService:
logger.error(f"所有重试均失败,文本长度: {len(text)}")
raise
except requests.exceptions.ConnectionError as e:
last_error = e
# 网络连接错误包括DNS解析失败
logger.warning(f"网络连接失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
if attempt < self.max_retries:
wait_time = 2 ** attempt # 指数退避
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
logger.error(f"网络连接持续失败请检查网络设置和API地址: {self.api_base_url}")
raise
except requests.exceptions.HTTPError as e:
last_error = e
# HTTP错误4xx, 5xx
status_code = e.response.status_code if e.response else "未知"
logger.error(f"HTTP错误 {status_code} (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
# 对于客户端错误4xx通常不需要重试
if e.response and 400 <= e.response.status_code < 500:
logger.error(f"客户端错误,不进行重试: {e.response.text}")
raise
# 对于服务器错误5xx可以重试
if attempt < self.max_retries:
wait_time = 2 ** attempt
logger.info(f"服务器错误,等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
logger.error(f"服务器错误持续发生")
raise
except Exception as e:
last_error = e
logger.error(f"文本编码失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
@@ -202,6 +243,28 @@ class ApiEmbeddingService:
result = response.json()
embeddings = [item['embedding'] for item in result['data']]
# 检查批量embedding的维度一致性
if embeddings:
dimensions = [len(emb) for emb in embeddings]
unique_dims = set(dimensions)
if len(unique_dims) > 1:
logger.warning(f"⚠️ 批量embedding维度不一致: {dict(zip(range(len(dimensions)), dimensions))}")
logger.warning(f" 唯一维度: {unique_dims}")
# 检查与已知维度的一致性
if self.embedding_dim is not None:
for i, dim in enumerate(dimensions):
if dim != self.embedding_dim:
text_preview = texts[i][:30] + "..." if len(texts[i]) > 30 else texts[i]
logger.warning(f"⚠️ 文本{i}维度不一致: 期望={self.embedding_dim}, 实际={dim}, 文本='{text_preview}'")
else:
# 首次批量调用,设置维度
if dimensions:
self.embedding_dim = dimensions[0]
logger.info(f"从批量请求动态获取 embedding 维度: {self.embedding_dim}")
logger.info(f"批量 embedding 生成成功: {len(embeddings)} 个向量")
return embeddings
except Exception as e:
logger.error(f"批量文本编码失败: {e}")
@@ -217,6 +280,29 @@ class ApiEmbeddingService:
"service_type": "api"
}
def health_check(self) -> dict:
"""健康检查测试embedding服务是否正常工作"""
try:
# 使用简单文本测试服务
test_text = "health_check"
embedding = self.encode_text(test_text)
return {
"status": "healthy",
"model": self.model_name,
"api_url": self.api_base_url,
"embedding_dim": len(embedding) if embedding else None,
"test_successful": True
}
except Exception as e:
return {
"status": "unhealthy",
"model": self.model_name,
"api_url": self.api_base_url,
"error": str(e),
"test_successful": False
}
# 全局实例
_global_embedding_service = None

View File

@@ -1,7 +1,7 @@
import os
import json
import codecs
from langsmith.schemas import Feedback
# from langsmith.schemas import Feedback # 临时注释
import requests
from configparser import ConfigParser
import functools
@@ -53,36 +53,58 @@ start_mode = None
fay_url = None
system_conf_path = None
config_json_path = None
use_bionic_memory = None
# Embedding API 配置全局变量
embedding_api_model = None
embedding_api_base_url = None
embedding_api_key = None
# 避免重复加载配置中心导致日志刷屏
_last_loaded_project_id = None
_last_loaded_config = None
_last_loaded_from_api = False # 表示上次加载来自配置中心(含缓存)
_bootstrap_loaded_from_api = False # 无本地配置时启动阶段已从配置中心加载过
_warned_public_project_ids = set()
use_bionic_memory = None
# Embedding API 配置全局变量
embedding_api_model = None
embedding_api_base_url = None
embedding_api_key = None
# 避免重复加载配置中心导致日志刷屏
_last_loaded_project_id = None
_last_loaded_config = None
_last_loaded_from_api = False # 表示上次加载来自配置中心(含缓存)
_bootstrap_loaded_from_api = False # 无本地配置时启动阶段已从配置中心加载过
_warned_public_config_keys = set()
# Public config center identifiers (warn users if matched)
PUBLIC_CONFIG_PROJECT_ID = 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
PUBLIC_CONFIG_BASE_URL = 'http://1.12.69.110:5500'
def _public_config_warn_key():
base_url = (CONFIG_SERVER.get('BASE_URL') or '').rstrip('/')
public_base = PUBLIC_CONFIG_BASE_URL.rstrip('/')
if base_url == public_base:
return f"base_url:{base_url}"
if CONFIG_SERVER.get('PROJECT_ID') == PUBLIC_CONFIG_PROJECT_ID:
return f"project_id:{CONFIG_SERVER.get('PROJECT_ID')}"
return None
def _warn_public_config_once():
key = _public_config_warn_key()
if not key or key in _warned_public_config_keys:
return
_warned_public_config_keys.add(key)
print("\033[1;33;41m警告你正在使用社区公共配置,请尽快更换!\033[0m")
# config server中心配置system.conf与config.json存在时不会使用配置中心
CONFIG_SERVER = {
'BASE_URL': 'http://1.12.69.110:5500', # 默认API服务器地址
'API_KEY': 'your-api-key-here', # 默认API密钥
'PROJECT_ID': 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' # 项目ID需要在使用前设置
}
def _refresh_config_center():
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
if env_project_id:
CONFIG_SERVER['PROJECT_ID'] = env_project_id
_refresh_config_center()
def load_config_from_api(project_id=None):
global CONFIG_SERVER
CONFIG_SERVER = {
'BASE_URL': 'http://1.12.69.110:5500', # 默认API服务器地址
'API_KEY': 'your-api-key-here', # 默认API密钥
'PROJECT_ID': 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' # 项目ID需要在使用前设置
}
def _refresh_config_center():
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
if env_project_id:
CONFIG_SERVER['PROJECT_ID'] = env_project_id
_refresh_config_center()
def load_config_from_api(project_id=None):
global CONFIG_SERVER
"""
从API加载配置
@@ -159,15 +181,15 @@ def load_config_from_api(project_id=None):
return None
@synchronized
def load_config(force_reload=False):
"""
加载配置文件如果本地文件不存在则直接使用API加载
@synchronized
def load_config(force_reload=False):
"""
加载配置文件如果本地文件不存在则直接使用API加载
Returns:
包含配置信息的字典
"""
global config
Returns:
包含配置信息的字典
"""
global config
global system_config
global key_ali_nls_key_id
global key_ali_nls_key_secret
@@ -200,168 +222,153 @@ def load_config(force_reload=False):
global embedding_api_base_url
global embedding_api_key
global CONFIG_SERVER
global system_conf_path
global config_json_path
global _last_loaded_project_id
global _last_loaded_config
global _last_loaded_from_api
global _bootstrap_loaded_from_api
_refresh_config_center()
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
explicit_config_center = bool(env_project_id)
using_config_center = explicit_config_center
if (
env_project_id
and not force_reload
and _last_loaded_config is not None
and _last_loaded_project_id == env_project_id
and _last_loaded_from_api
):
return _last_loaded_config
default_system_conf_path = os.path.join(os.getcwd(), 'system.conf')
default_config_json_path = os.path.join(os.getcwd(), 'config.json')
cache_system_conf_path = os.path.join(os.getcwd(), 'cache_data', 'system.conf')
cache_config_json_path = os.path.join(os.getcwd(), 'cache_data', 'config.json')
root_system_conf_exists = os.path.exists(default_system_conf_path)
root_config_json_exists = os.path.exists(default_config_json_path)
root_config_complete = root_system_conf_exists and root_config_json_exists
# 构建system.conf和config.json的完整路径
config_center_fallback = False
if using_config_center:
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
else:
if (
system_conf_path is None
or config_json_path is None
or system_conf_path == cache_system_conf_path
or config_json_path == cache_config_json_path
):
system_conf_path = default_system_conf_path
config_json_path = default_config_json_path
if not root_config_complete:
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
if (not _bootstrap_loaded_from_api) or (not cache_ready):
using_config_center = True
config_center_fallback = True
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
else:
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
forced_loaded = False
loaded_from_api = False
api_attempted = False
if using_config_center:
if explicit_config_center:
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
else:
util.log(1, f"未检测到本地system.conf或config.json尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
api_attempted = True
if api_config:
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
if config_center_fallback:
_bootstrap_loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
forced_loaded = True
if (
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
):
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
print("\033[1;33;41m警告你正在使用社区公共配置,请尽快更换!\033[0m")
else:
util.log(2, "配置中心加载失败,尝试使用缓存配置")
sys_conf_exists = os.path.exists(system_conf_path)
config_json_exists = os.path.exists(config_json_path)
# 如果任一本地文件不存在直接尝试从API加载
if (not sys_conf_exists or not config_json_exists) and not forced_loaded:
if using_config_center:
if not api_attempted:
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
api_attempted = True
if api_config:
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
if config_center_fallback:
_bootstrap_loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
if (
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
):
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
print("\033[1;33;41m警告你正在使用社区公共配置,请尽快更换!\033[0m")
else:
# 使用提取的项目ID或全局项目ID
util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在尝试从API加载配置...")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
if api_config:
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
if (
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
):
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
print("\033[1;33;41m警告你正在使用社区公共配置,请尽快更换!\033[0m")
sys_conf_exists = os.path.exists(system_conf_path)
config_json_exists = os.path.exists(config_json_path)
if using_config_center and (not sys_conf_exists or not config_json_exists):
if _last_loaded_config is not None and _last_loaded_from_api:
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
return _last_loaded_config
if config_center_fallback and using_config_center and (not sys_conf_exists or not config_json_exists):
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
if cache_ready:
util.log(2, "配置中心不可用,回退使用缓存配置")
using_config_center = False
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
else:
util.log(2, "配置中心不可用,回退使用本地配置文件")
using_config_center = False
system_conf_path = default_system_conf_path
config_json_path = default_config_json_path
sys_conf_exists = os.path.exists(system_conf_path)
config_json_exists = os.path.exists(config_json_path)
# 如果本地文件存在,从本地文件加载
global CONFIG_SERVER
global system_conf_path
global config_json_path
global _last_loaded_project_id
global _last_loaded_config
global _last_loaded_from_api
global _bootstrap_loaded_from_api
_refresh_config_center()
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
explicit_config_center = bool(env_project_id)
using_config_center = explicit_config_center
if (
env_project_id
and not force_reload
and _last_loaded_config is not None
and _last_loaded_project_id == env_project_id
and _last_loaded_from_api
):
return _last_loaded_config
default_system_conf_path = os.path.join(os.getcwd(), 'system.conf')
default_config_json_path = os.path.join(os.getcwd(), 'config.json')
cache_system_conf_path = os.path.join(os.getcwd(), 'cache_data', 'system.conf')
cache_config_json_path = os.path.join(os.getcwd(), 'cache_data', 'config.json')
root_system_conf_exists = os.path.exists(default_system_conf_path)
root_config_json_exists = os.path.exists(default_config_json_path)
root_config_complete = root_system_conf_exists and root_config_json_exists
# 构建system.conf和config.json的完整路径
config_center_fallback = False
if using_config_center:
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
else:
if (
system_conf_path is None
or config_json_path is None
or system_conf_path == cache_system_conf_path
or config_json_path == cache_config_json_path
):
system_conf_path = default_system_conf_path
config_json_path = default_config_json_path
if not root_config_complete:
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
if (not _bootstrap_loaded_from_api) or (not cache_ready):
using_config_center = True
config_center_fallback = True
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
else:
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
forced_loaded = False
loaded_from_api = False
api_attempted = False
if using_config_center:
if explicit_config_center:
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
else:
util.log(1, f"未检测到本地system.conf或config.json尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
api_attempted = True
if api_config:
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
if config_center_fallback:
_bootstrap_loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
forced_loaded = True
_warn_public_config_once()
else:
util.log(2, "配置中心加载失败,尝试使用缓存配置")
sys_conf_exists = os.path.exists(system_conf_path)
config_json_exists = os.path.exists(config_json_path)
# 如果任一本地文件不存在直接尝试从API加载
if (not sys_conf_exists or not config_json_exists) and not forced_loaded:
if using_config_center:
if not api_attempted:
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
api_attempted = True
if api_config:
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
if config_center_fallback:
_bootstrap_loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
_warn_public_config_once()
else:
# 使用提取的项目ID或全局项目ID
util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在尝试从API加载配置...")
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
if api_config:
util.log(1, "成功从配置中心加载配置")
system_config = api_config['system_config']
config = api_config['config']
loaded_from_api = True
# 缓存API配置到本地文件
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
save_api_config_to_local(api_config, system_conf_path, config_json_path)
_warn_public_config_once()
sys_conf_exists = os.path.exists(system_conf_path)
config_json_exists = os.path.exists(config_json_path)
if using_config_center and (not sys_conf_exists or not config_json_exists):
if _last_loaded_config is not None and _last_loaded_from_api:
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
return _last_loaded_config
if config_center_fallback and using_config_center and (not sys_conf_exists or not config_json_exists):
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
if cache_ready:
util.log(2, "配置中心不可用,回退使用缓存配置")
using_config_center = False
system_conf_path = cache_system_conf_path
config_json_path = cache_config_json_path
else:
util.log(2, "配置中心不可用,回退使用本地配置文件")
using_config_center = False
system_conf_path = default_system_conf_path
config_json_path = default_config_json_path
sys_conf_exists = os.path.exists(system_conf_path)
config_json_exists = os.path.exists(config_json_path)
# 如果本地文件存在,从本地文件加载
# 加载system.conf
system_config = ConfigParser()
system_config.read(system_conf_path, encoding='UTF-8')
@@ -417,9 +424,9 @@ def load_config(force_reload=False):
use_bionic_memory = config.get('memory', {}).get('use_bionic_memory', False)
# 构建配置字典
config_dict = {
'system_config': system_config,
'config': config,
config_dict = {
'system_config': system_config,
'config': config,
'ali_nls_key_id': key_ali_nls_key_id,
'ali_nls_key_secret': key_ali_nls_key_secret,
'ali_nls_app_key': key_ali_nls_app_key,
@@ -455,14 +462,14 @@ def load_config(force_reload=False):
'embedding_api_base_url': embedding_api_base_url,
'embedding_api_key': embedding_api_key,
'source': 'api' if using_config_center else 'local' # 标记配置来源
}
_last_loaded_project_id = CONFIG_SERVER['PROJECT_ID'] if using_config_center else None
_last_loaded_config = config_dict
_last_loaded_from_api = using_config_center
return config_dict
'source': 'api' if using_config_center else 'local' # 标记配置来源
}
_last_loaded_project_id = CONFIG_SERVER['PROJECT_ID'] if using_config_center else None
_last_loaded_config = config_dict
_last_loaded_from_api = using_config_center
return config_dict
def save_api_config_to_local(api_config, system_conf_path, config_json_path):
"""