Files
Fay/utils/api_embedding_service.py
guo zebin eecbb931a9 交互模块
1. [增加]embedding维度检查自动修正逻辑。
2026-01-29 15:09:37 +08:00

315 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import re
import requests
from typing import List, Optional
import threading
import os
import sys
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
sys.path.insert(0, project_root)
try:
import utils.config_util as cfg
CONFIG_UTIL_AVAILABLE = True
except ImportError as e:
CONFIG_UTIL_AVAILABLE = False
cfg = None
# 使用统一日志配置
from bionicmemory.utils.logging_config import get_logger
logger = get_logger(__name__)
if not CONFIG_UTIL_AVAILABLE:
logger.warning("无法导入 config_util将使用环境变量配置")
def _sanitize_text(text: str) -> str:
if not isinstance(text, str) or not text:
return text
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', text, flags=re.IGNORECASE)
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
return cleaned
class ApiEmbeddingService:
"""API Embedding服务 - 单例模式,调用 OpenAI 兼容的 API"""
_instance = None
_lock = threading.Lock()
_initialized = False
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
with self._lock:
if not self._initialized:
self._initialize_config()
ApiEmbeddingService._initialized = True
def _initialize_config(self):
"""初始化配置,只执行一次"""
try:
# 优先从 system.conf 读取配置
api_base_url = None
api_key = None
model_name = None
if CONFIG_UTIL_AVAILABLE and cfg:
try:
# 确保配置已加载
if cfg.config is None:
cfg.load_config()
# 从 config_util 获取配置(自动复用 LLM 配置)
api_base_url = cfg.embedding_api_base_url
api_key = cfg.embedding_api_key
model_name = cfg.embedding_api_model
logger.info(f"从 system.conf 读取配置:")
logger.info(f" - embedding_api_model: {model_name}")
logger.info(f" - embedding_api_base_url: {api_base_url}")
logger.info(f" - embedding_api_key: {'已配置' if api_key else '未配置'}")
except Exception as e:
logger.warning(f"从 system.conf 读取配置失败: {e}")
# 验证必需配置并提供更好的错误提示
if not api_base_url:
api_base_url = os.getenv('EMBEDDING_API_BASE_URL')
if not api_base_url:
error_msg = ("未配置 embedding_api_base_url\n"
"请确保 system.conf 中配置了 gpt_base_url"
"或设置环境变量 EMBEDDING_API_BASE_URL")
logger.error(error_msg)
raise ValueError(error_msg)
logger.warning(f"使用环境变量配置: base_url={api_base_url}")
if not api_key:
api_key = os.getenv('EMBEDDING_API_KEY')
if not api_key:
error_msg = ("未配置 embedding_api_key\n"
"请确保 system.conf 中配置了 gpt_api_key"
"或设置环境变量 EMBEDDING_API_KEY")
logger.error(error_msg)
raise ValueError(error_msg)
logger.warning("使用环境变量配置: api_key")
if not model_name:
model_name = os.getenv('EMBEDDING_API_MODEL', 'text-embedding-ada-002')
logger.warning(f"未配置 embedding_api_model使用默认值: {model_name}")
# 保存配置信息
self.api_base_url = api_base_url.rstrip('/') # 移除末尾的斜杠
self.api_key = api_key
self.model_name = model_name
self.embedding_dim = None # 将在首次调用时动态获取
self.timeout = 60 # API 请求超时时间(秒),默认 60 秒
self.max_retries = 2 # 最大重试次数
logger.info(f"API Embedding 服务初始化完成")
logger.info(f"模型: {self.model_name}")
logger.info(f"API 地址: {self.api_base_url}")
except Exception as e:
logger.error(f"API Embedding 服务初始化失败: {e}")
raise
def encode_text(self, text: str) -> List[float]:
"""编码单个文本(带重试机制)"""
import time
import requests.exceptions
text = _sanitize_text(text)
last_error = None
for attempt in range(self.max_retries + 1):
try:
# 调用 API 进行编码
url = f"{self.api_base_url}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": self.model_name,
"input": text
}
# 记录请求信息
text_preview = text[:50] + "..." if len(text) > 50 else text
logger.info(f"发送 embedding 请求 (尝试 {attempt + 1}/{self.max_retries + 1}): 文本长度={len(text)}, 预览='{text_preview}'")
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
response.raise_for_status()
result = response.json()
embedding = result['data'][0]['embedding']
# 首次调用时获取实际维度
if self.embedding_dim is None:
self.embedding_dim = len(embedding)
logger.info(f"动态获取 embedding 维度: {self.embedding_dim}")
else:
# 检查维度一致性
current_dim = len(embedding)
if current_dim != self.embedding_dim:
logger.warning(f"⚠️ Embedding维度不一致! 期望={self.embedding_dim}, 实际={current_dim}, 文本='{text_preview}'")
logger.warning(f" 建议检查API配置或模型设置")
# 更新维度记录
self.embedding_dim = current_dim
logger.info(f"embedding 生成成功")
return embedding
except requests.exceptions.Timeout as e:
last_error = e
logger.warning(f"请求超时 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
if attempt < self.max_retries:
wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
logger.error(f"所有重试均失败,文本长度: {len(text)}")
raise
except requests.exceptions.ConnectionError as e:
last_error = e
# 网络连接错误包括DNS解析失败
logger.warning(f"网络连接失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
if attempt < self.max_retries:
wait_time = 2 ** attempt # 指数退避
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
logger.error(f"网络连接持续失败请检查网络设置和API地址: {self.api_base_url}")
raise
except requests.exceptions.HTTPError as e:
last_error = e
# HTTP错误4xx, 5xx
status_code = e.response.status_code if e.response else "未知"
logger.error(f"HTTP错误 {status_code} (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
# 对于客户端错误4xx通常不需要重试
if e.response and 400 <= e.response.status_code < 500:
logger.error(f"客户端错误,不进行重试: {e.response.text}")
raise
# 对于服务器错误5xx可以重试
if attempt < self.max_retries:
wait_time = 2 ** attempt
logger.info(f"服务器错误,等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
logger.error(f"服务器错误持续发生")
raise
except Exception as e:
last_error = e
logger.error(f"文本编码失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
if attempt < self.max_retries:
wait_time = 2 ** attempt
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
raise
def encode_texts(self, texts: List[str]) -> List[List[float]]:
"""批量编码文本"""
try:
texts = [_sanitize_text(text) for text in texts]
# 调用 API 进行批量编码
url = f"{self.api_base_url}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": self.model_name,
"input": texts
}
# 批量请求使用更长的超时时间
batch_timeout = self.timeout * 2 # 批量请求超时时间加倍
logger.info(f"发送批量 embedding 请求: 文本数={len(texts)}, 超时={batch_timeout}")
response = requests.post(url, json=payload, headers=headers, timeout=batch_timeout)
response.raise_for_status()
result = response.json()
embeddings = [item['embedding'] for item in result['data']]
# 检查批量embedding的维度一致性
if embeddings:
dimensions = [len(emb) for emb in embeddings]
unique_dims = set(dimensions)
if len(unique_dims) > 1:
logger.warning(f"⚠️ 批量embedding维度不一致: {dict(zip(range(len(dimensions)), dimensions))}")
logger.warning(f" 唯一维度: {unique_dims}")
# 检查与已知维度的一致性
if self.embedding_dim is not None:
for i, dim in enumerate(dimensions):
if dim != self.embedding_dim:
text_preview = texts[i][:30] + "..." if len(texts[i]) > 30 else texts[i]
logger.warning(f"⚠️ 文本{i}维度不一致: 期望={self.embedding_dim}, 实际={dim}, 文本='{text_preview}'")
else:
# 首次批量调用,设置维度
if dimensions:
self.embedding_dim = dimensions[0]
logger.info(f"从批量请求动态获取 embedding 维度: {self.embedding_dim}")
logger.info(f"批量 embedding 生成成功: {len(embeddings)} 个向量")
return embeddings
except Exception as e:
logger.error(f"批量文本编码失败: {e}")
raise
def get_model_info(self) -> dict:
"""获取模型信息"""
return {
"model_name": self.model_name,
"embedding_dim": self.embedding_dim,
"api_base_url": self.api_base_url,
"initialized": self._initialized,
"service_type": "api"
}
def health_check(self) -> dict:
"""健康检查测试embedding服务是否正常工作"""
try:
# 使用简单文本测试服务
test_text = "health_check"
embedding = self.encode_text(test_text)
return {
"status": "healthy",
"model": self.model_name,
"api_url": self.api_base_url,
"embedding_dim": len(embedding) if embedding else None,
"test_successful": True
}
except Exception as e:
return {
"status": "unhealthy",
"model": self.model_name,
"api_url": self.api_base_url,
"error": str(e),
"test_successful": False
}
# 全局实例
_global_embedding_service = None
def get_embedding_service() -> ApiEmbeddingService:
"""获取全局embedding服务实例"""
global _global_embedding_service
if _global_embedding_service is None:
_global_embedding_service = ApiEmbeddingService()
return _global_embedding_service