From eecbb931a91206b8f4ba085c0778acfe4650a2b3 Mon Sep 17 00:00:00 2001 From: guo zebin Date: Thu, 29 Jan 2026 15:09:37 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BA=A4=E4=BA=92=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. [增加]embedding维度检查自动修正逻辑。 --- fay_booter.py | 20 ++ genagents/modules/memory_stream.py | 280 ++++++++++++------- llm/nlp_cognitive_stream.py | 49 +++- utils/api_embedding_service.py | 88 +++++- utils/config_util.py | 427 +++++++++++++++-------------- 5 files changed, 546 insertions(+), 318 deletions(-) diff --git a/fay_booter.py b/fay_booter.py index ca5df93..1eba21b 100644 --- a/fay_booter.py +++ b/fay_booter.py @@ -369,6 +369,26 @@ def start(): util.log(1, '读取配置...') config_util.load_config() + # 启动阶段预热 embedding 服务(避免首条消息时才初始化维度) + try: + util.log(1, '启动阶段预热 embedding 服务...') + from simulation_engine.gpt_structure import get_text_embedding + + # 检查服务是否已经初始化 + from utils.api_embedding_service import get_embedding_service + service = get_embedding_service() + + if hasattr(service, 'embedding_dim') and service.embedding_dim is not None: + util.log(1, f'embedding 服务已初始化,维度: {service.embedding_dim}') + else: + util.log(1, '初始化 embedding 服务维度...') + get_text_embedding("dimension_check") + util.log(1, f'embedding 服务维度初始化完成: {service.embedding_dim}') + + util.log(1, 'embedding 服务预热完成') + except Exception as e: + util.log(1, f'embedding 服务预热失败: {str(e)}') + #开启核心服务 util.log(1, '开启核心服务...') feiFei = get_fay_core().FeiFei() diff --git a/genagents/modules/memory_stream.py b/genagents/modules/memory_stream.py index 27ca8bd..d3ffb7b 100644 --- a/genagents/modules/memory_stream.py +++ b/genagents/modules/memory_stream.py @@ -12,6 +12,7 @@ from simulation_engine.settings import * from simulation_engine.global_methods import * from simulation_engine.gpt_structure import * from simulation_engine.llm_json_parser import * +from utils import util def run_gpt_generate_importance( @@ -31,7 +32,7 @@ def run_gpt_generate_importance( gpt_response = extract_first_json_dict(gpt_response) # 处理gpt_response为None的情况 if gpt_response is None: - print("警告: extract_first_json_dict返回None,使用默认值") + util.log(2, "警告: extract_first_json_dict返回None,使用默认值") return [50] # 返回默认重要性分数 return list(gpt_response.values()) @@ -172,11 +173,11 @@ def normalize_dict_floats(d, target_min, target_max): """ # 检查字典是否为None或为空 if d is None: - print("警告: normalize_dict_floats接收到None字典") + util.log(2, "警告: normalize_dict_floats接收到None字典") return {} if not d: - print("警告: normalize_dict_floats接收到空字典") + util.log(2, "警告: normalize_dict_floats接收到空字典") return {} try: @@ -193,7 +194,7 @@ def normalize_dict_floats(d, target_min, target_max): / range_val + target_min) return d except Exception as e: - print(f"normalize_dict_floats处理字典时出错: {str(e)}") + util.log(3, f"normalize_dict_floats处理字典时出错: {str(e)}") # 返回原始字典,避免处理失败 return d @@ -238,11 +239,11 @@ def extract_recency(seq_nodes): """ # 检查seq_nodes是否为None或为空 if seq_nodes is None: - print("警告: extract_recency接收到None节点列表") + util.log(2, "警告: extract_recency接收到None节点列表") return {} if not seq_nodes: - print("警告: extract_recency接收到空节点列表") + util.log(2, "警告: extract_recency接收到空节点列表") return {} try: @@ -250,11 +251,11 @@ def extract_recency(seq_nodes): normalized_timestamps = [] for node in seq_nodes: if node is None: - print("警告: 节点为None,跳过") + util.log(2, "警告: 节点为None,跳过") continue if not hasattr(node, 'last_retrieved'): - print(f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0") + util.log(2, f"警告: 节点 {node} 没有last_retrieved属性,使用默认值0") normalized_timestamps.append(0) continue @@ -284,18 +285,18 @@ def extract_recency(seq_nodes): recency_out[node.node_id] = (recency_decay ** (max_timestep - last_retrieved)) except Exception as e: - print(f"计算节点 {node.node_id} 的recency时出错: {str(e)}") + util.log(3, f"计算节点 {node.node_id} 的recency时出错: {str(e)}") # 使用默认值 recency_out[node.node_id] = 1.0 return recency_out except Exception as e: - print(f"extract_recency处理节点列表时出错: {str(e)}") + util.log(3, f"extract_recency处理节点列表时出错: {str(e)}") # 返回一个默认字典 return {node.node_id: 1.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} -def extract_importance(seq_nodes): +def extract_importance(seq_nodes): """ Gets the current Persona object and a list of nodes that are in a chronological order, and outputs a dictionary that has the importance score @@ -309,22 +310,22 @@ def extract_importance(seq_nodes): """ # 检查seq_nodes是否为None或为空 if seq_nodes is None: - print("警告: extract_importance接收到None节点列表") + util.log(2, "警告: extract_importance接收到None节点列表") return {} if not seq_nodes: - print("警告: extract_importance接收到空节点列表") + util.log(2, "警告: extract_importance接收到空节点列表") return {} try: importance_out = dict() for count, node in enumerate(seq_nodes): if node is None: - print("警告: 节点为None,跳过") + util.log(2, "警告: 节点为None,跳过") continue if not hasattr(node, 'node_id') or not hasattr(node, 'importance'): - print(f"警告: 节点缺少必要属性,跳过") + util.log(2, f"警告: 节点缺少必要属性,跳过") continue # 确保importance是数值类型 @@ -333,32 +334,32 @@ def extract_importance(seq_nodes): importance_out[node.node_id] = float(node.importance) except ValueError: # 如果无法转换为数值,使用默认值 - print(f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值") + util.log(2, f"警告: 节点 {node.node_id} 的importance无法转换为数值,使用默认值") importance_out[node.node_id] = 50.0 else: importance_out[node.node_id] = node.importance return importance_out - except Exception as e: - print(f"extract_importance处理节点列表时出错: {str(e)}") - # 返回一个默认字典 - return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} - - -def _is_valid_embedding(vec, expected_dim): - if vec is None: - return False - if not isinstance(vec, (list, tuple)): - return False - if expected_dim is not None and len(vec) != expected_dim: - return False - for val in vec: - if not isinstance(val, (int, float)): - return False - return True - - -def extract_relevance(seq_nodes, embeddings, focal_pt): + except Exception as e: + util.log(3, f"extract_importance处理节点列表时出错: {str(e)}") + # 返回一个默认字典 + return {node.node_id: 50.0 for node in seq_nodes if node is not None and hasattr(node, 'node_id')} + + +def _is_valid_embedding(vec, expected_dim): + if vec is None: + return False + if not isinstance(vec, (list, tuple)): + return False + if expected_dim is not None and len(vec) != expected_dim: + return False + for val in vec: + if not isinstance(val, (int, float)): + return False + return True + + +def extract_relevance(seq_nodes, embeddings, focal_pt): """ Gets the current Persona object, a list of seq_nodes that are in a chronological order, and the focal_pt string and outputs a dictionary @@ -373,45 +374,39 @@ def extract_relevance(seq_nodes, embeddings, focal_pt): """ # 确保embeddings不为None if embeddings is None: - print("警告: embeddings为None,使用空字典代替") + util.log(2, "警告: embeddings为None,使用空字典代替") embeddings = {} - try: - focal_embedding = get_text_embedding(focal_pt) - except Exception as e: - print(f"获取焦点嵌入向量时出错: {str(e)}") - # 如果无法获取嵌入向量,返回默认值 - return {node.node_id: 0.5 for node in seq_nodes} - - expected_dim = len(focal_embedding) if isinstance(focal_embedding, (list, tuple)) else None - relevance_out = dict() - for count, node in enumerate(seq_nodes): - try: - # 检查节点内容是否在embeddings中 - if node.content in embeddings: - node_embedding = embeddings[node.content] - if not _is_valid_embedding(node_embedding, expected_dim): - try: - regenerated = get_text_embedding(node.content) - if _is_valid_embedding(regenerated, expected_dim): - embeddings[node.content] = regenerated - node_embedding = regenerated - else: - print("Warning: regenerated embedding has unexpected dimension, using default relevance") - node_embedding = None - except Exception as e: - print(f"Regenerate embedding failed: {str(e)}") - node_embedding = None - # 计算余弦相似度 - if node_embedding is None: - relevance_out[node.node_id] = 0.5 - else: - relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding) - else: - # 如果没有对应的嵌入向量,使用默认值 - relevance_out[node.node_id] = 0.5 + try: + focal_embedding = get_text_embedding(focal_pt) + except Exception as e: + util.log(3, f"获取焦点嵌入向量时出错: {str(e)}") + # 如果无法获取嵌入向量,返回默认值 + return {node.node_id: 0.5 for node in seq_nodes} + + expected_dim = len(focal_embedding) if isinstance(focal_embedding, (list, tuple)) else None + relevance_out = dict() + for count, node in enumerate(seq_nodes): + try: + # 检查节点内容是否在embeddings中 + if node.content in embeddings: + node_embedding = embeddings[node.content] + if not _is_valid_embedding(node_embedding, expected_dim): + # 维度检查与修复在启动阶段完成,这里不再重算,避免首条消息扣时间 + current_dim = len(node_embedding) if isinstance(node_embedding, (list, tuple)) else "未知" + util.log(2, f"检索时发现维度不一致的embedding: 节点ID={node.node_id}, 内容='{node.content[:30]}...', 当前维度={current_dim}, 期望维度={expected_dim}") + util.log(2, f" -> 使用默认相关性分数 0.5 (建议重启系统进行维度修复)") + node_embedding = None + # 计算余弦相似度 + if node_embedding is None: + relevance_out[node.node_id] = 0.5 + else: + relevance_out[node.node_id] = cos_sim(node_embedding, focal_embedding) + else: + # 如果没有对应的嵌入向量,使用默认值 + relevance_out[node.node_id] = 0.5 except Exception as e: - print(f"计算节点 {node.node_id} 的相关性时出错: {str(e)}") + util.log(3, f"计算节点 {node.node_id} 的相关性时出错: {str(e)}") # 如果计算过程中出错,使用默认值 relevance_out[node.node_id] = 0.5 @@ -425,12 +420,12 @@ def extract_relevance(seq_nodes, embeddings, focal_pt): class ConceptNode: def __init__(self, node_dict): # Loading the content of a memory node in the memory stream. - self.node_id = node_dict["node_id"] - self.node_type = node_dict["node_type"] - self.content = node_dict["content"] - self.importance = node_dict["importance"] - self.datetime = node_dict.get("datetime", "") - # 确保created是整数类型 + self.node_id = node_dict["node_id"] + self.node_type = node_dict["node_type"] + self.content = node_dict["content"] + self.importance = node_dict["importance"] + self.datetime = node_dict.get("datetime", "") + # 确保created是整数类型 self.created = int(node_dict["created"]) if node_dict["created"] is not None else 0 # 确保last_retrieved是整数类型 self.last_retrieved = int(node_dict["last_retrieved"]) if node_dict["last_retrieved"] is not None else 0 @@ -449,10 +444,10 @@ class ConceptNode: curr_package = {} curr_package["node_id"] = self.node_id curr_package["node_type"] = self.node_type - curr_package["content"] = self.content - curr_package["importance"] = self.importance - curr_package["datetime"] = self.datetime - curr_package["created"] = self.created + curr_package["content"] = self.content + curr_package["importance"] = self.importance + curr_package["datetime"] = self.datetime + curr_package["created"] = self.created curr_package["last_retrieved"] = self.last_retrieved curr_package["pointer_id"] = self.pointer_id @@ -474,6 +469,87 @@ class MemoryStream: self.id_to_node[new_node.node_id] = new_node self.embeddings = embeddings + self._embedding_dim_checked = False + + + def precheck_embedding_dimensions(self, force: bool = False): + """ + 启动阶段检查并修复记忆节点 embedding 维度,避免首条消息检索时重算。 + """ + result = {"checked": False, "expected_dim": None, "fixed": 0} + if self._embedding_dim_checked and not force: + return result + # 确保embeddings不为None + if self.embeddings is None: + self.embeddings = {} + if not force: + return result + + try: + # 首先尝试从已初始化的embedding服务获取维度,避免重复调用 + from utils.api_embedding_service import get_embedding_service + from utils import util + service = get_embedding_service() + + # 如果服务已经有维度信息,直接使用 + if hasattr(service, 'embedding_dim') and service.embedding_dim is not None: + expected_dim = service.embedding_dim + util.log(1, f"使用已初始化的embedding服务维度: {expected_dim}") + else: + # 只有在服务未初始化维度时才调用dimension_check + util.log(1, "embedding服务维度未初始化,进行维度检查...") + sample_embedding = get_text_embedding("dimension_check") + expected_dim = len(sample_embedding) if isinstance(sample_embedding, (list, tuple)) else None + + except Exception as e: + from utils import util + util.log(2, f"启动阶段 embedding 维度检查失败: {str(e)}") + # 即使检查失败,也标记为已检查,避免重复尝试 + self._embedding_dim_checked = True + return result + + if expected_dim is None: + from utils import util + util.log(2, "无法获取 embedding 维度,跳过维度检查") + self._embedding_dim_checked = True + return result + + fixed = 0 + if self.seq_nodes: + contents = [node.content for node in self.seq_nodes if node is not None] + else: + contents = list(self.embeddings.keys()) + + for content in contents: + if content in self.embeddings: + node_embedding = self.embeddings[content] + if not _is_valid_embedding(node_embedding, expected_dim): + # 记录维度不一致的详细信息 + current_dim = len(node_embedding) if isinstance(node_embedding, (list, tuple)) else "未知" + from utils import util + util.log(2, f"发现维度不一致的embedding: 内容='{content[:30]}...', 当前维度={current_dim}, 期望维度={expected_dim}") + + try: + util.log(1, f"正在重新生成embedding: '{content[:30]}...'") + regenerated = get_text_embedding(content) + if regenerated is not None and _is_valid_embedding(regenerated, expected_dim): + self.embeddings[content] = regenerated + fixed += 1 + util.log(1, f"成功修复embedding维度: '{content[:30]}...' ({current_dim} -> {len(regenerated)})") + else: + regenerated_dim = len(regenerated) if isinstance(regenerated, (list, tuple)) else "未知" + util.log(2, f"重新生成的embedding维度仍不一致: '{content[:30]}...' (期望={expected_dim}, 实际={regenerated_dim}),已忽略") + except Exception as e: + util.log(2, f"重新生成embedding失败: '{content[:30]}...' - {str(e)}") + + if fixed > 0: + from utils import util + util.log(1, f"启动阶段已修复 {fixed} 条记忆节点 embedding 维度") + self._embedding_dim_checked = True + result["checked"] = True + result["expected_dim"] = expected_dim + result["fixed"] = fixed + return result def count_observations(self): @@ -503,8 +579,8 @@ class MemoryStream: the elemnts of the list are the query sentences. time_step: Current time_step n_count: The number of nodes that we want to retrieve. - curr_filter: Filtering the node.type that we want to retrieve. - Acceptable values are 'all', 'reflection', 'observation', 'conversation' + curr_filter: Filtering the node.type that we want to retrieve. + Acceptable values are 'all', 'reflection', 'observation', 'conversation' hp: Hyperparameter for [recency_w, relevance_w, importance_w] verbose: verbose Returns: @@ -517,8 +593,8 @@ class MemoryStream: if len(self.seq_nodes) == 0: return dict() - # Filtering for the desired node type. curr_filter can be one of the - # elements: 'all', 'reflection', 'observation', 'conversation' + # Filtering for the desired node type. curr_filter can be one of the + # elements: 'all', 'reflection', 'observation', 'conversation' if curr_filter == "all": curr_nodes = self.seq_nodes else: @@ -528,7 +604,7 @@ class MemoryStream: # 确保embeddings不为None if self.embeddings is None: - print("警告: 在retrieve方法中,embeddings为None,初始化为空字典") + util.log(2, "警告: 在retrieve方法中,embeddings为None,初始化为空字典") self.embeddings = {} # is the main dictionary that we are returning @@ -587,7 +663,7 @@ class MemoryStream: Parameters: time_step: Current time_step - node_type: type of node -- it's either reflection, observation, conversation + node_type: type of node -- it's either reflection, observation, conversation content: the str content of the memory record importance: int score of the importance score pointer_id: the str of the parent node @@ -597,11 +673,11 @@ class MemoryStream: """ node_dict = dict() node_dict["node_id"] = len(self.seq_nodes) - node_dict["node_type"] = node_type - node_dict["content"] = content - node_dict["importance"] = importance - node_dict["datetime"] = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S") - node_dict["created"] = time_step + node_dict["node_type"] = node_type + node_dict["content"] = content + node_dict["importance"] = importance + node_dict["datetime"] = datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S") + node_dict["created"] = time_step node_dict["last_retrieved"] = time_step node_dict["pointer_id"] = pointer_id new_node = ConceptNode(node_dict) @@ -616,19 +692,19 @@ class MemoryStream: try: self.embeddings[content] = get_text_embedding(content) except Exception as e: - print(f"获取文本嵌入时出错: {str(e)}") + util.log(3, f"获取文本嵌入时出错: {str(e)}") # 如果获取嵌入失败,使用空列表代替 self.embeddings[content] = [] - def remember(self, content, time_step=0): - score = generate_importance_score([content])[0] - self._add_node(time_step, "observation", content, score, None) - - def remember_conversation(self, content, time_step=0): - score = generate_importance_score([content])[0] - self._add_node(time_step, "conversation", content, score, None) - + def remember(self, content, time_step=0): + score = generate_importance_score([content])[0] + self._add_node(time_step, "observation", content, score, None) + + def remember_conversation(self, content, time_step=0): + score = generate_importance_score([content])[0] + self._add_node(time_step, "conversation", content, score, None) + def reflect(self, anchor, reflection_count=5, retrieval_count=120, time_step=0): diff --git a/llm/nlp_cognitive_stream.py b/llm/nlp_cognitive_stream.py index c41aade..1d512f9 100644 --- a/llm/nlp_cognitive_stream.py +++ b/llm/nlp_cognitive_stream.py @@ -1517,9 +1517,28 @@ def init_memory_scheduler(): global agents # 确保agent已经创建 - if not agents: - util.log(1, '创建代理实例...') - create_agent() + agent = None + if not agents: + util.log(1, '创建代理实例...') + agent = create_agent() + else: + agent = agents.get("User") + if agent is None and len(agents) > 0: + agent = next(iter(agents.values())) + + # 启动阶段做一次 embedding 维度检查,避免首条消息时触发 + try: + if agent and agent.memory_stream and hasattr(agent.memory_stream, "precheck_embedding_dimensions"): + result = agent.memory_stream.precheck_embedding_dimensions() + if result.get("checked"): + util.log( + 1, + f"启动阶段记忆 embedding 维度检查完成: dim={result.get('expected_dim')}, 修复={result.get('fixed')}" + ) + else: + util.log(1, "启动阶段记忆 embedding 维度检查跳过(无记忆/无embedding)") + except Exception as e: + util.log(1, f"启动阶段 embedding 维度检查失败: {str(e)}") # 设置每天0点保存记忆 schedule.every().day.at("00:00").do(save_agent_memory) @@ -1640,8 +1659,28 @@ def create_agent(username=None): agent.scratch = scratch_data # 如果memory目录存在且不为空,则加载之前保存的记忆(不包括scratch数据) - if is_exist: - load_agent_memory(agent, username) + if is_exist: + load_agent_memory(agent, username) + try: + if agent.memory_stream and hasattr(agent.memory_stream, "precheck_embedding_dimensions"): + result = agent.memory_stream.precheck_embedding_dimensions(force=True) + if result.get("checked"): + util.log( + 1, + f"启动阶段记忆 embedding 维度检查完成: dim={result.get('expected_dim')}, 修复={result.get('fixed')}" + ) + if result.get("fixed"): + try: + embeddings_path = os.path.join(memory_dir, "memory_stream", "embeddings.json") + with open(embeddings_path, "w", encoding="utf-8") as f: + json.dump(agent.memory_stream.embeddings or {}, f, ensure_ascii=False, indent=2) + util.log(1, f"启动阶段已写回 embeddings.json (修复={result.get('fixed')})") + except Exception as write_err: + util.log(1, f"写回 embeddings.json 失败: {str(write_err)}") + else: + util.log(1, "启动阶段记忆 embedding 维度检查跳过(无记忆/无embedding)") + except Exception as e: + util.log(1, f"启动阶段 embedding 维度检查失败: {str(e)}") # 缓存到字典 agents[username] = agent diff --git a/utils/api_embedding_service.py b/utils/api_embedding_service.py index 22220c7..3d8e4f2 100644 --- a/utils/api_embedding_service.py +++ b/utils/api_embedding_service.py @@ -123,6 +123,7 @@ class ApiEmbeddingService: def encode_text(self, text: str) -> List[float]: """编码单个文本(带重试机制)""" import time + import requests.exceptions text = _sanitize_text(text) last_error = None @@ -152,7 +153,15 @@ class ApiEmbeddingService: # 首次调用时获取实际维度 if self.embedding_dim is None: self.embedding_dim = len(embedding) - logger.info(f"动态获取 embedding 维度: {self.embedding_dim},与原记忆节点不一致,将重新生成记忆节点的 embedding维度") + logger.info(f"动态获取 embedding 维度: {self.embedding_dim}") + else: + # 检查维度一致性 + current_dim = len(embedding) + if current_dim != self.embedding_dim: + logger.warning(f"⚠️ Embedding维度不一致! 期望={self.embedding_dim}, 实际={current_dim}, 文本='{text_preview}'") + logger.warning(f" 建议检查API配置或模型设置") + # 更新维度记录 + self.embedding_dim = current_dim logger.info(f"embedding 生成成功") return embedding @@ -168,6 +177,38 @@ class ApiEmbeddingService: logger.error(f"所有重试均失败,文本长度: {len(text)}") raise + except requests.exceptions.ConnectionError as e: + last_error = e + # 网络连接错误,包括DNS解析失败 + logger.warning(f"网络连接失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}") + if attempt < self.max_retries: + wait_time = 2 ** attempt # 指数退避 + logger.info(f"等待 {wait_time} 秒后重试...") + time.sleep(wait_time) + else: + logger.error(f"网络连接持续失败,请检查网络设置和API地址: {self.api_base_url}") + raise + + except requests.exceptions.HTTPError as e: + last_error = e + # HTTP错误(4xx, 5xx) + status_code = e.response.status_code if e.response else "未知" + logger.error(f"HTTP错误 {status_code} (尝试 {attempt + 1}/{self.max_retries + 1}): {e}") + + # 对于客户端错误(4xx),通常不需要重试 + if e.response and 400 <= e.response.status_code < 500: + logger.error(f"客户端错误,不进行重试: {e.response.text}") + raise + + # 对于服务器错误(5xx),可以重试 + if attempt < self.max_retries: + wait_time = 2 ** attempt + logger.info(f"服务器错误,等待 {wait_time} 秒后重试...") + time.sleep(wait_time) + else: + logger.error(f"服务器错误持续发生") + raise + except Exception as e: last_error = e logger.error(f"文本编码失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}") @@ -202,6 +243,28 @@ class ApiEmbeddingService: result = response.json() embeddings = [item['embedding'] for item in result['data']] + # 检查批量embedding的维度一致性 + if embeddings: + dimensions = [len(emb) for emb in embeddings] + unique_dims = set(dimensions) + + if len(unique_dims) > 1: + logger.warning(f"⚠️ 批量embedding维度不一致: {dict(zip(range(len(dimensions)), dimensions))}") + logger.warning(f" 唯一维度: {unique_dims}") + + # 检查与已知维度的一致性 + if self.embedding_dim is not None: + for i, dim in enumerate(dimensions): + if dim != self.embedding_dim: + text_preview = texts[i][:30] + "..." if len(texts[i]) > 30 else texts[i] + logger.warning(f"⚠️ 文本{i}维度不一致: 期望={self.embedding_dim}, 实际={dim}, 文本='{text_preview}'") + else: + # 首次批量调用,设置维度 + if dimensions: + self.embedding_dim = dimensions[0] + logger.info(f"从批量请求动态获取 embedding 维度: {self.embedding_dim}") + + logger.info(f"批量 embedding 生成成功: {len(embeddings)} 个向量") return embeddings except Exception as e: logger.error(f"批量文本编码失败: {e}") @@ -217,6 +280,29 @@ class ApiEmbeddingService: "service_type": "api" } + def health_check(self) -> dict: + """健康检查:测试embedding服务是否正常工作""" + try: + # 使用简单文本测试服务 + test_text = "health_check" + embedding = self.encode_text(test_text) + + return { + "status": "healthy", + "model": self.model_name, + "api_url": self.api_base_url, + "embedding_dim": len(embedding) if embedding else None, + "test_successful": True + } + except Exception as e: + return { + "status": "unhealthy", + "model": self.model_name, + "api_url": self.api_base_url, + "error": str(e), + "test_successful": False + } + # 全局实例 _global_embedding_service = None diff --git a/utils/config_util.py b/utils/config_util.py index 27aecfb..d600d61 100644 --- a/utils/config_util.py +++ b/utils/config_util.py @@ -1,7 +1,7 @@ import os import json import codecs -from langsmith.schemas import Feedback +# from langsmith.schemas import Feedback # 临时注释 import requests from configparser import ConfigParser import functools @@ -53,36 +53,58 @@ start_mode = None fay_url = None system_conf_path = None config_json_path = None -use_bionic_memory = None - -# Embedding API 配置全局变量 -embedding_api_model = None -embedding_api_base_url = None -embedding_api_key = None - -# 避免重复加载配置中心导致日志刷屏 -_last_loaded_project_id = None -_last_loaded_config = None -_last_loaded_from_api = False # 表示上次加载来自配置中心(含缓存) -_bootstrap_loaded_from_api = False # 无本地配置时启动阶段已从配置中心加载过 -_warned_public_project_ids = set() +use_bionic_memory = None + +# Embedding API 配置全局变量 +embedding_api_model = None +embedding_api_base_url = None +embedding_api_key = None + +# 避免重复加载配置中心导致日志刷屏 +_last_loaded_project_id = None +_last_loaded_config = None +_last_loaded_from_api = False # 表示上次加载来自配置中心(含缓存) +_bootstrap_loaded_from_api = False # 无本地配置时启动阶段已从配置中心加载过 +_warned_public_config_keys = set() + +# Public config center identifiers (warn users if matched) +PUBLIC_CONFIG_PROJECT_ID = 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' +PUBLIC_CONFIG_BASE_URL = 'http://1.12.69.110:5500' + + +def _public_config_warn_key(): + base_url = (CONFIG_SERVER.get('BASE_URL') or '').rstrip('/') + public_base = PUBLIC_CONFIG_BASE_URL.rstrip('/') + if base_url == public_base: + return f"base_url:{base_url}" + if CONFIG_SERVER.get('PROJECT_ID') == PUBLIC_CONFIG_PROJECT_ID: + return f"project_id:{CONFIG_SERVER.get('PROJECT_ID')}" + return None + + +def _warn_public_config_once(): + key = _public_config_warn_key() + if not key or key in _warned_public_config_keys: + return + _warned_public_config_keys.add(key) + print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m") # config server中心配置,system.conf与config.json存在时不会使用配置中心 -CONFIG_SERVER = { - 'BASE_URL': 'http://1.12.69.110:5500', # 默认API服务器地址 - 'API_KEY': 'your-api-key-here', # 默认API密钥 - 'PROJECT_ID': 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' # 项目ID,需要在使用前设置 -} - -def _refresh_config_center(): - env_project_id = os.getenv('FAY_CONFIG_CENTER_ID') - if env_project_id: - CONFIG_SERVER['PROJECT_ID'] = env_project_id - -_refresh_config_center() - -def load_config_from_api(project_id=None): - global CONFIG_SERVER +CONFIG_SERVER = { + 'BASE_URL': 'http://1.12.69.110:5500', # 默认API服务器地址 + 'API_KEY': 'your-api-key-here', # 默认API密钥 + 'PROJECT_ID': 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' # 项目ID,需要在使用前设置 +} + +def _refresh_config_center(): + env_project_id = os.getenv('FAY_CONFIG_CENTER_ID') + if env_project_id: + CONFIG_SERVER['PROJECT_ID'] = env_project_id + +_refresh_config_center() + +def load_config_from_api(project_id=None): + global CONFIG_SERVER """ 从API加载配置 @@ -159,15 +181,15 @@ def load_config_from_api(project_id=None): return None -@synchronized -def load_config(force_reload=False): - """ - 加载配置文件,如果本地文件不存在则直接使用API加载 +@synchronized +def load_config(force_reload=False): + """ + 加载配置文件,如果本地文件不存在则直接使用API加载 - Returns: - 包含配置信息的字典 - """ - global config + Returns: + 包含配置信息的字典 + """ + global config global system_config global key_ali_nls_key_id global key_ali_nls_key_secret @@ -200,168 +222,153 @@ def load_config(force_reload=False): global embedding_api_base_url global embedding_api_key - global CONFIG_SERVER - global system_conf_path - global config_json_path - global _last_loaded_project_id - global _last_loaded_config - global _last_loaded_from_api - global _bootstrap_loaded_from_api - - _refresh_config_center() - - env_project_id = os.getenv('FAY_CONFIG_CENTER_ID') - explicit_config_center = bool(env_project_id) - using_config_center = explicit_config_center - if ( - env_project_id - and not force_reload - and _last_loaded_config is not None - and _last_loaded_project_id == env_project_id - and _last_loaded_from_api - ): - return _last_loaded_config - - default_system_conf_path = os.path.join(os.getcwd(), 'system.conf') - default_config_json_path = os.path.join(os.getcwd(), 'config.json') - cache_system_conf_path = os.path.join(os.getcwd(), 'cache_data', 'system.conf') - cache_config_json_path = os.path.join(os.getcwd(), 'cache_data', 'config.json') - root_system_conf_exists = os.path.exists(default_system_conf_path) - root_config_json_exists = os.path.exists(default_config_json_path) - root_config_complete = root_system_conf_exists and root_config_json_exists - - # 构建system.conf和config.json的完整路径 - config_center_fallback = False - if using_config_center: - system_conf_path = cache_system_conf_path - config_json_path = cache_config_json_path - else: - if ( - system_conf_path is None - or config_json_path is None - or system_conf_path == cache_system_conf_path - or config_json_path == cache_config_json_path - ): - system_conf_path = default_system_conf_path - config_json_path = default_config_json_path - - if not root_config_complete: - cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path) - if (not _bootstrap_loaded_from_api) or (not cache_ready): - using_config_center = True - config_center_fallback = True - system_conf_path = cache_system_conf_path - config_json_path = cache_config_json_path - else: - system_conf_path = cache_system_conf_path - config_json_path = cache_config_json_path - - forced_loaded = False - loaded_from_api = False - api_attempted = False - if using_config_center: - if explicit_config_center: - util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}") - else: - util.log(1, f"未检测到本地system.conf或config.json,尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}") - api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID']) - api_attempted = True - if api_config: - util.log(1, "成功从配置中心加载配置") - system_config = api_config['system_config'] - config = api_config['config'] - loaded_from_api = True - if config_center_fallback: - _bootstrap_loaded_from_api = True - - # 缓存API配置到本地文件 - system_conf_path = cache_system_conf_path - config_json_path = cache_config_json_path - save_api_config_to_local(api_config, system_conf_path, config_json_path) - forced_loaded = True - - if ( - CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' - and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids - ): - _warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID']) - print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m") - else: - util.log(2, "配置中心加载失败,尝试使用缓存配置") - - sys_conf_exists = os.path.exists(system_conf_path) - config_json_exists = os.path.exists(config_json_path) - - # 如果任一本地文件不存在,直接尝试从API加载 - if (not sys_conf_exists or not config_json_exists) and not forced_loaded: - if using_config_center: - if not api_attempted: - util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...") - api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID']) - api_attempted = True - if api_config: - util.log(1, "成功从配置中心加载配置") - system_config = api_config['system_config'] - config = api_config['config'] - loaded_from_api = True - if config_center_fallback: - _bootstrap_loaded_from_api = True - - # 缓存API配置到本地文件 - system_conf_path = cache_system_conf_path - config_json_path = cache_config_json_path - save_api_config_to_local(api_config, system_conf_path, config_json_path) - - if ( - CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' - and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids - ): - _warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID']) - print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m") - else: - # 使用提取的项目ID或全局项目ID - util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'和' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在),尝试从API加载配置...") - api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID']) - - if api_config: - util.log(1, "成功从配置中心加载配置") - system_config = api_config['system_config'] - config = api_config['config'] - loaded_from_api = True - - # 缓存API配置到本地文件 - system_conf_path = cache_system_conf_path - config_json_path = cache_config_json_path - save_api_config_to_local(api_config, system_conf_path, config_json_path) - - if ( - CONFIG_SERVER['PROJECT_ID'] == 'd19f7b0a-2b8a-4503-8c0d-1a587b90eb69' - and CONFIG_SERVER['PROJECT_ID'] not in _warned_public_project_ids - ): - _warned_public_project_ids.add(CONFIG_SERVER['PROJECT_ID']) - print("\033[1;33;41m警告:你正在使用社区公共配置,请尽快更换!\033[0m") - - sys_conf_exists = os.path.exists(system_conf_path) - config_json_exists = os.path.exists(config_json_path) - if using_config_center and (not sys_conf_exists or not config_json_exists): - if _last_loaded_config is not None and _last_loaded_from_api: - util.log(2, "配置中心缓存不可用,继续使用内存中的配置") - return _last_loaded_config - if config_center_fallback and using_config_center and (not sys_conf_exists or not config_json_exists): - cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path) - if cache_ready: - util.log(2, "配置中心不可用,回退使用缓存配置") - using_config_center = False - system_conf_path = cache_system_conf_path - config_json_path = cache_config_json_path - else: - util.log(2, "配置中心不可用,回退使用本地配置文件") - using_config_center = False - system_conf_path = default_system_conf_path - config_json_path = default_config_json_path - sys_conf_exists = os.path.exists(system_conf_path) - config_json_exists = os.path.exists(config_json_path) - # 如果本地文件存在,从本地文件加载 + global CONFIG_SERVER + global system_conf_path + global config_json_path + global _last_loaded_project_id + global _last_loaded_config + global _last_loaded_from_api + global _bootstrap_loaded_from_api + + _refresh_config_center() + + env_project_id = os.getenv('FAY_CONFIG_CENTER_ID') + explicit_config_center = bool(env_project_id) + using_config_center = explicit_config_center + if ( + env_project_id + and not force_reload + and _last_loaded_config is not None + and _last_loaded_project_id == env_project_id + and _last_loaded_from_api + ): + return _last_loaded_config + + default_system_conf_path = os.path.join(os.getcwd(), 'system.conf') + default_config_json_path = os.path.join(os.getcwd(), 'config.json') + cache_system_conf_path = os.path.join(os.getcwd(), 'cache_data', 'system.conf') + cache_config_json_path = os.path.join(os.getcwd(), 'cache_data', 'config.json') + root_system_conf_exists = os.path.exists(default_system_conf_path) + root_config_json_exists = os.path.exists(default_config_json_path) + root_config_complete = root_system_conf_exists and root_config_json_exists + + # 构建system.conf和config.json的完整路径 + config_center_fallback = False + if using_config_center: + system_conf_path = cache_system_conf_path + config_json_path = cache_config_json_path + else: + if ( + system_conf_path is None + or config_json_path is None + or system_conf_path == cache_system_conf_path + or config_json_path == cache_config_json_path + ): + system_conf_path = default_system_conf_path + config_json_path = default_config_json_path + + if not root_config_complete: + cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path) + if (not _bootstrap_loaded_from_api) or (not cache_ready): + using_config_center = True + config_center_fallback = True + system_conf_path = cache_system_conf_path + config_json_path = cache_config_json_path + else: + system_conf_path = cache_system_conf_path + config_json_path = cache_config_json_path + + forced_loaded = False + loaded_from_api = False + api_attempted = False + if using_config_center: + if explicit_config_center: + util.log(1, f"检测到配置中心参数,优先加载项目配置: {CONFIG_SERVER['PROJECT_ID']}") + else: + util.log(1, f"未检测到本地system.conf或config.json,尝试从配置中心加载配置: {CONFIG_SERVER['PROJECT_ID']}") + api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID']) + api_attempted = True + if api_config: + util.log(1, "成功从配置中心加载配置") + system_config = api_config['system_config'] + config = api_config['config'] + loaded_from_api = True + if config_center_fallback: + _bootstrap_loaded_from_api = True + + # 缓存API配置到本地文件 + system_conf_path = cache_system_conf_path + config_json_path = cache_config_json_path + save_api_config_to_local(api_config, system_conf_path, config_json_path) + forced_loaded = True + + _warn_public_config_once() + else: + util.log(2, "配置中心加载失败,尝试使用缓存配置") + + sys_conf_exists = os.path.exists(system_conf_path) + config_json_exists = os.path.exists(config_json_path) + + # 如果任一本地文件不存在,直接尝试从API加载 + if (not sys_conf_exists or not config_json_exists) and not forced_loaded: + if using_config_center: + if not api_attempted: + util.log(1, "配置中心缓存缺失,尝试从配置中心加载配置...") + api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID']) + api_attempted = True + if api_config: + util.log(1, "成功从配置中心加载配置") + system_config = api_config['system_config'] + config = api_config['config'] + loaded_from_api = True + if config_center_fallback: + _bootstrap_loaded_from_api = True + + # 缓存API配置到本地文件 + system_conf_path = cache_system_conf_path + config_json_path = cache_config_json_path + save_api_config_to_local(api_config, system_conf_path, config_json_path) + + _warn_public_config_once() + else: + # 使用提取的项目ID或全局项目ID + util.log(1, f"本地配置文件不完整({system_conf_path if not sys_conf_exists else ''}{'和' if not sys_conf_exists and not config_json_exists else ''}{config_json_path if not config_json_exists else ''}不存在),尝试从API加载配置...") + api_config = load_config_from_api(CONFIG_SERVER['PROJECT_ID']) + + if api_config: + util.log(1, "成功从配置中心加载配置") + system_config = api_config['system_config'] + config = api_config['config'] + loaded_from_api = True + + # 缓存API配置到本地文件 + system_conf_path = cache_system_conf_path + config_json_path = cache_config_json_path + save_api_config_to_local(api_config, system_conf_path, config_json_path) + + _warn_public_config_once() + + sys_conf_exists = os.path.exists(system_conf_path) + config_json_exists = os.path.exists(config_json_path) + if using_config_center and (not sys_conf_exists or not config_json_exists): + if _last_loaded_config is not None and _last_loaded_from_api: + util.log(2, "配置中心缓存不可用,继续使用内存中的配置") + return _last_loaded_config + if config_center_fallback and using_config_center and (not sys_conf_exists or not config_json_exists): + cache_ready = os.path.exists(cache_system_conf_path) and os.path.exists(cache_config_json_path) + if cache_ready: + util.log(2, "配置中心不可用,回退使用缓存配置") + using_config_center = False + system_conf_path = cache_system_conf_path + config_json_path = cache_config_json_path + else: + util.log(2, "配置中心不可用,回退使用本地配置文件") + using_config_center = False + system_conf_path = default_system_conf_path + config_json_path = default_config_json_path + sys_conf_exists = os.path.exists(system_conf_path) + config_json_exists = os.path.exists(config_json_path) + # 如果本地文件存在,从本地文件加载 # 加载system.conf system_config = ConfigParser() system_config.read(system_conf_path, encoding='UTF-8') @@ -417,9 +424,9 @@ def load_config(force_reload=False): use_bionic_memory = config.get('memory', {}).get('use_bionic_memory', False) # 构建配置字典 - config_dict = { - 'system_config': system_config, - 'config': config, + config_dict = { + 'system_config': system_config, + 'config': config, 'ali_nls_key_id': key_ali_nls_key_id, 'ali_nls_key_secret': key_ali_nls_key_secret, 'ali_nls_app_key': key_ali_nls_app_key, @@ -455,14 +462,14 @@ def load_config(force_reload=False): 'embedding_api_base_url': embedding_api_base_url, 'embedding_api_key': embedding_api_key, - 'source': 'api' if using_config_center else 'local' # 标记配置来源 - } - - _last_loaded_project_id = CONFIG_SERVER['PROJECT_ID'] if using_config_center else None - _last_loaded_config = config_dict - _last_loaded_from_api = using_config_center - - return config_dict + 'source': 'api' if using_config_center else 'local' # 标记配置来源 + } + + _last_loaded_project_id = CONFIG_SERVER['PROJECT_ID'] if using_config_center else None + _last_loaded_config = config_dict + _last_loaded_from_api = using_config_center + + return config_dict def save_api_config_to_local(api_config, system_conf_path, config_json_path): """