mirror of
https://github.com/xszyou/Fay.git
synced 2026-03-12 17:51:28 +08:00
交互模块
1. [增加]embedding维度检查自动修正逻辑。
This commit is contained in:
@@ -369,6 +369,26 @@ def start():
|
||||
util.log(1, '读取配置...')
|
||||
config_util.load_config()
|
||||
|
||||
# 启动阶段预热 embedding 服务(避免首条消息时才初始化维度)
|
||||
try:
|
||||
util.log(1, '启动阶段预热 embedding 服务...')
|
||||
from simulation_engine.gpt_structure import get_text_embedding
|
||||
|
||||
# 检查服务是否已经初始化
|
||||
from utils.api_embedding_service import get_embedding_service
|
||||
service = get_embedding_service()
|
||||
|
||||
if hasattr(service, 'embedding_dim') and service.embedding_dim is not None:
|
||||
util.log(1, f'embedding 服务已初始化,维度: {service.embedding_dim}')
|
||||
else:
|
||||
util.log(1, '初始化 embedding 服务维度...')
|
||||
get_text_embedding("dimension_check")
|
||||
util.log(1, f'embedding 服务维度初始化完成: {service.embedding_dim}')
|
||||
|
||||
util.log(1, 'embedding 服务预热完成')
|
||||
except Exception as e:
|
||||
util.log(1, f'embedding 服务预热失败: {str(e)}')
|
||||
|
||||
#开启核心服务
|
||||
util.log(1, '开启核心服务...')
|
||||
feiFei = get_fay_core().FeiFei()
|
||||
|
||||
@@ -12,6 +12,7 @@ from simulation_engine.settings import *
|
||||
from simulation_engine.global_methods import *
|
||||
from simulation_engine.gpt_structure import *
|
||||
from simulation_engine.llm_json_parser import *
|
||||
from utils import util
|
||||
|
||||
|
||||
def run_gpt_generate_importance(
|
||||
@@ -31,7 +32,7 @@ def run_gpt_generate_importance(
|
||||
gpt_response = extract_first_json_dict(gpt_response)
|
||||
# 处理gpt_response为None的情况
|
||||
if gpt_response is None:
|
||||
print("警告: extract_first_json_dict返回None,使用默认值")
|
||||
util.log(2, "警告: extract_first_json_dict返回None,使用默认值")
|
||||
return [50] # 返回默认重要性分数
|
||||
return list(gpt_response.values())
|
||||
|
||||
@@ -172,11 +173,11 @@ def normalize_dict_floats(d, target_min, target_max):
|
||||
"""
|
||||
# 检查字典是否为None或为空
|
||||
if d is None:
|
||||
print("警告: normalize_dict_floats接收到None字典")
|
||||
util.log(2, "警告: normalize_dict_floats接收到None字典")
|
||||
return {}
|
||||
|
||||
if not d:
|
||||
print("警告: normalize_dict_floats接收到空字典")
|
||||
util.log(2, "警告: normalize_dict_floats接收到空字典")
|
||||
return {}
|
||||
|
||||
try:
|
||||
@@ -193,7 +194,7 @@ def normalize_dict_floats(d, target_min, target_max):
|
||||
/ range_val + target_min)
|
||||
return d
|
||||
except Exception as e:
|
||||
print(f"normalize_dict_floats处理字典时出错: {str(e)}")
|
||||
util.log(3, f"normalize_dict_floats处理字典时出错: {str(e)}")
|
||||
# 返回原始字典,避免处理失败
|
||||
return d
|
||||
|
||||
@@ -238,11 +239,11 @@ def extract_recency(seq_nodes):
|
||||
"""
|
||||
# 检查seq_nodes是否为None或为空
|
||||
if seq_nodes is None:
|
||||
print("警告: extract_recency接收到None节点列表")
|
||||
util.log(2, "警告: extract_recency接收到None节点列表")
|
||||
return {}
|
||||
|
||||
if not seq_nodes:
|
||||
print("警告: extract_recency接收到空节点列表")
|
||||
util.log(2, "警告: extract_recency接收到空节点列表")
|
||||
return {}
|
||||
|
||||
try:
|
||||
@@ -250,11 +251,11 @@ def extract_recency(seq_nodes):
|
||||
normalized_timestamps = []
|
||||
for node in seq_nodes:
|
||||
if node is None:
|
||||
print("警告: 节点为None,跳过")
|
||||
util.log(2, "警告: 节点为None,跳过")
|
||||
continue
|
||||
|
||||
if not hasattr(node, 'last_retrieved'):
|
||||
print(f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0")
|
||||
util.log(2, f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0")
|
||||
normalized_timestamps.append(0)
|
||||
continue
|
||||
|
||||
@@ -284,18 +285,18 @@ def extract_recency(seq_nodes):
|
||||
recency_out[node.node_id] = (recency_decay
|
||||
** (max_timestep - last_retrieved))
|
||||
except Exception as e:
|
||||
print(f"计算节点 {node.node_id} 的recency时出错: {str(e)}")
|
||||
util.log(3, f"计算节点 {node.node_id} 的recency时出错: {str(e)}")
|
||||
# 使用默认值
|
||||
recency_out[node.node_id] = 1.0
|
||||
|
||||
return recency_out
|
||||
except Exception as e:
|
||||
print(f"extract_recency处理节点列表时出错: {str(e)}")
|
||||
util.log(3, f"extract_recency处理节点列表时出错: {str(e)}")
|
||||
# 返回一个默认字典
|
||||
return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
|
||||
|
||||
|
||||
def extract_importance(seq_nodes):
|
||||
def extract_importance(seq_nodes):
|
||||
"""
|
||||
Gets the current Persona object and a list of nodes that are in a
|
||||
chronological order, and outputs a dictionary that has the importance score
|
||||
@@ -309,22 +310,22 @@ def extract_importance(seq_nodes):
|
||||
"""
|
||||
# 检查seq_nodes是否为None或为空
|
||||
if seq_nodes is None:
|
||||
print("警告: extract_importance接收到None节点列表")
|
||||
util.log(2, "警告: extract_importance接收到None节点列表")
|
||||
return {}
|
||||
|
||||
if not seq_nodes:
|
||||
print("警告: extract_importance接收到空节点列表")
|
||||
util.log(2, "警告: extract_importance接收到空节点列表")
|
||||
return {}
|
||||
|
||||
try:
|
||||
importance_out = dict()
|
||||
for count, node in enumerate(seq_nodes):
|
||||
if node is None:
|
||||
print("警告: 节点为None,跳过")
|
||||
util.log(2, "警告: 节点为None,跳过")
|
||||
continue
|
||||
|
||||
if not hasattr(node, 'node_id') or not hasattr(node, 'importance'):
|
||||
print(f"警告: 节点缺少必要属性,跳过")
|
||||
util.log(2, f"警告: 节点缺少必要属性,跳过")
|
||||
continue
|
||||
|
||||
# 确保importance是数值类型
|
||||
@@ -333,32 +334,32 @@ def extract_importance(seq_nodes):
|
||||
importance_out[node.node_id] = float(node.importance)
|
||||
except ValueError:
|
||||
# 如果无法转换为数值,使用默认值
|
||||
print(f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值")
|
||||
util.log(2, f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值")
|
||||
importance_out[node.node_id] = 50.0
|
||||
else:
|
||||
importance_out[node.node_id] = node.importance
|
||||
|
||||
return importance_out
|
||||
except Exception as e:
|
||||
print(f"extract_importance处理节点列表时出错: {str(e)}")
|
||||
# 返回一个默认字典
|
||||
return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
|
||||
|
||||
|
||||
def _is_valid_embedding(vec, expected_dim):
|
||||
if vec is None:
|
||||
return False
|
||||
if not isinstance(vec, (list, tuple)):
|
||||
return False
|
||||
if expected_dim is not None and len(vec) != expected_dim:
|
||||
return False
|
||||
for val in vec:
|
||||
if not isinstance(val, (int, float)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def extract_relevance(seq_nodes, embeddings, focal_pt):
|
||||
except Exception as e:
|
||||
util.log(3, f"extract_importance处理节点列表时出错: {str(e)}")
|
||||
# 返回一个默认字典
|
||||
return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')}
|
||||
|
||||
|
||||
def _is_valid_embedding(vec, expected_dim):
|
||||
if vec is None:
|
||||
return False
|
||||
if not isinstance(vec, (list, tuple)):
|
||||
return False
|
||||
if expected_dim is not None and len(vec) != expected_dim:
|
||||
return False
|
||||
for val in vec:
|
||||
if not isinstance(val, (int, float)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def extract_relevance(seq_nodes, embeddings, focal_pt):
|
||||
"""
|
||||
Gets the current Persona object, a list of seq_nodes that are in a
|
||||
chronological order, and the focal_pt string and outputs a dictionary
|
||||
@@ -373,45 +374,39 @@ def extract_relevance(seq_nodes, embeddings, focal_pt):
|
||||
"""
|
||||
# 确保embeddings不为None
|
||||
if embeddings is None:
|
||||
print("警告: embeddings为None,使用空字典代替")
|
||||
util.log(2, "警告: embeddings为None,使用空字典代替")
|
||||
embeddings = {}
|
||||
|
||||
try:
|
||||
focal_embedding = get_text_embedding(focal_pt)
|
||||
except Exception as e:
|
||||
print(f"获取焦点嵌入向量时出错: {str(e)}")
|
||||
# 如果无法获取嵌入向量,返回默认值
|
||||
return {node.node_id: 0.5 for node in seq_nodes}
|
||||
|
||||
expected_dim = len(focal_embedding) if isinstance(focal_embedding, (list, tuple)) else None
|
||||
relevance_out = dict()
|
||||
for count, node in enumerate(seq_nodes):
|
||||
try:
|
||||
# 检查节点内容是否在embeddings中
|
||||
if node.content in embeddings:
|
||||
node_embedding = embeddings[node.content]
|
||||
if not _is_valid_embedding(node_embedding, expected_dim):
|
||||
try:
|
||||
regenerated = get_text_embedding(node.content)
|
||||
if _is_valid_embedding(regenerated, expected_dim):
|
||||
embeddings[node.content] = regenerated
|
||||
node_embedding = regenerated
|
||||
else:
|
||||
print("Warning: regenerated embedding has unexpected dimension, using default relevance")
|
||||
node_embedding = None
|
||||
except Exception as e:
|
||||
print(f"Regenerate embedding failed: {str(e)}")
|
||||
node_embedding = None
|
||||
# 计算余弦相似度
|
||||
if node_embedding is None:
|
||||
relevance_out[node.node_id] = 0.5
|
||||
else:
|
||||
relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding)
|
||||
else:
|
||||
# 如果没有对应的嵌入向量,使用默认值
|
||||
relevance_out[node.node_id] = 0.5
|
||||
try:
|
||||
focal_embedding = get_text_embedding(focal_pt)
|
||||
except Exception as e:
|
||||
util.log(3, f"获取焦点嵌入向量时出错: {str(e)}")
|
||||
# 如果无法获取嵌入向量,返回默认值
|
||||
return {node.node_id: 0.5 for node in seq_nodes}
|
||||
|
||||
expected_dim = len(focal_embedding) if isinstance(focal_embedding, (list, tuple)) else None
|
||||
relevance_out = dict()
|
||||
for count, node in enumerate(seq_nodes):
|
||||
try:
|
||||
# 检查节点内容是否在embeddings中
|
||||
if node.content in embeddings:
|
||||
node_embedding = embeddings[node.content]
|
||||
if not _is_valid_embedding(node_embedding, expected_dim):
|
||||
# 维度检查与修复在启动阶段完成,这里不再重算,避免首条消息扣时间
|
||||
current_dim = len(node_embedding) if isinstance(node_embedding, (list, tuple)) else "未知"
|
||||
util.log(2, f"检索时发现维度不一致的embedding: 节点ID={node.node_id}, 内容='{node.content[:30]}...', 当前维度={current_dim}, 期望维度={expected_dim}")
|
||||
util.log(2, f" -> 使用默认相关性分数 0.5 (建议重启系统进行维度修复)")
|
||||
node_embedding = None
|
||||
# 计算余弦相似度
|
||||
if node_embedding is None:
|
||||
relevance_out[node.node_id] = 0.5
|
||||
else:
|
||||
relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding)
|
||||
else:
|
||||
# 如果没有对应的嵌入向量,使用默认值
|
||||
relevance_out[node.node_id] = 0.5
|
||||
except Exception as e:
|
||||
print(f"计算节点 {node.node_id} 的相关性时出错: {str(e)}")
|
||||
util.log(3, f"计算节点 {node.node_id} 的相关性时出错: {str(e)}")
|
||||
# 如果计算过程中出错,使用默认值
|
||||
relevance_out[node.node_id] = 0.5
|
||||
|
||||
@@ -425,12 +420,12 @@ def extract_relevance(seq_nodes, embeddings, focal_pt):
|
||||
class ConceptNode:
|
||||
def __init__(self, node_dict):
|
||||
# Loading the content of a memory node in the memory stream.
|
||||
self.node_id = node_dict["node_id"]
|
||||
self.node_type = node_dict["node_type"]
|
||||
self.content = node_dict["content"]
|
||||
self.importance = node_dict["importance"]
|
||||
self.datetime = node_dict.get("datetime", "")
|
||||
# 确保created是整数类型
|
||||
self.node_id = node_dict["node_id"]
|
||||
self.node_type = node_dict["node_type"]
|
||||
self.content = node_dict["content"]
|
||||
self.importance = node_dict["importance"]
|
||||
self.datetime = node_dict.get("datetime", "")
|
||||
# 确保created是整数类型
|
||||
self.created = int(node_dict["created"]) if node_dict["created"] is not None else 0
|
||||
# 确保last_retrieved是整数类型
|
||||
self.last_retrieved = int(node_dict["last_retrieved"]) if node_dict["last_retrieved"] is not None else 0
|
||||
@@ -449,10 +444,10 @@ class ConceptNode:
|
||||
curr_package = {}
|
||||
curr_package["node_id"] = self.node_id
|
||||
curr_package["node_type"] = self.node_type
|
||||
curr_package["content"] = self.content
|
||||
curr_package["importance"] = self.importance
|
||||
curr_package["datetime"] = self.datetime
|
||||
curr_package["created"] = self.created
|
||||
curr_package["content"] = self.content
|
||||
curr_package["importance"] = self.importance
|
||||
curr_package["datetime"] = self.datetime
|
||||
curr_package["created"] = self.created
|
||||
curr_package["last_retrieved"] = self.last_retrieved
|
||||
curr_package["pointer_id"] = self.pointer_id
|
||||
|
||||
@@ -474,6 +469,87 @@ class MemoryStream:
|
||||
self.id_to_node[new_node.node_id] = new_node
|
||||
|
||||
self.embeddings = embeddings
|
||||
self._embedding_dim_checked = False
|
||||
|
||||
|
||||
def precheck_embedding_dimensions(self, force: bool = False):
|
||||
"""
|
||||
启动阶段检查并修复记忆节点 embedding 维度,避免首条消息检索时重算。
|
||||
"""
|
||||
result = {"checked": False, "expected_dim": None, "fixed": 0}
|
||||
if self._embedding_dim_checked and not force:
|
||||
return result
|
||||
# 确保embeddings不为None
|
||||
if self.embeddings is None:
|
||||
self.embeddings = {}
|
||||
if not force:
|
||||
return result
|
||||
|
||||
try:
|
||||
# 首先尝试从已初始化的embedding服务获取维度,避免重复调用
|
||||
from utils.api_embedding_service import get_embedding_service
|
||||
from utils import util
|
||||
service = get_embedding_service()
|
||||
|
||||
# 如果服务已经有维度信息,直接使用
|
||||
if hasattr(service, 'embedding_dim') and service.embedding_dim is not None:
|
||||
expected_dim = service.embedding_dim
|
||||
util.log(1, f"使用已初始化的embedding服务维度: {expected_dim}")
|
||||
else:
|
||||
# 只有在服务未初始化维度时才调用dimension_check
|
||||
util.log(1, "embedding服务维度未初始化,进行维度检查...")
|
||||
sample_embedding = get_text_embedding("dimension_check")
|
||||
expected_dim = len(sample_embedding) if isinstance(sample_embedding, (list, tuple)) else None
|
||||
|
||||
except Exception as e:
|
||||
from utils import util
|
||||
util.log(2, f"启动阶段 embedding 维度检查失败: {str(e)}")
|
||||
# 即使检查失败,也标记为已检查,避免重复尝试
|
||||
self._embedding_dim_checked = True
|
||||
return result
|
||||
|
||||
if expected_dim is None:
|
||||
from utils import util
|
||||
util.log(2, "无法获取 embedding 维度,跳过维度检查")
|
||||
self._embedding_dim_checked = True
|
||||
return result
|
||||
|
||||
fixed = 0
|
||||
if self.seq_nodes:
|
||||
contents = [node.content for node in self.seq_nodes if node is not None]
|
||||
else:
|
||||
contents = list(self.embeddings.keys())
|
||||
|
||||
for content in contents:
|
||||
if content in self.embeddings:
|
||||
node_embedding = self.embeddings[content]
|
||||
if not _is_valid_embedding(node_embedding, expected_dim):
|
||||
# 记录维度不一致的详细信息
|
||||
current_dim = len(node_embedding) if isinstance(node_embedding, (list, tuple)) else "未知"
|
||||
from utils import util
|
||||
util.log(2, f"发现维度不一致的embedding: 内容='{content[:30]}...', 当前维度={current_dim}, 期望维度={expected_dim}")
|
||||
|
||||
try:
|
||||
util.log(1, f"正在重新生成embedding: '{content[:30]}...'")
|
||||
regenerated = get_text_embedding(content)
|
||||
if regenerated is not None and _is_valid_embedding(regenerated, expected_dim):
|
||||
self.embeddings[content] = regenerated
|
||||
fixed += 1
|
||||
util.log(1, f"成功修复embedding维度: '{content[:30]}...' ({current_dim} -> {len(regenerated)})")
|
||||
else:
|
||||
regenerated_dim = len(regenerated) if isinstance(regenerated, (list, tuple)) else "未知"
|
||||
util.log(2, f"重新生成的embedding维度仍不一致: '{content[:30]}...' (期望={expected_dim}, 实际={regenerated_dim}),已忽略")
|
||||
except Exception as e:
|
||||
util.log(2, f"重新生成embedding失败: '{content[:30]}...' - {str(e)}")
|
||||
|
||||
if fixed > 0:
|
||||
from utils import util
|
||||
util.log(1, f"启动阶段已修复 {fixed} 条记忆节点 embedding 维度")
|
||||
self._embedding_dim_checked = True
|
||||
result["checked"] = True
|
||||
result["expected_dim"] = expected_dim
|
||||
result["fixed"] = fixed
|
||||
return result
|
||||
|
||||
|
||||
def count_observations(self):
|
||||
@@ -503,8 +579,8 @@ class MemoryStream:
|
||||
the elemnts of the list are the query sentences.
|
||||
time_step: Current time_step
|
||||
n_count: The number of nodes that we want to retrieve.
|
||||
curr_filter: Filtering the node.type that we want to retrieve.
|
||||
Acceptable values are 'all', 'reflection', 'observation', 'conversation'
|
||||
curr_filter: Filtering the node.type that we want to retrieve.
|
||||
Acceptable values are 'all', 'reflection', 'observation', 'conversation'
|
||||
hp: Hyperparameter for [recency_w, relevance_w, importance_w]
|
||||
verbose: verbose
|
||||
Returns:
|
||||
@@ -517,8 +593,8 @@ class MemoryStream:
|
||||
if len(self.seq_nodes) == 0:
|
||||
return dict()
|
||||
|
||||
# Filtering for the desired node type. curr_filter can be one of the
|
||||
# elements: 'all', 'reflection', 'observation', 'conversation'
|
||||
# Filtering for the desired node type. curr_filter can be one of the
|
||||
# elements: 'all', 'reflection', 'observation', 'conversation'
|
||||
if curr_filter == "all":
|
||||
curr_nodes = self.seq_nodes
|
||||
else:
|
||||
@@ -528,7 +604,7 @@ class MemoryStream:
|
||||
|
||||
# 确保embeddings不为None
|
||||
if self.embeddings is None:
|
||||
print("警告: 在retrieve方法中,embeddings为None,初始化为空字典")
|
||||
util.log(2, "警告: 在retrieve方法中,embeddings为None,初始化为空字典")
|
||||
self.embeddings = {}
|
||||
|
||||
# <retrieved> is the main dictionary that we are returning
|
||||
@@ -587,7 +663,7 @@ class MemoryStream:
|
||||
|
||||
Parameters:
|
||||
time_step: Current time_step
|
||||
node_type: type of node -- it's either reflection, observation, conversation
|
||||
node_type: type of node -- it's either reflection, observation, conversation
|
||||
content: the str content of the memory record
|
||||
importance: int score of the importance score
|
||||
pointer_id: the str of the parent node
|
||||
@@ -597,11 +673,11 @@ class MemoryStream:
|
||||
"""
|
||||
node_dict = dict()
|
||||
node_dict["node_id"] = len(self.seq_nodes)
|
||||
node_dict["node_type"] = node_type
|
||||
node_dict["content"] = content
|
||||
node_dict["importance"] = importance
|
||||
node_dict["datetime"] = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")
|
||||
node_dict["created"] = time_step
|
||||
node_dict["node_type"] = node_type
|
||||
node_dict["content"] = content
|
||||
node_dict["importance"] = importance
|
||||
node_dict["datetime"] = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")
|
||||
node_dict["created"] = time_step
|
||||
node_dict["last_retrieved"] = time_step
|
||||
node_dict["pointer_id"] = pointer_id
|
||||
new_node = ConceptNode(node_dict)
|
||||
@@ -616,19 +692,19 @@ class MemoryStream:
|
||||
try:
|
||||
self.embeddings[content] = get_text_embedding(content)
|
||||
except Exception as e:
|
||||
print(f"获取文本嵌入时出错: {str(e)}")
|
||||
util.log(3, f"获取文本嵌入时出错: {str(e)}")
|
||||
# 如果获取嵌入失败,使用空列表代替
|
||||
self.embeddings[content] = []
|
||||
|
||||
|
||||
def remember(self, content, time_step=0):
|
||||
score = generate_importance_score([content])[0]
|
||||
self._add_node(time_step, "observation", content, score, None)
|
||||
|
||||
def remember_conversation(self, content, time_step=0):
|
||||
score = generate_importance_score([content])[0]
|
||||
self._add_node(time_step, "conversation", content, score, None)
|
||||
|
||||
def remember(self, content, time_step=0):
|
||||
score = generate_importance_score([content])[0]
|
||||
self._add_node(time_step, "observation", content, score, None)
|
||||
|
||||
def remember_conversation(self, content, time_step=0):
|
||||
score = generate_importance_score([content])[0]
|
||||
self._add_node(time_step, "conversation", content, score, None)
|
||||
|
||||
|
||||
def reflect(self, anchor, reflection_count=5,
|
||||
retrieval_count=120, time_step=0):
|
||||
|
||||
@@ -1517,9 +1517,28 @@ def init_memory_scheduler():
|
||||
global agents
|
||||
|
||||
# 确保agent已经创建
|
||||
if not agents:
|
||||
util.log(1, '创建代理实例...')
|
||||
create_agent()
|
||||
agent = None
|
||||
if not agents:
|
||||
util.log(1, '创建代理实例...')
|
||||
agent = create_agent()
|
||||
else:
|
||||
agent = agents.get("User")
|
||||
if agent is None and len(agents) > 0:
|
||||
agent = next(iter(agents.values()))
|
||||
|
||||
# 启动阶段做一次 embedding 维度检查,避免首条消息时触发
|
||||
try:
|
||||
if agent and agent.memory_stream and hasattr(agent.memory_stream, "precheck_embedding_dimensions"):
|
||||
result = agent.memory_stream.precheck_embedding_dimensions()
|
||||
if result.get("checked"):
|
||||
util.log(
|
||||
1,
|
||||
f"启动阶段记忆 embedding 维度检查完成: dim={result.get('expected_dim')}, 修复={result.get('fixed')}"
|
||||
)
|
||||
else:
|
||||
util.log(1, "启动阶段记忆 embedding 维度检查跳过(无记忆/无embedding)")
|
||||
except Exception as e:
|
||||
util.log(1, f"启动阶段 embedding 维度检查失败: {str(e)}")
|
||||
|
||||
# 设置每天0点保存记忆
|
||||
schedule.every().day.at("00:00").do(save_agent_memory)
|
||||
@@ -1640,8 +1659,28 @@ def create_agent(username=None):
|
||||
agent.scratch = scratch_data
|
||||
|
||||
# 如果memory目录存在且不为空,则加载之前保存的记忆(不包括scratch数据)
|
||||
if is_exist:
|
||||
load_agent_memory(agent, username)
|
||||
if is_exist:
|
||||
load_agent_memory(agent, username)
|
||||
try:
|
||||
if agent.memory_stream and hasattr(agent.memory_stream, "precheck_embedding_dimensions"):
|
||||
result = agent.memory_stream.precheck_embedding_dimensions(force=True)
|
||||
if result.get("checked"):
|
||||
util.log(
|
||||
1,
|
||||
f"启动阶段记忆 embedding 维度检查完成: dim={result.get('expected_dim')}, 修复={result.get('fixed')}"
|
||||
)
|
||||
if result.get("fixed"):
|
||||
try:
|
||||
embeddings_path = os.path.join(memory_dir, "memory_stream", "embeddings.json")
|
||||
with open(embeddings_path, "w", encoding="utf-8") as f:
|
||||
json.dump(agent.memory_stream.embeddings or {}, f, ensure_ascii=False, indent=2)
|
||||
util.log(1, f"启动阶段已写回 embeddings.json (修复={result.get('fixed')})")
|
||||
except Exception as write_err:
|
||||
util.log(1, f"写回 embeddings.json 失败: {str(write_err)}")
|
||||
else:
|
||||
util.log(1, "启动阶段记忆 embedding 维度检查跳过(无记忆/无embedding)")
|
||||
except Exception as e:
|
||||
util.log(1, f"启动阶段 embedding 维度检查失败: {str(e)}")
|
||||
|
||||
# 缓存到字典
|
||||
agents[username] = agent
|
||||
|
||||
@@ -123,6 +123,7 @@ class ApiEmbeddingService:
|
||||
def encode_text(self, text: str) -> List[float]:
|
||||
"""编码单个文本(带重试机制)"""
|
||||
import time
|
||||
import requests.exceptions
|
||||
|
||||
text = _sanitize_text(text)
|
||||
last_error = None
|
||||
@@ -152,7 +153,15 @@ class ApiEmbeddingService:
|
||||
# 首次调用时获取实际维度
|
||||
if self.embedding_dim is None:
|
||||
self.embedding_dim = len(embedding)
|
||||
logger.info(f"动态获取 embedding 维度: {self.embedding_dim},与原记忆节点不一致,将重新生成记忆节点的 embedding维度")
|
||||
logger.info(f"动态获取 embedding 维度: {self.embedding_dim}")
|
||||
else:
|
||||
# 检查维度一致性
|
||||
current_dim = len(embedding)
|
||||
if current_dim != self.embedding_dim:
|
||||
logger.warning(f"⚠️ Embedding维度不一致! 期望={self.embedding_dim}, 实际={current_dim}, 文本='{text_preview}'")
|
||||
logger.warning(f" 建议检查API配置或模型设置")
|
||||
# 更新维度记录
|
||||
self.embedding_dim = current_dim
|
||||
|
||||
logger.info(f"embedding 生成成功")
|
||||
return embedding
|
||||
@@ -168,6 +177,38 @@ class ApiEmbeddingService:
|
||||
logger.error(f"所有重试均失败,文本长度: {len(text)}")
|
||||
raise
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
last_error = e
|
||||
# 网络连接错误,包括DNS解析失败
|
||||
logger.warning(f"网络连接失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
if attempt < self.max_retries:
|
||||
wait_time = 2 ** attempt # 指数退避
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"网络连接持续失败,请检查网络设置和API地址: {self.api_base_url}")
|
||||
raise
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
last_error = e
|
||||
# HTTP错误(4xx, 5xx)
|
||||
status_code = e.response.status_code if e.response else "未知"
|
||||
logger.error(f"HTTP错误 {status_code} (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
|
||||
# 对于客户端错误(4xx),通常不需要重试
|
||||
if e.response and 400 <= e.response.status_code < 500:
|
||||
logger.error(f"客户端错误,不进行重试: {e.response.text}")
|
||||
raise
|
||||
|
||||
# 对于服务器错误(5xx),可以重试
|
||||
if attempt < self.max_retries:
|
||||
wait_time = 2 ** attempt
|
||||
logger.info(f"服务器错误,等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"服务器错误持续发生")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error(f"文本编码失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
@@ -202,6 +243,28 @@ class ApiEmbeddingService:
|
||||
result = response.json()
|
||||
embeddings = [item['embedding'] for item in result['data']]
|
||||
|
||||
# 检查批量embedding的维度一致性
|
||||
if embeddings:
|
||||
dimensions = [len(emb) for emb in embeddings]
|
||||
unique_dims = set(dimensions)
|
||||
|
||||
if len(unique_dims) > 1:
|
||||
logger.warning(f"⚠️ 批量embedding维度不一致: {dict(zip(range(len(dimensions)), dimensions))}")
|
||||
logger.warning(f" 唯一维度: {unique_dims}")
|
||||
|
||||
# 检查与已知维度的一致性
|
||||
if self.embedding_dim is not None:
|
||||
for i, dim in enumerate(dimensions):
|
||||
if dim != self.embedding_dim:
|
||||
text_preview = texts[i][:30] + "..." if len(texts[i]) > 30 else texts[i]
|
||||
logger.warning(f"⚠️ 文本{i}维度不一致: 期望={self.embedding_dim}, 实际={dim}, 文本='{text_preview}'")
|
||||
else:
|
||||
# 首次批量调用,设置维度
|
||||
if dimensions:
|
||||
self.embedding_dim = dimensions[0]
|
||||
logger.info(f"从批量请求动态获取 embedding 维度: {self.embedding_dim}")
|
||||
|
||||
logger.info(f"批量 embedding 生成成功: {len(embeddings)} 个向量")
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"批量文本编码失败: {e}")
|
||||
@@ -217,6 +280,29 @@ class ApiEmbeddingService:
|
||||
"service_type": "api"
|
||||
}
|
||||
|
||||
def health_check(self) -> dict:
|
||||
"""健康检查:测试embedding服务是否正常工作"""
|
||||
try:
|
||||
# 使用简单文本测试服务
|
||||
test_text = "health_check"
|
||||
embedding = self.encode_text(test_text)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"model": self.model_name,
|
||||
"api_url": self.api_base_url,
|
||||
"embedding_dim": len(embedding) if embedding else None,
|
||||
"test_successful": True
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"model": self.model_name,
|
||||
"api_url": self.api_base_url,
|
||||
"error": str(e),
|
||||
"test_successful": False
|
||||
}
|
||||
|
||||
# 全局实例
|
||||
_global_embedding_service = None
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import codecs
|
||||
from langsmith.schemas import Feedback
|
||||
# from langsmith.schemas import Feedback # 临时注释
|
||||
import requests
|
||||
from configparser import ConfigParser
|
||||
import functools
|
||||
@@ -53,36 +53,58 @@ start_mode = None
|
||||
fay_url = None
|
||||
system_conf_path = None
|
||||
config_json_path = None
|
||||
use_bionic_memory = None
|
||||
|
||||
# Embedding API 配置全局变量
|
||||
embedding_api_model = None
|
||||
embedding_api_base_url = None
|
||||
embedding_api_key = None
|
||||
|
||||
# 避免重复加载配置中心导致日志刷屏
|
||||
_last_loaded_project_id = None
|
||||
_last_loaded_config = None
|
||||
_last_loaded_from_api = False # 表示上次加载来自配置中心(含缓存)
|
||||
_bootstrap_loaded_from_api = False # 无本地配置时启动阶段已从配置中心加载过
|
||||
_warned_public_project_ids = set()
|
||||
use_bionic_memory = None
|
||||
|
||||
# Embedding API 配置全局变量
|
||||
embedding_api_model = None
|
||||
embedding_api_base_url = None
|
||||
embedding_api_key = None
|
||||
|
||||
# 避免重复加载配置中心导致日志刷屏
|
||||
_last_loaded_project_id = None
|
||||
_last_loaded_config = None
|
||||
_last_loaded_from_api = False # 表示上次加载来自配置中心(含缓存)
|
||||
_bootstrap_loaded_from_api = False # 无本地配置时启动阶段已从配置中心加载过
|
||||
_warned_public_config_keys = set()
|
||||
|
||||
# Public config center identifiers (warn users if matched)
|
||||
PUBLIC_CONFIG_PROJECT_ID = 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
|
||||
PUBLIC_CONFIG_BASE_URL = 'http://1.12.69.110:5500'
|
||||
|
||||
|
||||
def _public_config_warn_key():
|
||||
base_url = (CONFIG_SERVER.get('BASE_URL') or '').rstrip('/')
|
||||
public_base = PUBLIC_CONFIG_BASE_URL.rstrip('/')
|
||||
if base_url == public_base:
|
||||
return f"base_url:{base_url}"
|
||||
if CONFIG_SERVER.get('PROJECT_ID') == PUBLIC_CONFIG_PROJECT_ID:
|
||||
return f"project_id:{CONFIG_SERVER.get('PROJECT_ID')}"
|
||||
return None
|
||||
|
||||
|
||||
def _warn_public_config_once():
|
||||
key = _public_config_warn_key()
|
||||
if not key or key in _warned_public_config_keys:
|
||||
return
|
||||
_warned_public_config_keys.add(key)
|
||||
print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m")
|
||||
|
||||
# config server中心配置,system.conf与config.json存在时不会使用配置中心
|
||||
CONFIG_SERVER = {
|
||||
'BASE_URL': 'http://1.12.69.110:5500', # 默认API服务器地址
|
||||
'API_KEY': 'your-api-key-here', # 默认API密钥
|
||||
'PROJECT_ID': 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' # 项目ID,需要在使用前设置
|
||||
}
|
||||
|
||||
def _refresh_config_center():
|
||||
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
|
||||
if env_project_id:
|
||||
CONFIG_SERVER['PROJECT_ID'] = env_project_id
|
||||
|
||||
_refresh_config_center()
|
||||
|
||||
def load_config_from_api(project_id=None):
|
||||
global CONFIG_SERVER
|
||||
CONFIG_SERVER = {
|
||||
'BASE_URL': 'http://1.12.69.110:5500', # 默认API服务器地址
|
||||
'API_KEY': 'your-api-key-here', # 默认API密钥
|
||||
'PROJECT_ID': 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' # 项目ID,需要在使用前设置
|
||||
}
|
||||
|
||||
def _refresh_config_center():
|
||||
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
|
||||
if env_project_id:
|
||||
CONFIG_SERVER['PROJECT_ID'] = env_project_id
|
||||
|
||||
_refresh_config_center()
|
||||
|
||||
def load_config_from_api(project_id=None):
|
||||
global CONFIG_SERVER
|
||||
|
||||
"""
|
||||
从API加载配置
|
||||
@@ -159,15 +181,15 @@ def load_config_from_api(project_id=None):
|
||||
|
||||
return None
|
||||
|
||||
@synchronized
|
||||
def load_config(force_reload=False):
|
||||
"""
|
||||
加载配置文件,如果本地文件不存在则直接使用API加载
|
||||
@synchronized
|
||||
def load_config(force_reload=False):
|
||||
"""
|
||||
加载配置文件,如果本地文件不存在则直接使用API加载
|
||||
|
||||
Returns:
|
||||
包含配置信息的字典
|
||||
"""
|
||||
global config
|
||||
Returns:
|
||||
包含配置信息的字典
|
||||
"""
|
||||
global config
|
||||
global system_config
|
||||
global key_ali_nls_key_id
|
||||
global key_ali_nls_key_secret
|
||||
@@ -200,168 +222,153 @@ def load_config(force_reload=False):
|
||||
global embedding_api_base_url
|
||||
global embedding_api_key
|
||||
|
||||
global CONFIG_SERVER
|
||||
global system_conf_path
|
||||
global config_json_path
|
||||
global _last_loaded_project_id
|
||||
global _last_loaded_config
|
||||
global _last_loaded_from_api
|
||||
global _bootstrap_loaded_from_api
|
||||
|
||||
_refresh_config_center()
|
||||
|
||||
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
|
||||
explicit_config_center = bool(env_project_id)
|
||||
using_config_center = explicit_config_center
|
||||
if (
|
||||
env_project_id
|
||||
and not force_reload
|
||||
and _last_loaded_config is not None
|
||||
and _last_loaded_project_id == env_project_id
|
||||
and _last_loaded_from_api
|
||||
):
|
||||
return _last_loaded_config
|
||||
|
||||
default_system_conf_path = os.path.join(os.getcwd(), 'system.conf')
|
||||
default_config_json_path = os.path.join(os.getcwd(), 'config.json')
|
||||
cache_system_conf_path = os.path.join(os.getcwd(), 'cache_data', 'system.conf')
|
||||
cache_config_json_path = os.path.join(os.getcwd(), 'cache_data', 'config.json')
|
||||
root_system_conf_exists = os.path.exists(default_system_conf_path)
|
||||
root_config_json_exists = os.path.exists(default_config_json_path)
|
||||
root_config_complete = root_system_conf_exists and root_config_json_exists
|
||||
|
||||
# 构建system.conf和config.json的完整路径
|
||||
config_center_fallback = False
|
||||
if using_config_center:
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
if (
|
||||
system_conf_path is None
|
||||
or config_json_path is None
|
||||
or system_conf_path == cache_system_conf_path
|
||||
or config_json_path == cache_config_json_path
|
||||
):
|
||||
system_conf_path = default_system_conf_path
|
||||
config_json_path = default_config_json_path
|
||||
|
||||
if not root_config_complete:
|
||||
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
|
||||
if (not _bootstrap_loaded_from_api) or (not cache_ready):
|
||||
using_config_center = True
|
||||
config_center_fallback = True
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
|
||||
forced_loaded = False
|
||||
loaded_from_api = False
|
||||
api_attempted = False
|
||||
if using_config_center:
|
||||
if explicit_config_center:
|
||||
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
else:
|
||||
util.log(1, f"未检测到本地system.conf或config.json,尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
api_attempted = True
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
if config_center_fallback:
|
||||
_bootstrap_loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
forced_loaded = True
|
||||
|
||||
if (
|
||||
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
|
||||
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
|
||||
):
|
||||
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
|
||||
print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m")
|
||||
else:
|
||||
util.log(2, "配置中心加载失败,尝试使用缓存配置")
|
||||
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
|
||||
# 如果任一本地文件不存在,直接尝试从API加载
|
||||
if (not sys_conf_exists or not config_json_exists) and not forced_loaded:
|
||||
if using_config_center:
|
||||
if not api_attempted:
|
||||
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
api_attempted = True
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
if config_center_fallback:
|
||||
_bootstrap_loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
|
||||
if (
|
||||
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
|
||||
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
|
||||
):
|
||||
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
|
||||
print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m")
|
||||
else:
|
||||
# 使用提取的项目ID或全局项目ID
|
||||
util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'和' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在),尝试从API加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
|
||||
if (
|
||||
CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69'
|
||||
and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids
|
||||
):
|
||||
_warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID'])
|
||||
print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m")
|
||||
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
if using_config_center and (not sys_conf_exists or not config_json_exists):
|
||||
if _last_loaded_config is not None and _last_loaded_from_api:
|
||||
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
|
||||
return _last_loaded_config
|
||||
if config_center_fallback and using_config_center and (not sys_conf_exists or not config_json_exists):
|
||||
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
|
||||
if cache_ready:
|
||||
util.log(2, "配置中心不可用,回退使用缓存配置")
|
||||
using_config_center = False
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
util.log(2, "配置中心不可用,回退使用本地配置文件")
|
||||
using_config_center = False
|
||||
system_conf_path = default_system_conf_path
|
||||
config_json_path = default_config_json_path
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
# 如果本地文件存在,从本地文件加载
|
||||
global CONFIG_SERVER
|
||||
global system_conf_path
|
||||
global config_json_path
|
||||
global _last_loaded_project_id
|
||||
global _last_loaded_config
|
||||
global _last_loaded_from_api
|
||||
global _bootstrap_loaded_from_api
|
||||
|
||||
_refresh_config_center()
|
||||
|
||||
env_project_id = os.getenv('FAY_CONFIG_CENTER_ID')
|
||||
explicit_config_center = bool(env_project_id)
|
||||
using_config_center = explicit_config_center
|
||||
if (
|
||||
env_project_id
|
||||
and not force_reload
|
||||
and _last_loaded_config is not None
|
||||
and _last_loaded_project_id == env_project_id
|
||||
and _last_loaded_from_api
|
||||
):
|
||||
return _last_loaded_config
|
||||
|
||||
default_system_conf_path = os.path.join(os.getcwd(), 'system.conf')
|
||||
default_config_json_path = os.path.join(os.getcwd(), 'config.json')
|
||||
cache_system_conf_path = os.path.join(os.getcwd(), 'cache_data', 'system.conf')
|
||||
cache_config_json_path = os.path.join(os.getcwd(), 'cache_data', 'config.json')
|
||||
root_system_conf_exists = os.path.exists(default_system_conf_path)
|
||||
root_config_json_exists = os.path.exists(default_config_json_path)
|
||||
root_config_complete = root_system_conf_exists and root_config_json_exists
|
||||
|
||||
# 构建system.conf和config.json的完整路径
|
||||
config_center_fallback = False
|
||||
if using_config_center:
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
if (
|
||||
system_conf_path is None
|
||||
or config_json_path is None
|
||||
or system_conf_path == cache_system_conf_path
|
||||
or config_json_path == cache_config_json_path
|
||||
):
|
||||
system_conf_path = default_system_conf_path
|
||||
config_json_path = default_config_json_path
|
||||
|
||||
if not root_config_complete:
|
||||
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
|
||||
if (not _bootstrap_loaded_from_api) or (not cache_ready):
|
||||
using_config_center = True
|
||||
config_center_fallback = True
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
|
||||
forced_loaded = False
|
||||
loaded_from_api = False
|
||||
api_attempted = False
|
||||
if using_config_center:
|
||||
if explicit_config_center:
|
||||
util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
else:
|
||||
util.log(1, f"未检测到本地system.conf或config.json,尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
api_attempted = True
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
if config_center_fallback:
|
||||
_bootstrap_loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
forced_loaded = True
|
||||
|
||||
_warn_public_config_once()
|
||||
else:
|
||||
util.log(2, "配置中心加载失败,尝试使用缓存配置")
|
||||
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
|
||||
# 如果任一本地文件不存在,直接尝试从API加载
|
||||
if (not sys_conf_exists or not config_json_exists) and not forced_loaded:
|
||||
if using_config_center:
|
||||
if not api_attempted:
|
||||
util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
api_attempted = True
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
if config_center_fallback:
|
||||
_bootstrap_loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
|
||||
_warn_public_config_once()
|
||||
else:
|
||||
# 使用提取的项目ID或全局项目ID
|
||||
util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'和' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在),尝试从API加载配置...")
|
||||
api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID'])
|
||||
|
||||
if api_config:
|
||||
util.log(1, "成功从配置中心加载配置")
|
||||
system_config = api_config['system_config']
|
||||
config = api_config['config']
|
||||
loaded_from_api = True
|
||||
|
||||
# 缓存API配置到本地文件
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
save_api_config_to_local(api_config, system_conf_path, config_json_path)
|
||||
|
||||
_warn_public_config_once()
|
||||
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
if using_config_center and (not sys_conf_exists or not config_json_exists):
|
||||
if _last_loaded_config is not None and _last_loaded_from_api:
|
||||
util.log(2, "配置中心缓存不可用,继续使用内存中的配置")
|
||||
return _last_loaded_config
|
||||
if config_center_fallback and using_config_center and (not sys_conf_exists or not config_json_exists):
|
||||
cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path)
|
||||
if cache_ready:
|
||||
util.log(2, "配置中心不可用,回退使用缓存配置")
|
||||
using_config_center = False
|
||||
system_conf_path = cache_system_conf_path
|
||||
config_json_path = cache_config_json_path
|
||||
else:
|
||||
util.log(2, "配置中心不可用,回退使用本地配置文件")
|
||||
using_config_center = False
|
||||
system_conf_path = default_system_conf_path
|
||||
config_json_path = default_config_json_path
|
||||
sys_conf_exists = os.path.exists(system_conf_path)
|
||||
config_json_exists = os.path.exists(config_json_path)
|
||||
# 如果本地文件存在,从本地文件加载
|
||||
# 加载system.conf
|
||||
system_config = ConfigParser()
|
||||
system_config.read(system_conf_path, encoding='UTF-8')
|
||||
@@ -417,9 +424,9 @@ def load_config(force_reload=False):
|
||||
use_bionic_memory = config.get('memory', {}).get('use_bionic_memory', False)
|
||||
|
||||
# 构建配置字典
|
||||
config_dict = {
|
||||
'system_config': system_config,
|
||||
'config': config,
|
||||
config_dict = {
|
||||
'system_config': system_config,
|
||||
'config': config,
|
||||
'ali_nls_key_id': key_ali_nls_key_id,
|
||||
'ali_nls_key_secret': key_ali_nls_key_secret,
|
||||
'ali_nls_app_key': key_ali_nls_app_key,
|
||||
@@ -455,14 +462,14 @@ def load_config(force_reload=False):
|
||||
'embedding_api_base_url': embedding_api_base_url,
|
||||
'embedding_api_key': embedding_api_key,
|
||||
|
||||
'source': 'api' if using_config_center else 'local' # 标记配置来源
|
||||
}
|
||||
|
||||
_last_loaded_project_id = CONFIG_SERVER['PROJECT_ID'] if using_config_center else None
|
||||
_last_loaded_config = config_dict
|
||||
_last_loaded_from_api = using_config_center
|
||||
|
||||
return config_dict
|
||||
'source': 'api' if using_config_center else 'local' # 标记配置来源
|
||||
}
|
||||
|
||||
_last_loaded_project_id = CONFIG_SERVER['PROJECT_ID'] if using_config_center else None
|
||||
_last_loaded_config = config_dict
|
||||
_last_loaded_from_api = using_config_center
|
||||
|
||||
return config_dict
|
||||
|
||||
def save_api_config_to_local(api_config, system_conf_path, config_json_path):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user