mirror of
https://github.com/xszyou/Fay.git
synced 2026-03-12 17:51:28 +08:00
自然进化
1.加入仿生记忆功能。
This commit is contained in:
17
bionicmemory/__init__.py
Normal file
17
bionicmemory/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
BionicMemory - 仿生记忆系统
|
||||
|
||||
基于仿生学原理的AI记忆管理系统,模拟人类大脑的长短期记忆机制,
|
||||
通过科学的遗忘算法和智能的记忆管理,为AI应用提供真正个性化的记忆体验。
|
||||
|
||||
主要特性:
|
||||
- 长短期记忆分层管理
|
||||
- 牛顿冷却遗忘算法
|
||||
- 聚类抑制机制
|
||||
- 上下文增强技术
|
||||
- 多租户安全隔离
|
||||
"""
|
||||
|
||||
__version__ = "2.0.0"
|
||||
__author__ = "BionicMemory Team"
|
||||
__email__ = "contact@bionicmemory.ai"
|
||||
7
bionicmemory/algorithms/__init__.py
Normal file
7
bionicmemory/algorithms/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
算法模块
|
||||
|
||||
包含仿生记忆系统的核心算法:
|
||||
- 牛顿冷却遗忘算法
|
||||
- 聚类抑制机制
|
||||
"""
|
||||
147
bionicmemory/algorithms/clustering_suppression.py
Normal file
147
bionicmemory/algorithms/clustering_suppression.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
基于聚类的记忆抑制机制
|
||||
实现从短期记忆中加载数倍目标条数,进行k-means聚类,每簇取最相似的代表
|
||||
|
||||
主要思路:
|
||||
1. 从短期记忆中加载数倍(t:聚类平均条数)目标所需条数(k*n:从n倍的检索结果中取topk)的相关记录(含embedding),总检索条数=t*k*n
|
||||
2. 对结果根据embedding进行k-means聚类,簇数为k*n(同条数/t)
|
||||
3. 每簇取与检索最相似的代表当前簇,返回k个簇代表作为最终结果
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from sklearn.cluster import KMeans
|
||||
from typing import List, Dict, Tuple
|
||||
import logging
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class ClusteringSuppression:
|
||||
"""
|
||||
聚类抑制机制
|
||||
通过k-means聚类对相似记忆进行分组,从每组中选择最相关的代表
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cluster_multiplier: int = 3,
|
||||
retrieval_multiplier: int = 2):
|
||||
"""
|
||||
初始化聚类抑制机制
|
||||
|
||||
Args:
|
||||
cluster_multiplier: 每个簇期望包含的记录数量,默认3条
|
||||
retrieval_multiplier: 检索结果倍数,默认2倍
|
||||
"""
|
||||
self.cluster_multiplier = cluster_multiplier
|
||||
self.retrieval_multiplier = retrieval_multiplier
|
||||
logger.info(f"聚类抑制机制初始化: 每簇期望记录数={cluster_multiplier}, 检索倍数={retrieval_multiplier}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def calculate_retrieval_parameters(self, target_k: int) -> Tuple[int, int]:
|
||||
"""
|
||||
计算检索参数
|
||||
|
||||
Args:
|
||||
target_k: 目标返回条数
|
||||
|
||||
Returns:
|
||||
(总检索条数, 聚类数)
|
||||
"""
|
||||
# 聚类数 = 目标条数 * 检索倍数
|
||||
cluster_count = target_k * self.retrieval_multiplier
|
||||
|
||||
# 总检索条数 = 聚类数 * 每簇期望记录数
|
||||
total_retrieval = cluster_count * self.cluster_multiplier
|
||||
|
||||
return total_retrieval, cluster_count
|
||||
|
||||
def cluster_by_query_similarity_and_aggregate(self,
|
||||
records: List[Dict],
|
||||
embeddings_array: np.ndarray,
|
||||
distances: List[float],
|
||||
cluster_count: int,
|
||||
target_k: int) -> List[Dict]:
|
||||
"""
|
||||
基于查询相似度的聚类:
|
||||
- 簇内选与查询distance最小的记录为代表;
|
||||
- 代表记录的valid_access_count = 簇内所有记录的valid_access_count之和;
|
||||
- 最终结果 = 分别按相关度与valid_access_count各取target_k条,按doc_id去重后返回合集。
|
||||
Args:
|
||||
records: 与embeddings_array、distances一一对齐的记录列表(每条含embedding、distance、valid_access_count)
|
||||
embeddings_array: 形如 (N, D) 的向量数组
|
||||
distances: 长度为 N 的距离列表(越小越相似)
|
||||
cluster_count: 聚类簇数
|
||||
target_k: 返回前k条代表
|
||||
"""
|
||||
import numpy as np
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
if not isinstance(cluster_count, int) or cluster_count < 1:
|
||||
cluster_count = 1
|
||||
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
return []
|
||||
|
||||
# 样本数 <= 聚类数:不聚类,直接在原集合上做双路topK并去重
|
||||
if n <= cluster_count:
|
||||
base = []
|
||||
for i in range(n):
|
||||
rep = dict(records[i])
|
||||
rep["cluster_size"] = 1
|
||||
base.append(rep)
|
||||
# 分别取topK
|
||||
by_rel = sorted(base, key=lambda x: float(x.get("distance", float("inf"))))[:target_k]
|
||||
by_cnt = sorted(base, key=lambda x: float(x.get("valid_access_count", 0.0)), reverse=True)[:target_k]
|
||||
# 合并去重(按doc_id)
|
||||
seen = set()
|
||||
merged = []
|
||||
for r in by_rel + by_cnt:
|
||||
rid = r.get("doc_id")
|
||||
if rid not in seen:
|
||||
seen.add(rid)
|
||||
merged.append(r)
|
||||
return merged
|
||||
|
||||
# KMeans 聚类
|
||||
kmeans = KMeans(n_clusters=cluster_count, random_state=42, n_init=10)
|
||||
labels = kmeans.fit_predict(embeddings_array)
|
||||
|
||||
# 簇代表选择与累计
|
||||
representatives = []
|
||||
for cid in np.unique(labels):
|
||||
idx = np.where(labels == cid)[0]
|
||||
if len(idx) == 0:
|
||||
continue
|
||||
|
||||
# 代表:簇内与查询distance最小
|
||||
local_dist = [(i, float(distances[i]) if distances[i] is not None else float("inf")) for i in idx]
|
||||
rep_idx, _ = min(local_dist, key=lambda t: t[1])
|
||||
|
||||
# 累计簇内valid_access_count
|
||||
sum_valid = float(sum(float(records[i].get("valid_access_count", 0.0)) for i in idx))
|
||||
|
||||
rep = dict(records[rep_idx])
|
||||
rep["valid_access_count"] = sum_valid
|
||||
rep["cluster_size"] = len(idx)
|
||||
representatives.append(rep)
|
||||
|
||||
# 分别按相关度与valid_access_count取topK,然后合并去重
|
||||
top_by_relevance = sorted(representatives, key=lambda x: float(x.get("distance", float("inf"))))[:target_k]
|
||||
top_by_count = sorted(representatives, key=lambda x: float(x.get("valid_access_count", 0.0)), reverse=True)[:target_k]
|
||||
|
||||
seen_ids = set()
|
||||
final_selection = []
|
||||
for r in top_by_relevance + top_by_count:
|
||||
rid = r.get("doc_id")
|
||||
if rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
final_selection.append(r)
|
||||
|
||||
return final_selection
|
||||
48
bionicmemory/algorithms/newton_cooling_helper.py
Normal file
48
bionicmemory/algorithms/newton_cooling_helper.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import math
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
|
||||
class CoolingRate(Enum):
|
||||
MINUTES_20 = (0.582, 20 * 60)
|
||||
HOURS_1 = (0.442, 1 * 60 * 60)
|
||||
HOURS_9 = (0.358, 9 * 60 * 60)
|
||||
DAYS_1 = (0.337, 1 * 24 * 60 * 60)
|
||||
DAYS_2 = (0.278, 2 * 24 * 60 * 60)
|
||||
DAYS_6 = (0.254, 6 * 24 * 60 * 60)
|
||||
DAYS_31 = (0.211, 31 * 24 * 60 * 60)
|
||||
|
||||
class NewtonCoolingHelper:
|
||||
@staticmethod
|
||||
def calculate_cooling_rate(enum_value: CoolingRate) -> float:
|
||||
"""
|
||||
根据枚举值计算冷却速率系数(alpha)。
|
||||
"""
|
||||
final_temperature_ratio, time_interval = enum_value.value
|
||||
return -math.log(final_temperature_ratio) / time_interval
|
||||
|
||||
@staticmethod
|
||||
def calculate_newton_cooling_effect(initial_temperature: float, time_interval: float, cooling_rate: float = None) -> float:
|
||||
"""
|
||||
根据牛顿冷却定律计算当前时间的温度。
|
||||
"""
|
||||
if cooling_rate is None:
|
||||
cooling_rate = NewtonCoolingHelper.calculate_cooling_rate(CoolingRate.DAYS_31)
|
||||
return initial_temperature * math.exp(-cooling_rate * time_interval)
|
||||
|
||||
@staticmethod
|
||||
def calculate_time_difference(update_time: datetime, current_time: datetime) -> float:
|
||||
"""
|
||||
计算上次更新时间与当前时间之间的时间差。
|
||||
"""
|
||||
if isinstance(update_time, str):
|
||||
update_time = datetime.fromisoformat(update_time)
|
||||
if isinstance(current_time, str):
|
||||
current_time = datetime.fromisoformat(current_time)
|
||||
time_delta = current_time - update_time
|
||||
return time_delta.total_seconds()
|
||||
|
||||
@staticmethod
|
||||
def get_threshold(cooling_rate: CoolingRate=None) -> float:
|
||||
if cooling_rate is None:
|
||||
cooling_rate=CoolingRate.DAYS_31
|
||||
return cooling_rate.value[0]
|
||||
6
bionicmemory/api/__init__.py
Normal file
6
bionicmemory/api/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
API模块
|
||||
|
||||
包含仿生记忆系统的API接口:
|
||||
- FastAPI代理服务器
|
||||
"""
|
||||
511
bionicmemory/api/proxy_server.py
Normal file
511
bionicmemory/api/proxy_server.py
Normal file
@@ -0,0 +1,511 @@
|
||||
"""
|
||||
基于OpenAI官方库的代理服务器
|
||||
使用OpenAI官方客户端处理所有请求,确保完全兼容
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from fastapi import FastAPI, Request, Response, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
# OpenAI官方库
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.embedding import Embedding
|
||||
|
||||
# BionicMemory核心组件
|
||||
from bionicmemory.core.memory_system import LongShortTermMemorySystem, SourceType
|
||||
from bionicmemory.services.memory_cleanup_scheduler import MemoryCleanupScheduler
|
||||
from bionicmemory.core.chroma_service import ChromaService
|
||||
from bionicmemory.algorithms.newton_cooling_helper import CoolingRate
|
||||
from bionicmemory.services.local_embedding_service import get_embedding_service
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ========== 环境变量配置 ==========
|
||||
# 禁用ChromaDB遥测
|
||||
os.environ["ANONYMIZED_TELEMETRY"] = "False"
|
||||
|
||||
API_HOST = os.getenv("API_HOST", "0.0.0.0")
|
||||
API_PORT = int(os.getenv("API_PORT", "8000"))
|
||||
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
|
||||
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8001"))
|
||||
CHROMA_CLIENT_TYPE = os.getenv("CHROMA_CLIENT_TYPE", "persistent")
|
||||
|
||||
# ========== OpenAI配置 ==========
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "https://api.deepseek.com")
|
||||
OPENAI_MODEL_NAME = os.getenv("OPENAI_MODEL_NAME", "deepseek-chat")
|
||||
|
||||
# 记忆系统配置
|
||||
SUMMARY_MAX_LENGTH = int(os.getenv('SUMMARY_MAX_LENGTH', '500'))
|
||||
MAX_RETRIEVAL_RESULTS = int(os.getenv('MAX_RETRIEVAL_RESULTS', '7'))
|
||||
CLUSTER_MULTIPLIER = int(os.getenv('CLUSTER_MULTIPLIER', '3'))
|
||||
RETRIEVAL_MULTIPLIER = int(os.getenv('RETRIEVAL_MULTIPLIER', '2'))
|
||||
|
||||
# ========== 工具函数 ==========
|
||||
|
||||
def extract_user_message(messages: List[Dict]) -> Optional[str]:
|
||||
"""从消息列表中提取用户消息"""
|
||||
for message in reversed(messages): # 从最新消息开始查找
|
||||
if message.get("role") == "user":
|
||||
return message.get("content", "")
|
||||
return None
|
||||
|
||||
def extract_user_id_from_request(body_data: Dict) -> str:
|
||||
"""从OpenAI请求中提取用户ID"""
|
||||
try:
|
||||
logger.info("🔍 开始提取用户ID...")
|
||||
|
||||
# 1. 优先从对话协议中的user字段提取
|
||||
if "user" in body_data:
|
||||
raw_user = body_data["user"]
|
||||
if isinstance(raw_user, str) and raw_user.strip():
|
||||
user_id = raw_user.strip()
|
||||
logger.info(f"✅ 使用对话协议user字段: {user_id}")
|
||||
return user_id
|
||||
|
||||
# 2. 默认值:default_user
|
||||
user_id = "default_user"
|
||||
logger.info(f"✅ 使用默认用户ID: {user_id}")
|
||||
return user_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 提取用户ID失败: {e}")
|
||||
return "default_user"
|
||||
|
||||
def enhance_chat_with_memory(body_data: Dict, user_id: str) -> Tuple[Dict, List[float]]:
|
||||
"""
|
||||
使用记忆系统增强聊天请求
|
||||
|
||||
Args:
|
||||
body_data: 请求体数据
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
(增强后的body_data, enhanced_query_embedding)
|
||||
"""
|
||||
global memory_system
|
||||
|
||||
if not memory_system:
|
||||
logger.warning("⚠️ 记忆系统未初始化,跳过记忆增强")
|
||||
return body_data, None
|
||||
|
||||
try:
|
||||
messages = body_data.get("messages", [])
|
||||
if not messages:
|
||||
return body_data, None
|
||||
|
||||
# 提取用户消息
|
||||
user_message = extract_user_message(messages)
|
||||
if not user_message:
|
||||
return body_data, None
|
||||
|
||||
# 使用记忆系统处理用户消息
|
||||
short_term_records, system_prompt, query_embedding = memory_system.process_user_message(
|
||||
user_message, user_id
|
||||
)
|
||||
|
||||
if short_term_records:
|
||||
logger.info(f"🧠 找到 {len(short_term_records)} 条相关记忆")
|
||||
logger.info(f"🧠 生成的系统提示语长度: {len(system_prompt)}")
|
||||
|
||||
# 直接使用memory_system生成的系统提示语作为系统消息
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
}
|
||||
|
||||
# 在用户消息前插入系统消息
|
||||
enhanced_messages = [system_message] + (messages[-3:] if len(messages) > 3 else messages)
|
||||
body_data["messages"] = enhanced_messages
|
||||
|
||||
logger.info(f"🧠 记忆增强完成,消息数量: {len(messages)} -> {len(enhanced_messages)}")
|
||||
logger.info(f"🧠 记忆增强完成,消息内容: {enhanced_messages}")
|
||||
|
||||
return body_data, query_embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆增强失败: {e}")
|
||||
return body_data, None
|
||||
|
||||
async def process_ai_reply_async(response_content: str, user_id: str, current_user_content: str = None):
|
||||
"""异步处理AI回复(不阻塞响应性能)"""
|
||||
global memory_system
|
||||
|
||||
if not memory_system:
|
||||
return
|
||||
|
||||
try:
|
||||
# 执行记忆系统处理(正确的业务逻辑顺序)
|
||||
await memory_system.process_agent_reply_async(response_content, user_id, current_user_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 异步处理AI回复失败: {e}")
|
||||
|
||||
# ========== 全局变量 ==========
|
||||
memory_system = None
|
||||
memory_cleanup_scheduler = None
|
||||
chroma_service = None
|
||||
|
||||
# OpenAI客户端
|
||||
openai_client = None
|
||||
async_openai_client = None
|
||||
|
||||
# ========== 初始化函数 ==========
|
||||
|
||||
def initialize_memory_system():
|
||||
"""初始化记忆系统"""
|
||||
global memory_system, memory_cleanup_scheduler, chroma_service
|
||||
|
||||
try:
|
||||
logger.info("正在初始化记忆系统...")
|
||||
|
||||
# 初始化ChromaDB服务(只使用本地embedding)
|
||||
chroma_service = ChromaService()
|
||||
logger.info("ChromaDB服务初始化完成(本地embedding模式)")
|
||||
|
||||
# 初始化记忆系统
|
||||
memory_system = LongShortTermMemorySystem(
|
||||
chroma_service=chroma_service,
|
||||
summary_threshold=SUMMARY_MAX_LENGTH,
|
||||
max_retrieval_results=MAX_RETRIEVAL_RESULTS,
|
||||
cluster_multiplier=CLUSTER_MULTIPLIER,
|
||||
retrieval_multiplier=RETRIEVAL_MULTIPLIER,
|
||||
)
|
||||
|
||||
# 启动时清空短期记忆库
|
||||
try:
|
||||
# 清空短期记忆库
|
||||
short_term_deleted_ids = chroma_service.delete_documents(
|
||||
memory_system.short_term_collection_name
|
||||
)
|
||||
logger.info(f"启动清空短期记忆库,删除 {len(short_term_deleted_ids)} 条记录")
|
||||
|
||||
except Exception as _e:
|
||||
logger.warning("启动清空短期记忆库失败", exc_info=True)
|
||||
|
||||
# 初始化清理调度器
|
||||
memory_cleanup_scheduler = MemoryCleanupScheduler(memory_system=memory_system)
|
||||
memory_cleanup_scheduler.start()
|
||||
|
||||
logger.info("记忆系统初始化完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"记忆系统初始化失败: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def initialize_openai_clients():
|
||||
"""初始化OpenAI客户端"""
|
||||
global openai_client, async_openai_client
|
||||
|
||||
try:
|
||||
logger.info("正在初始化OpenAI客户端...")
|
||||
|
||||
# 同步客户端
|
||||
openai_client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_BASE
|
||||
)
|
||||
|
||||
# 异步客户端
|
||||
async_openai_client = AsyncOpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_BASE
|
||||
)
|
||||
|
||||
logger.info("OpenAI客户端初始化完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI客户端初始化失败: {e}")
|
||||
return False
|
||||
|
||||
# ========== 生命周期事件处理器 ==========
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 启动时初始化
|
||||
initialize_memory_system()
|
||||
initialize_openai_clients()
|
||||
yield
|
||||
# 关闭时清理
|
||||
if memory_cleanup_scheduler:
|
||||
memory_cleanup_scheduler.stop()
|
||||
logger.info("记忆清理调度器已停止")
|
||||
|
||||
# ========== FastAPI应用初始化 ==========
|
||||
app = FastAPI(title="BionicMemory OpenAI Proxy", version="2.0.0", lifespan=lifespan)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ========== 健康检查端点 ==========
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "BionicMemory OpenAI Proxy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"memory_system_initialized": memory_system is not None,
|
||||
"openai_client_initialized": openai_client is not None,
|
||||
"cleanup_scheduler_running": memory_cleanup_scheduler is not None if memory_cleanup_scheduler else False
|
||||
}
|
||||
|
||||
# ========== 主要路由处理 ==========
|
||||
@app.api_route("/v1/{path:path}", methods=["POST", "GET"])
|
||||
async def proxy(request: Request, path: str):
|
||||
"""
|
||||
代理所有 /v1/* 请求
|
||||
使用OpenAI官方库处理,确保完全兼容
|
||||
"""
|
||||
body = await request.body()
|
||||
|
||||
# 记录基本请求信息
|
||||
logger.info(f"📥 收到请求: {request.method} /v1/{path}")
|
||||
|
||||
# ========== 路由处理 ==========
|
||||
if path.startswith("embeddings"):
|
||||
# Embedding API - 使用本地embedding服务
|
||||
return await handle_embedding_request(request, path, body)
|
||||
|
||||
elif path == "chat/completions":
|
||||
# Chat Completions API - 使用OpenAI客户端 + 记忆增强
|
||||
return await handle_chat_request(request, path, body)
|
||||
|
||||
else:
|
||||
# 其他 API - 使用OpenAI客户端透传
|
||||
return await handle_other_request(request, path, body)
|
||||
|
||||
# ========== 处理函数 ==========
|
||||
|
||||
async def handle_embedding_request(request: Request, path: str, body: bytes):
|
||||
"""处理embedding请求 - 使用本地embedding服务"""
|
||||
try:
|
||||
# 解析请求体
|
||||
if body:
|
||||
body_data = json.loads(body)
|
||||
input_text = body_data.get("input", "")
|
||||
model = body_data.get("model", "")
|
||||
|
||||
# 使用本地embedding服务
|
||||
logger.info("使用本地embedding服务")
|
||||
embedding_service = get_embedding_service()
|
||||
embeddings = embedding_service.get_embeddings([input_text])
|
||||
|
||||
# 构造OpenAI兼容的响应
|
||||
response_data = {
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": embeddings[0]
|
||||
}],
|
||||
"model": model,
|
||||
"usage": {
|
||||
"prompt_tokens": len(input_text.split()),
|
||||
"total_tokens": len(input_text.split())
|
||||
}
|
||||
}
|
||||
|
||||
return JSONResponse(content=response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理embedding请求失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": f"处理embedding请求失败: {str(e)}"}
|
||||
)
|
||||
|
||||
async def handle_chat_request(request: Request, path: str, body: bytes):
|
||||
"""处理对话请求 - 使用OpenAI客户端 + 记忆增强"""
|
||||
try:
|
||||
# 解析请求体
|
||||
body_data = None
|
||||
user_id = None
|
||||
enhanced_query_embedding = None
|
||||
current_user_content = None
|
||||
|
||||
if body:
|
||||
body_data = json.loads(body)
|
||||
# 提取用户ID
|
||||
user_id = extract_user_id_from_request(body_data)
|
||||
|
||||
# 替换模型名称
|
||||
if "model" in body_data:
|
||||
body_data["model"] = OPENAI_MODEL_NAME
|
||||
|
||||
# 记忆增强处理
|
||||
enhanced_body_data, query_embedding = enhance_chat_with_memory(body_data, user_id)
|
||||
current_user_content = body_data.get("messages", [])[-1].get("content", "")
|
||||
body_data = enhanced_body_data
|
||||
|
||||
# 检查是否为流式响应
|
||||
is_stream = body_data and body_data.get("stream", False) if body_data else False
|
||||
|
||||
if is_stream:
|
||||
# 流式响应 - 使用异步OpenAI客户端
|
||||
logger.info("🌊 处理流式响应(使用OpenAI客户端)")
|
||||
|
||||
try:
|
||||
# 使用OpenAI客户端创建流式响应
|
||||
stream = await async_openai_client.chat.completions.create(
|
||||
model=body_data.get("model", OPENAI_MODEL_NAME),
|
||||
messages=body_data.get("messages", []),
|
||||
stream=True,
|
||||
**{k: v for k, v in body_data.items()
|
||||
if k not in ["model", "messages", "stream"]}
|
||||
)
|
||||
|
||||
async def openai_stream_wrapper():
|
||||
full_content = ""
|
||||
async for chunk in stream:
|
||||
# 使用OpenAI原生格式
|
||||
chunk_data = chunk.model_dump()
|
||||
content = chunk_data.get('choices', [{}])[0].get('delta', {}).get('content', '')
|
||||
if content:
|
||||
full_content += content
|
||||
|
||||
# 转换为SSE格式
|
||||
yield f"data: {json.dumps(chunk_data)}\n\n"
|
||||
|
||||
# 流式结束后异步存储记忆
|
||||
if full_content and body_data:
|
||||
asyncio.create_task(process_ai_reply_async(
|
||||
full_content, user_id, current_user_content
|
||||
))
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
openai_stream_wrapper(),
|
||||
status_code=200,
|
||||
headers={
|
||||
"Content-Type": "text/plain; charset=utf-8",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ OpenAI流式处理失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": f"流式处理失败: {str(e)}"}
|
||||
)
|
||||
else:
|
||||
# 非流式响应 - 使用同步OpenAI客户端
|
||||
logger.info("📝 处理非流式响应(使用OpenAI客户端)")
|
||||
|
||||
try:
|
||||
response = openai_client.chat.completions.create(
|
||||
model=body_data.get("model", OPENAI_MODEL_NAME),
|
||||
messages=body_data.get("messages", []),
|
||||
**{k: v for k, v in body_data.items()
|
||||
if k not in ["model", "messages"]}
|
||||
)
|
||||
|
||||
# 异步存储记忆
|
||||
if response.choices[0].message.content and body_data:
|
||||
asyncio.create_task(process_ai_reply_async(
|
||||
response.choices[0].message.content,
|
||||
user_id,
|
||||
current_user_content
|
||||
))
|
||||
|
||||
# 返回OpenAI原生响应
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ OpenAI非流式处理失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": f"非流式处理失败: {str(e)}"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理对话请求失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": f"处理对话请求失败: {str(e)}"}
|
||||
)
|
||||
|
||||
async def handle_other_request(request: Request, path: str, body: bytes):
|
||||
"""处理其他API - 使用OpenAI客户端透传"""
|
||||
try:
|
||||
# 解析请求体
|
||||
body_data = json.loads(body) if body else {}
|
||||
|
||||
# 使用OpenAI客户端处理其他请求
|
||||
logger.info(f"🔄 处理其他请求: {path}")
|
||||
|
||||
# 根据路径选择处理方法
|
||||
if path == "models":
|
||||
# 模型列表请求
|
||||
models_response = {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": OPENAI_MODEL_NAME,
|
||||
"object": "model",
|
||||
"created": int(datetime.now().timestamp()),
|
||||
"owned_by": "bionicmemory"
|
||||
}
|
||||
]
|
||||
}
|
||||
return JSONResponse(content=models_response)
|
||||
|
||||
else:
|
||||
# 其他请求透传
|
||||
try:
|
||||
# 使用OpenAI客户端处理
|
||||
if request.method == "GET":
|
||||
# GET请求处理
|
||||
response = openai_client._client.get(f"/v1/{path}")
|
||||
return JSONResponse(content=response.json())
|
||||
else:
|
||||
# POST请求处理
|
||||
response = openai_client._client.post(
|
||||
f"/v1/{path}",
|
||||
json=body_data,
|
||||
headers={"Authorization": f"Bearer {OPENAI_API_KEY}"}
|
||||
)
|
||||
return JSONResponse(content=response.json())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ OpenAI客户端处理其他请求失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": f"处理请求失败: {str(e)}"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理其他请求失败: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": f"处理其他请求失败: {str(e)}"}
|
||||
)
|
||||
|
||||
# ========== 启动配置 ==========
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"bionicmemory.api.proxy_server_openai:app",
|
||||
host=API_HOST,
|
||||
port=API_PORT,
|
||||
log_level="info",
|
||||
access_log=True,
|
||||
reload=False
|
||||
)
|
||||
7
bionicmemory/core/__init__.py
Normal file
7
bionicmemory/core/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
核心模块
|
||||
|
||||
包含仿生记忆系统的核心功能:
|
||||
- 长短期记忆系统
|
||||
- ChromaDB服务封装
|
||||
"""
|
||||
552
bionicmemory/core/chroma_service.py
Normal file
552
bionicmemory/core/chroma_service.py
Normal file
@@ -0,0 +1,552 @@
|
||||
import chromadb
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from typing import Optional, List, Dict, Any, Union, Callable
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from bionicmemory.services.chat_helper import ChatHelper
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 在文件顶部添加导入
|
||||
from bionicmemory.services.local_embedding_service import get_embedding_service
|
||||
|
||||
|
||||
class ChromaService:
|
||||
"""
|
||||
ChromaDB向量数据库操作服务
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
client_type: str = None,
|
||||
path: Optional[str] = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
chat_api_key: str = None,
|
||||
chat_base_url: str = None):
|
||||
"""
|
||||
初始化ChromaDB服务
|
||||
|
||||
Args:
|
||||
client_type (str): 客户端类型,支持 'persistent', 'ephemeral', 'http'
|
||||
path (str): 持久化存储路径(仅persistent模式)
|
||||
host (str): 服务器地址(仅http模式)
|
||||
port (int): 服务器端口(仅http模式)
|
||||
chat_api_key (str): 聊天API密钥
|
||||
chat_base_url (str): 聊天API基础URL
|
||||
"""
|
||||
try:
|
||||
# 从环境变量读取配置
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
|
||||
# 设置默认值
|
||||
client_type = client_type or os.getenv('CHROMA_CLIENT_TYPE', 'persistent')
|
||||
path = path or os.getenv('CHROMA_PATH', './memory/chroma_db')
|
||||
path = os.path.abspath(path) # 转换为绝对路径
|
||||
host = host or os.getenv('CHROMA_HOST', 'localhost')
|
||||
port = int(port or os.getenv('CHROMA_PORT', '8001'))
|
||||
chat_api_key = chat_api_key or os.getenv('OPENAI_API_KEY')
|
||||
chat_base_url = chat_base_url or os.getenv('OPENAI_API_BASE')
|
||||
|
||||
# 初始化ChromaDB客户端
|
||||
if client_type == "persistent":
|
||||
self.client = chromadb.PersistentClient(path=path)
|
||||
elif client_type == "ephemeral":
|
||||
self.client = chromadb.EphemeralClient()
|
||||
elif client_type == "http":
|
||||
self.client = chromadb.HttpClient(host=host, port=port)
|
||||
else:
|
||||
raise ValueError(f"不支持的客户端类型: {client_type}")
|
||||
|
||||
# 初始化聊天助手(如果需要)
|
||||
if chat_api_key and chat_base_url:
|
||||
self.chat_helper = ChatHelper(chat_api_key, chat_base_url)
|
||||
logger.info("聊天助手初始化完成")
|
||||
else:
|
||||
self.chat_helper = None
|
||||
logger.info("未配置聊天API,聊天功能不可用")
|
||||
|
||||
# 初始化本地embedding服务
|
||||
self.local_embedding_service = get_embedding_service()
|
||||
logger.info("使用本地embedding服务")
|
||||
|
||||
# 初始化自定义embedding函数相关变量
|
||||
self._custom_embedding_func = None
|
||||
self._embedding_function = None # 本地模式不需要embedding函数
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"初始化ChromaDB客户端失败: {str(e)}")
|
||||
|
||||
def create_collection(self, name: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
创建新的集合
|
||||
|
||||
Args:
|
||||
name (str): 集合名称
|
||||
metadata (Dict[str, Any], optional): 集合元数据
|
||||
|
||||
Returns:
|
||||
Collection: 集合对象
|
||||
"""
|
||||
try:
|
||||
# 本地embedding模式,不使用ChromaDB的embedding函数
|
||||
embedding_function = None
|
||||
|
||||
collection = self.client.create_collection(
|
||||
name=name,
|
||||
metadata=metadata,
|
||||
embedding_function=embedding_function
|
||||
)
|
||||
logger.info(f"成功创建集合: {name}")
|
||||
return collection
|
||||
except Exception as e:
|
||||
logger.error(f"创建集合失败: {name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_or_create_collection(self, name: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
获取或创建集合
|
||||
|
||||
Args:
|
||||
name (str): 集合名称
|
||||
metadata (Dict[str, Any], optional): 集合元数据
|
||||
|
||||
Returns:
|
||||
Collection: 集合对象
|
||||
"""
|
||||
try:
|
||||
embedding_function = None
|
||||
if self._custom_embedding_func is not None:
|
||||
self._embedding_function.custom_func = self._custom_embedding_func
|
||||
embedding_function = self._embedding_function
|
||||
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=name,
|
||||
metadata=metadata,
|
||||
embedding_function=embedding_function
|
||||
)
|
||||
logger.info(f"成功获取或创建集合: {name}")
|
||||
return collection
|
||||
except Exception as e:
|
||||
logger.error(f"获取或创建集合失败: {name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def list_collections(self):
|
||||
"""
|
||||
列出所有集合
|
||||
|
||||
Returns:
|
||||
List[Collection]: 集合对象列表
|
||||
"""
|
||||
try:
|
||||
collections = self.client.list_collections()
|
||||
logger.info(f"找到 {len(collections)} 个集合")
|
||||
return collections
|
||||
except Exception as e:
|
||||
logger.error(f"获取集合列表失败: {e}")
|
||||
raise
|
||||
|
||||
def delete_collection(self, name: str):
|
||||
"""
|
||||
删除集合
|
||||
|
||||
Args:
|
||||
name (str): 集合名称
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
self.client.delete_collection(name=name)
|
||||
logger.info(f"成功删除集合: {name}")
|
||||
except Exception as e:
|
||||
logger.error(f"删除集合失败: {name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def add_documents(self,
|
||||
collection_name: str,
|
||||
documents: List[str],
|
||||
embeddings: List[List[float]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None) -> List[str]:
|
||||
"""
|
||||
向集合添加文档
|
||||
|
||||
Args:
|
||||
collection_name (str): 集合名称
|
||||
documents (List[str]): 文档内容列表
|
||||
embeddings (List[List[float]], optional): 预计算的embedding向量列表
|
||||
ids (List[str], optional): 文档ID列表
|
||||
metadatas (List[Dict[str, Any]], optional): 文档元数据列表
|
||||
|
||||
Returns:
|
||||
List[str]: 添加的文档ID列表
|
||||
"""
|
||||
try:
|
||||
# 使用self.client确保集合存在
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function
|
||||
)
|
||||
|
||||
# 如果没有提供ID,自动生成
|
||||
if ids is None:
|
||||
ids = [f"doc_{i}" for i in range(len(documents))]
|
||||
|
||||
# 如果提供了预计算的embedding,使用它们
|
||||
if embeddings is not None:
|
||||
# 验证参数长度一致性
|
||||
if len(documents) != len(embeddings):
|
||||
raise ValueError(f"文档数量({len(documents)})与embedding数量({len(embeddings)})不匹配")
|
||||
|
||||
collection.add(
|
||||
documents=documents,
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas
|
||||
)
|
||||
else:
|
||||
# 让ChromaDB自动生成embedding
|
||||
collection.add(
|
||||
documents=documents,
|
||||
ids=ids,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
return ids # ✅ 返回实际数据
|
||||
except Exception as e:
|
||||
logger.error(f"添加文档失败: {e}")
|
||||
raise # ✅ 抛出异常
|
||||
|
||||
def query_documents(self,
|
||||
collection_name: str,
|
||||
query_texts: List[str] = None,
|
||||
query_embeddings: List[List[float]] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
include: Optional[List[str]] = None) -> Dict:
|
||||
"""
|
||||
查询文档
|
||||
|
||||
Args:
|
||||
collection_name (str): 集合名称
|
||||
query_texts (List[str], optional): 查询文本列表
|
||||
query_embeddings (List[List[float]], optional): 预计算的查询embedding列表
|
||||
n_results (int): 返回结果数量
|
||||
where (Dict[str, Any], optional): 元数据过滤条件
|
||||
include (List[str], optional): 需要返回的数据类型
|
||||
|
||||
Returns:
|
||||
Dict: 查询结果字典
|
||||
"""
|
||||
try:
|
||||
# 使用self.client确保集合存在
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function
|
||||
)
|
||||
|
||||
# 设置默认的include参数
|
||||
if include is None:
|
||||
include = ["documents", "metadatas", "distances", "embeddings"]
|
||||
|
||||
# 优先使用预计算的embedding,避免重复计算
|
||||
if query_embeddings is not None:
|
||||
results = collection.query(
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
include=include
|
||||
)
|
||||
else:
|
||||
results = collection.query(
|
||||
query_texts=query_texts,
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
include=include
|
||||
)
|
||||
|
||||
# 统一处理embeddings,确保返回list格式
|
||||
if 'embeddings' in results and results.get('embeddings') is not None:
|
||||
embeddings_data = results['embeddings']
|
||||
processed_embeddings = []
|
||||
for embedding_list in embeddings_data:
|
||||
processed_embedding_list = []
|
||||
for embedding in embedding_list:
|
||||
if embedding is not None and hasattr(embedding, 'tolist'):
|
||||
processed_embedding_list.append(embedding.tolist())
|
||||
else:
|
||||
processed_embedding_list.append(embedding)
|
||||
processed_embeddings.append(processed_embedding_list)
|
||||
results['embeddings'] = processed_embeddings
|
||||
|
||||
return results # ✅ 返回实际数据
|
||||
except Exception as e:
|
||||
logger.error(f"查询文档失败: {e}")
|
||||
raise # ✅ 抛出异常
|
||||
|
||||
def get_documents(self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
include: Optional[List[str]] = None) -> Dict:
|
||||
"""
|
||||
获取文档
|
||||
|
||||
Args:
|
||||
collection_name (str): 集合名称
|
||||
ids (List[str], optional): 文档ID列表
|
||||
limit (int, optional): 限制返回数量
|
||||
where (Dict[str, Any], optional): 元数据过滤条件
|
||||
include (List[str], optional): 需要返回的数据类型
|
||||
|
||||
Returns:
|
||||
Dict: 文档结果字典
|
||||
"""
|
||||
try:
|
||||
# 使用self.client确保集合存在
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function
|
||||
)
|
||||
|
||||
# 设置默认的include参数
|
||||
if include is None:
|
||||
include = ["documents", "metadatas"]
|
||||
|
||||
results = collection.get(
|
||||
ids=ids,
|
||||
limit=limit,
|
||||
where=where,
|
||||
include=include
|
||||
)
|
||||
|
||||
# 统一处理embeddings,确保返回list格式
|
||||
if 'embeddings' in results and results.get('embeddings') is not None:
|
||||
embeddings_data = results['embeddings']
|
||||
processed_embeddings = []
|
||||
for embedding_list in embeddings_data:
|
||||
processed_embedding_list = []
|
||||
for embedding in embedding_list:
|
||||
if embedding is not None and hasattr(embedding, 'tolist'):
|
||||
processed_embedding_list.append(embedding.tolist())
|
||||
else:
|
||||
processed_embedding_list.append(embedding)
|
||||
processed_embeddings.append(processed_embedding_list)
|
||||
results['embeddings'] = processed_embeddings
|
||||
|
||||
return results # ✅ 返回实际数据
|
||||
except Exception as e:
|
||||
logger.error(f"获取文档失败: {e}")
|
||||
raise # ✅ 抛出异常
|
||||
|
||||
def update_documents(self,
|
||||
collection_name: str,
|
||||
ids: List[str],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None) -> Dict:
|
||||
"""
|
||||
更新文档
|
||||
|
||||
Args:
|
||||
collection_name (str): 集合名称
|
||||
ids (List[str]): 文档ID列表
|
||||
documents (List[str], optional): 新的文档内容
|
||||
metadatas (List[Dict[str, Any]], optional): 新的元数据
|
||||
|
||||
Returns:
|
||||
Dict: 更新后的文档数据
|
||||
"""
|
||||
try:
|
||||
# 使用self.client确保集合存在
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function
|
||||
)
|
||||
|
||||
collection.update(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
# 返回更新后的文档数据
|
||||
return collection.get(ids=ids) # ✅ 返回实际数据
|
||||
except Exception as e:
|
||||
logger.error(f"更新文档失败: {e}")
|
||||
raise # ✅ 抛出异常
|
||||
|
||||
def delete_documents(self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None) -> List[str]:
|
||||
"""
|
||||
删除文档
|
||||
|
||||
Args:
|
||||
collection_name (str): 集合名称
|
||||
ids (List[str], optional): 文档ID列表
|
||||
where (Dict[str, Any], optional): 元数据过滤条件
|
||||
|
||||
Returns:
|
||||
List[str]: 删除的文档ID列表
|
||||
"""
|
||||
try:
|
||||
# 使用self.client确保集合存在
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function
|
||||
)
|
||||
|
||||
# 如果提供了ids,直接删除
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
return ids # ✅ 返回实际数据
|
||||
else:
|
||||
# 如果使用where条件,先查询要删除的文档
|
||||
if where:
|
||||
results = collection.get(where=where)
|
||||
deleted_ids = results.get('ids', [])
|
||||
if deleted_ids:
|
||||
collection.delete(ids=deleted_ids)
|
||||
return deleted_ids # ✅ 返回实际数据
|
||||
else:
|
||||
# 删除所有文档
|
||||
all_results = collection.get()
|
||||
all_ids = all_results.get('ids', [])
|
||||
if all_ids:
|
||||
collection.delete(ids=all_ids)
|
||||
return all_ids # ✅ 返回实际数据
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除文档失败: {e}")
|
||||
raise # ✅ 抛出异常
|
||||
|
||||
def count_documents(self, collection_name: str) -> int:
|
||||
"""
|
||||
统计集合中的文档数量
|
||||
|
||||
Args:
|
||||
collection_name (str): 集合名称
|
||||
|
||||
Returns:
|
||||
int: 文档数量
|
||||
"""
|
||||
try:
|
||||
# 使用self.client确保集合存在
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function
|
||||
)
|
||||
count = collection.count()
|
||||
return count # ✅ 返回实际数据
|
||||
except Exception as e:
|
||||
logger.error(f"统计文档数量失败: {e}")
|
||||
raise # ✅ 抛出异常
|
||||
|
||||
def peek_documents(self, collection_name: str, limit: int = 10) -> Dict:
|
||||
"""
|
||||
预览集合中的文档
|
||||
|
||||
Args:
|
||||
collection_name (str): 集合名称
|
||||
limit (int): 预览数量限制
|
||||
|
||||
Returns:
|
||||
Dict: 预览结果数据
|
||||
"""
|
||||
try:
|
||||
# 使用self.client确保集合存在
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function
|
||||
)
|
||||
results = collection.peek(limit=limit)
|
||||
return results # ✅ 返回实际数据
|
||||
except Exception as e:
|
||||
logger.error(f"预览文档失败: {e}")
|
||||
raise # ✅ 抛出异常
|
||||
|
||||
def custom_embedding(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
自定义嵌入函数(预留接口)
|
||||
|
||||
Args:
|
||||
texts (List[str]): 待嵌入的文本列表
|
||||
|
||||
Returns:
|
||||
List[List[float]]: 嵌入向量列表
|
||||
"""
|
||||
# 函数体为pass,后续手动实现
|
||||
pass
|
||||
|
||||
def set_custom_embedding_function(self, embedding_func: Callable[[List[str]], List[List[float]]]) -> None:
|
||||
"""
|
||||
设置自定义嵌入函数
|
||||
|
||||
Args:
|
||||
embedding_func: 自定义嵌入函数,接受文本列表,返回向量列表
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
self._custom_embedding_func = embedding_func
|
||||
# ✅ 不返回值,成功就成功
|
||||
except Exception as e:
|
||||
logger.error(f"设置自定义嵌入函数失败: {e}")
|
||||
raise # ✅ 抛出异常
|
||||
|
||||
def get_custom_embedding_function(self) -> Optional[Callable]:
|
||||
"""
|
||||
获取当前设置的自定义嵌入函数
|
||||
|
||||
Returns:
|
||||
Optional[Callable]: 当前的自定义嵌入函数,如果未设置则返回None
|
||||
"""
|
||||
return self._custom_embedding_func
|
||||
|
||||
def create_embeddings(self, texts: List[str], model: str = None) -> List[List[float]]:
|
||||
"""
|
||||
使用本地服务生成文本的embedding向量
|
||||
"""
|
||||
# 使用本地embedding服务
|
||||
embeddings = self.local_embedding_service.encode_texts(texts)
|
||||
return embeddings.tolist()
|
||||
|
||||
def get_embedding_dimension(self) -> int:
|
||||
"""
|
||||
获取embedding维度(从embedding服务动态获取)
|
||||
"""
|
||||
# 从 embedding 服务获取实际维度
|
||||
model_info = self.local_embedding_service.get_model_info()
|
||||
return model_info.get('embedding_dim', 1024)
|
||||
|
||||
def get_collection(self, name: str):
|
||||
"""
|
||||
获取集合对象
|
||||
|
||||
Args:
|
||||
name (str): 集合名称
|
||||
|
||||
Returns:
|
||||
Collection: 集合对象
|
||||
"""
|
||||
try:
|
||||
collection = self.client.get_collection(name)
|
||||
logger.info(f"成功获取集合: {name}")
|
||||
return collection
|
||||
except Exception as e:
|
||||
logger.error(f"获取集合失败: {name}, 错误: {e}")
|
||||
raise
|
||||
1488
bionicmemory/core/memory_system.py
Normal file
1488
bionicmemory/core/memory_system.py
Normal file
File diff suppressed because it is too large
Load Diff
10
bionicmemory/services/__init__.py
Normal file
10
bionicmemory/services/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
服务模块
|
||||
|
||||
包含仿生记忆系统的各种服务:
|
||||
- 摘要生成服务
|
||||
- 话题摘要服务
|
||||
- 本地Embedding服务
|
||||
- 聊天助手服务
|
||||
- 记忆清理调度器
|
||||
"""
|
||||
218
bionicmemory/services/api_embedding_service.py
Normal file
218
bionicmemory/services/api_embedding_service.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
import requests
|
||||
from typing import List, Optional
|
||||
import threading
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
try:
|
||||
import utils.config_util as cfg
|
||||
CONFIG_UTIL_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
CONFIG_UTIL_AVAILABLE = False
|
||||
cfg = None
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if not CONFIG_UTIL_AVAILABLE:
|
||||
logger.warning("无法导入 config_util,将使用环境变量配置")
|
||||
|
||||
class ApiEmbeddingService:
|
||||
"""API Embedding服务 - 单例模式,调用 OpenAI 兼容的 API"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._initialize_config()
|
||||
ApiEmbeddingService._initialized = True
|
||||
|
||||
def _initialize_config(self):
|
||||
"""初始化配置,只执行一次"""
|
||||
try:
|
||||
# 优先从 system.conf 读取配置
|
||||
api_base_url = None
|
||||
api_key = None
|
||||
model_name = None
|
||||
|
||||
if CONFIG_UTIL_AVAILABLE and cfg:
|
||||
try:
|
||||
# 确保配置已加载
|
||||
if cfg.config is None:
|
||||
cfg.load_config()
|
||||
|
||||
# 从 config_util 获取配置(自动复用 LLM 配置)
|
||||
api_base_url = cfg.embedding_api_base_url
|
||||
api_key = cfg.embedding_api_key
|
||||
model_name = cfg.embedding_api_model
|
||||
|
||||
logger.info(f"从 system.conf 读取配置:")
|
||||
logger.info(f" - embedding_api_model: {model_name}")
|
||||
logger.info(f" - embedding_api_base_url: {api_base_url}")
|
||||
logger.info(f" - embedding_api_key: {'已配置' if api_key else '未配置'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"从 system.conf 读取配置失败: {e}")
|
||||
|
||||
# 验证必需配置并提供更好的错误提示
|
||||
if not api_base_url:
|
||||
api_base_url = os.getenv('EMBEDDING_API_BASE_URL')
|
||||
if not api_base_url:
|
||||
error_msg = ("未配置 embedding_api_base_url!\n"
|
||||
"请确保 system.conf 中配置了 gpt_base_url,"
|
||||
"或设置环境变量 EMBEDDING_API_BASE_URL")
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.warning(f"使用环境变量配置: base_url={api_base_url}")
|
||||
|
||||
if not api_key:
|
||||
api_key = os.getenv('EMBEDDING_API_KEY')
|
||||
if not api_key:
|
||||
error_msg = ("未配置 embedding_api_key!\n"
|
||||
"请确保 system.conf 中配置了 gpt_api_key,"
|
||||
"或设置环境变量 EMBEDDING_API_KEY")
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.warning("使用环境变量配置: api_key")
|
||||
|
||||
if not model_name:
|
||||
model_name = os.getenv('EMBEDDING_API_MODEL', 'text-embedding-ada-002')
|
||||
logger.warning(f"未配置 embedding_api_model,使用默认值: {model_name}")
|
||||
|
||||
# 保存配置信息
|
||||
self.api_base_url = api_base_url.rstrip('/') # 移除末尾的斜杠
|
||||
self.api_key = api_key
|
||||
self.model_name = model_name
|
||||
self.embedding_dim = None # 将在首次调用时动态获取
|
||||
self.timeout = 60 # API 请求超时时间(秒),默认 60 秒
|
||||
self.max_retries = 2 # 最大重试次数
|
||||
|
||||
logger.info(f"API Embedding 服务初始化完成")
|
||||
logger.info(f"模型: {self.model_name}")
|
||||
logger.info(f"API 地址: {self.api_base_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API Embedding 服务初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def encode_text(self, text: str) -> List[float]:
|
||||
"""编码单个文本(带重试机制)"""
|
||||
import time
|
||||
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
# 调用 API 进行编码
|
||||
url = f"{self.api_base_url}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text
|
||||
}
|
||||
|
||||
# 记录请求信息
|
||||
text_preview = text[:50] + "..." if len(text) > 50 else text
|
||||
logger.info(f"发送 embedding 请求 (尝试 {attempt + 1}/{self.max_retries + 1}): 文本长度={len(text)}, 预览='{text_preview}'")
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embedding = result['data'][0]['embedding']
|
||||
|
||||
# 首次调用时获取实际维度
|
||||
if self.embedding_dim is None:
|
||||
self.embedding_dim = len(embedding)
|
||||
logger.info(f"动态获取 embedding 维度: {self.embedding_dim}")
|
||||
|
||||
logger.info(f"embedding 生成成功")
|
||||
return embedding
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
last_error = e
|
||||
logger.warning(f"请求超时 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
if attempt < self.max_retries:
|
||||
wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"所有重试均失败,文本长度: {len(text)}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error(f"文本编码失败 (尝试 {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
if attempt < self.max_retries:
|
||||
wait_time = 2 ** attempt
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
|
||||
def encode_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""批量编码文本"""
|
||||
try:
|
||||
# 调用 API 进行批量编码
|
||||
url = f"{self.api_base_url}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": texts
|
||||
}
|
||||
|
||||
# 批量请求使用更长的超时时间
|
||||
batch_timeout = self.timeout * 2 # 批量请求超时时间加倍
|
||||
logger.info(f"发送批量 embedding 请求: 文本数={len(texts)}, 超时={batch_timeout}秒")
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=batch_timeout)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embeddings = [item['embedding'] for item in result['data']]
|
||||
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.error(f"批量文本编码失败: {e}")
|
||||
raise
|
||||
|
||||
def get_model_info(self) -> dict:
|
||||
"""获取模型信息"""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"api_base_url": self.api_base_url,
|
||||
"initialized": self._initialized,
|
||||
"service_type": "api"
|
||||
}
|
||||
|
||||
# 全局实例
|
||||
_global_embedding_service = None
|
||||
|
||||
def get_embedding_service() -> ApiEmbeddingService:
|
||||
"""获取全局embedding服务实例"""
|
||||
global _global_embedding_service
|
||||
if _global_embedding_service is None:
|
||||
_global_embedding_service = ApiEmbeddingService()
|
||||
return _global_embedding_service
|
||||
109
bionicmemory/services/chat_helper.py
Normal file
109
bionicmemory/services/chat_helper.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import logging
|
||||
import openai
|
||||
from typing import List
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
|
||||
class ChatHelper:
|
||||
"""聊天助手类,专门处理LLM聊天功能"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
"""
|
||||
初始化聊天助手
|
||||
|
||||
Args:
|
||||
api_key: API密钥(必须)
|
||||
base_url: API基础URL(必须)
|
||||
"""
|
||||
if not api_key or not base_url:
|
||||
raise ValueError("api_key和base_url是必须参数")
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
self.client = openai.OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
self.logger = get_logger(__name__)
|
||||
self.logger.info("聊天助手初始化完成")
|
||||
|
||||
def create_chat_completions(self, model: str, messages: List[dict], stream: bool = False,
|
||||
top_p: float = 0.5, temperature: float = 0.2, user: str = None):
|
||||
"""
|
||||
创建聊天完成
|
||||
|
||||
Args:
|
||||
model: 模型名称(必须)
|
||||
messages: 消息列表(必须)
|
||||
stream: 是否流式输出
|
||||
top_p: 核采样参数
|
||||
temperature: 温度参数
|
||||
user: 用户标识
|
||||
"""
|
||||
if not model or not messages:
|
||||
raise ValueError("model和messages参数是必须的")
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
if user:
|
||||
kwargs["user"] = user
|
||||
|
||||
completion = self.client.chat.completions.create(**kwargs)
|
||||
return completion
|
||||
|
||||
def generate_text(self, prompt: str, model: str, max_tokens: int = 500,
|
||||
temperature: float = 0.2, top_p: float = 0.5) -> str:
|
||||
"""
|
||||
生成文本内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词(必须)
|
||||
model: 模型名称(必须)
|
||||
max_tokens: 最大生成token数
|
||||
temperature: 温度参数
|
||||
top_p: 核采样参数
|
||||
|
||||
Returns:
|
||||
str: 生成的文本内容
|
||||
"""
|
||||
if not prompt or not model:
|
||||
raise ValueError("prompt和model参数是必须的")
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p
|
||||
)
|
||||
|
||||
generated_text = response.choices[0].message.content
|
||||
self.logger.debug(f"成功生成文本,长度: {len(generated_text)}")
|
||||
return generated_text
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成文本失败: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def get_models(self):
|
||||
"""获取可用模型列表"""
|
||||
models = self.client.models.list()
|
||||
return [model.id for model in models.data]
|
||||
|
||||
def get_model(self, model_id):
|
||||
"""获取特定模型详情"""
|
||||
model = self.client.models.retrieve(model_id)
|
||||
return model
|
||||
199
bionicmemory/services/local_embedding_service.py
Normal file
199
bionicmemory/services/local_embedding_service.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List, Optional
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import torch
|
||||
import hashlib
|
||||
import threading
|
||||
import os
|
||||
import sys
|
||||
from dotenv import load_dotenv,find_dotenv
|
||||
|
||||
# 设置离线模式,避免访问Hugging Face
|
||||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||||
os.environ['HF_HUB_OFFLINE'] = '1'
|
||||
os.environ['HF_DATASETS_OFFLINE'] = '1'
|
||||
|
||||
# 设置国内 Hugging Face 镜像站点(作为备用)
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 导入配置工具
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
try:
|
||||
import utils.config_util as cfg
|
||||
CONFIG_UTIL_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
CONFIG_UTIL_AVAILABLE = False
|
||||
cfg = None
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if not CONFIG_UTIL_AVAILABLE:
|
||||
logger.warning("无法导入 config_util,将使用 .env 配置")
|
||||
|
||||
class LocalEmbeddingService:
|
||||
"""本地Embedding服务 - 单例模式,模型驻留内存"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._initialize_model()
|
||||
LocalEmbeddingService._initialized = True
|
||||
|
||||
def _initialize_model(self):
|
||||
"""初始化模型,只执行一次"""
|
||||
try:
|
||||
# 优先从 system.conf 读取配置
|
||||
user_model_name = None
|
||||
cache_dir_config = None
|
||||
|
||||
if CONFIG_UTIL_AVAILABLE and cfg:
|
||||
try:
|
||||
# 确保配置已加载
|
||||
if cfg.config is None:
|
||||
cfg.load_config()
|
||||
|
||||
# 从 config_util 获取配置
|
||||
user_model_name = cfg.embedding_model
|
||||
cache_dir_config = cfg.embedding_cache_dir
|
||||
|
||||
if user_model_name:
|
||||
logger.info(f"从 system.conf 读取配置: embedding_model={user_model_name}")
|
||||
if cache_dir_config:
|
||||
logger.info(f"从 system.conf 读取配置: embedding_cache_dir={cache_dir_config}")
|
||||
except Exception as e:
|
||||
logger.warning(f"从 system.conf 读取配置失败: {e}")
|
||||
|
||||
# 降级到 .env 或默认值
|
||||
if not user_model_name:
|
||||
user_model_name = os.getenv('LOCAL_EMBEDDING_MODEL', 'Qwen/Qwen3-Embedding-0.6B')
|
||||
logger.info(f"使用 .env 或默认配置: embedding_model={user_model_name}")
|
||||
|
||||
if not cache_dir_config:
|
||||
cache_dir_config = os.getenv('LOCAL_EMBEDDING_CACHE_DIR', 'models/embeddings')
|
||||
logger.info(f"使用 .env 或默认配置: embedding_cache_dir={cache_dir_config}")
|
||||
|
||||
# 处理相对路径
|
||||
if not os.path.isabs(cache_dir_config):
|
||||
cache_dir = os.path.join(os.getcwd(), cache_dir_config)
|
||||
else:
|
||||
cache_dir = cache_dir_config
|
||||
|
||||
cache_dir_abs = os.path.abspath(cache_dir)
|
||||
|
||||
# 按规则拼成路径
|
||||
model_path = os.path.join(cache_dir_abs, f"models--{user_model_name.replace('/', '--')}", "snapshots",
|
||||
"c54f2e6e80b2d7b7de06f51cec4959f6b3e03418")
|
||||
|
||||
# 转换为绝对路径
|
||||
model_name_abs = os.path.abspath(model_path)
|
||||
|
||||
|
||||
logger.info(f"用户设置的模型名称: {user_model_name}")
|
||||
logger.info(f"按规则拼成的模型路径: {model_path}")
|
||||
logger.info(f"程序实际使用的模型绝对路径: {model_name_abs}")
|
||||
logger.info(f"程序实际使用的缓存绝对路径: {cache_dir_abs}")
|
||||
logger.info(f"模型路径是否存在: {os.path.exists(model_name_abs)}")
|
||||
logger.info(f"缓存路径是否存在: {os.path.exists(cache_dir_abs)}")
|
||||
|
||||
# 检查路径是否存在,如果不存在则自动下载
|
||||
if not os.path.exists(model_name_abs):
|
||||
logger.info(f"模型路径不存在: {model_name_abs}")
|
||||
logger.info("开始自动下载模型...")
|
||||
|
||||
# 确保缓存目录存在
|
||||
os.makedirs(cache_dir_abs, exist_ok=True)
|
||||
|
||||
# 使用 SentenceTransformer 自动下载模型
|
||||
logger.info(f"正在下载模型: {user_model_name}")
|
||||
self.model = SentenceTransformer(user_model_name, cache_folder=cache_dir_abs)
|
||||
logger.info("模型下载完成!")
|
||||
else:
|
||||
logger.info(f"使用本地模型: {model_name_abs}")
|
||||
# 使用绝对路径
|
||||
self.model = SentenceTransformer(model_name_abs, cache_folder=cache_dir_abs)
|
||||
|
||||
# 设置为评估模式
|
||||
self.model.eval()
|
||||
|
||||
# 如果支持GPU,使用GPU
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
logger.info("使用GPU加速")
|
||||
else:
|
||||
logger.info("使用CPU")
|
||||
|
||||
logger.info(f"{model_name_abs}模型加载完成")
|
||||
logger.info(f"模型缓存路径: {cache_dir_abs}")
|
||||
|
||||
# 保存配置信息
|
||||
self.model_name = user_model_name
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{model_name_abs}模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
def encode_text(self, text: str) -> List[float]:
|
||||
"""编码单个文本"""
|
||||
try:
|
||||
# 使用驻留的模型进行编码
|
||||
embedding = self.model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist() # 转换为list
|
||||
except Exception as e:
|
||||
logger.error(f"文本编码失败: {e}")
|
||||
raise
|
||||
|
||||
def encode_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""批量编码文本"""
|
||||
try:
|
||||
# 使用驻留的模型进行批量编码
|
||||
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings.tolist() # 转换为list
|
||||
except Exception as e:
|
||||
logger.error(f"批量文本编码失败: {e}")
|
||||
raise
|
||||
|
||||
def get_model_info(self) -> dict:
|
||||
"""获取模型信息"""
|
||||
return {
|
||||
"model_name": getattr(self, 'model_name', 'Qwen/Qwen3-Embedding-0.6B'),
|
||||
"embedding_dim": 1024,
|
||||
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
||||
"initialized": self._initialized,
|
||||
"cache_dir": getattr(self, 'cache_dir', os.path.join(os.getcwd(), "ChromaWithForgetting", "models", "embeddings"))
|
||||
}
|
||||
|
||||
# 导入 API Embedding 服务
|
||||
from bionicmemory.services.api_embedding_service import ApiEmbeddingService
|
||||
|
||||
# 全局实例
|
||||
_global_embedding_service = None
|
||||
|
||||
def get_embedding_service() -> ApiEmbeddingService:
|
||||
"""获取全局embedding服务实例(现在返回 API 服务)"""
|
||||
global _global_embedding_service
|
||||
if _global_embedding_service is None:
|
||||
_global_embedding_service = ApiEmbeddingService()
|
||||
return _global_embedding_service
|
||||
312
bionicmemory/services/memory_cleanup_scheduler.py
Normal file
312
bionicmemory/services/memory_cleanup_scheduler.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
记忆库定时清理服务
|
||||
使用 apscheduler 定期清理长短期记忆库
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from bionicmemory.core.memory_system import LongShortTermMemorySystem
|
||||
from bionicmemory.algorithms.newton_cooling_helper import CoolingRate
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class MemoryCleanupScheduler:
|
||||
"""
|
||||
记忆库定时清理调度器
|
||||
负责定期清理长短期记忆库中的过期记录
|
||||
"""
|
||||
|
||||
def __init__(self, memory_system: LongShortTermMemorySystem):
|
||||
"""
|
||||
初始化清理调度器
|
||||
|
||||
Args:
|
||||
memory_system: 长短期记忆系统实例
|
||||
"""
|
||||
self.memory_system = memory_system
|
||||
self.scheduler = BackgroundScheduler()
|
||||
self.is_running = False
|
||||
|
||||
logger.info("记忆库清理调度器初始化完成")
|
||||
|
||||
def start(self):
|
||||
"""启动定时清理服务"""
|
||||
try:
|
||||
if self.is_running:
|
||||
logger.warning("清理调度器已经在运行")
|
||||
return
|
||||
|
||||
# 添加定时清理任务
|
||||
self._add_cleanup_jobs()
|
||||
|
||||
# 启动调度器
|
||||
self.scheduler.start()
|
||||
self.is_running = True
|
||||
|
||||
logger.info("记忆库清理调度器启动成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动清理调度器失败: {e}")
|
||||
raise
|
||||
|
||||
def stop(self):
|
||||
"""停止定时清理服务"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
logger.warning("清理调度器未在运行")
|
||||
return
|
||||
|
||||
# 停止调度器
|
||||
self.scheduler.shutdown()
|
||||
self.is_running = False
|
||||
|
||||
logger.info("记忆库清理调度器已停止")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"停止清理调度器失败: {e}")
|
||||
raise
|
||||
|
||||
def _add_cleanup_jobs(self):
|
||||
"""添加定时清理任务"""
|
||||
try:
|
||||
# 1. 短期记忆库清理任务 - 每10分钟执行一次
|
||||
# 短期记忆使用 MINUTES_20 遗忘速率,需要更频繁的清理
|
||||
short_term_trigger = IntervalTrigger(minutes=10)
|
||||
self.scheduler.add_job(
|
||||
func=self._cleanup_short_term_memory,
|
||||
trigger=short_term_trigger,
|
||||
id="short_term_cleanup",
|
||||
name="短期记忆库清理",
|
||||
max_instances=1,
|
||||
coalesce=True
|
||||
)
|
||||
|
||||
# 2. 长期记忆库清理任务 - 每天夜里4点执行
|
||||
# 长期记忆使用 DAYS_31 遗忘速率,可以每天清理一次
|
||||
long_term_trigger = CronTrigger(hour=4, minute=0)
|
||||
self.scheduler.add_job(
|
||||
func=self._cleanup_long_term_memory,
|
||||
trigger=long_term_trigger,
|
||||
id="long_term_cleanup",
|
||||
name="长期记忆库清理",
|
||||
max_instances=1,
|
||||
coalesce=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
logger.info("定时清理任务添加完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加定时清理任务失败: {e}")
|
||||
raise
|
||||
|
||||
def _cleanup_short_term_memory(self):
|
||||
"""清理短期记忆库 - 每10分钟执行一次"""
|
||||
try:
|
||||
logger.info("开始执行短期记忆库定时清理")
|
||||
|
||||
# 🔒 注意:定时清理是系统级操作,清理所有用户的过期记录
|
||||
# 这是合理的,因为系统需要维护整体性能
|
||||
self.memory_system._cleanup_collection(
|
||||
self.memory_system.short_term_collection_name,
|
||||
CoolingRate.MINUTES_20,
|
||||
self.memory_system.short_term_threshold
|
||||
)
|
||||
|
||||
logger.info(f"短期记忆库定时清理完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"短期记忆库定时清理失败: {e}")
|
||||
|
||||
def _cleanup_long_term_memory(self):
|
||||
"""清理长期记忆库"""
|
||||
try:
|
||||
logger.info("开始执行长期记忆库定时清理")
|
||||
|
||||
# 🔒 注意:定时清理是系统级操作,清理所有用户的过期记录
|
||||
# 这是合理的,因为系统需要维护整体性能
|
||||
self.memory_system._cleanup_collection(
|
||||
self.memory_system.long_term_collection_name,
|
||||
CoolingRate.DAYS_31,
|
||||
self.memory_system.long_term_threshold
|
||||
)
|
||||
|
||||
logger.info(f"长期记忆库定时清理完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"长期记忆库定时清理失败: {e}")
|
||||
|
||||
|
||||
|
||||
def get_scheduler_status(self) -> dict:
|
||||
"""
|
||||
获取调度器状态
|
||||
|
||||
Returns:
|
||||
调度器状态信息
|
||||
"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
return {
|
||||
"status": "stopped",
|
||||
"jobs": [],
|
||||
"message": "调度器未运行"
|
||||
}
|
||||
|
||||
# 获取所有任务信息
|
||||
jobs = []
|
||||
for job in self.scheduler.get_jobs():
|
||||
jobs.append({
|
||||
"id": job.id,
|
||||
"name": job.name,
|
||||
"next_run_time": str(job.next_run_time) if job.next_run_time else "None",
|
||||
"trigger": str(job.trigger)
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "running",
|
||||
"jobs": jobs,
|
||||
"message": "调度器运行正常"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取调度器状态失败: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"jobs": [],
|
||||
"message": f"获取状态失败: {e}"
|
||||
}
|
||||
|
||||
def add_custom_cleanup_job(self,
|
||||
func,
|
||||
trigger,
|
||||
job_id: str,
|
||||
name: str = None):
|
||||
"""
|
||||
添加自定义清理任务
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
trigger: 触发器
|
||||
job_id: 任务ID
|
||||
name: 任务名称
|
||||
"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
logger.warning("调度器未运行,无法添加任务")
|
||||
return False
|
||||
|
||||
self.scheduler.add_job(
|
||||
func=func,
|
||||
trigger=trigger,
|
||||
id=job_id,
|
||||
name=name or job_id,
|
||||
max_instances=1,
|
||||
coalesce=True
|
||||
)
|
||||
|
||||
logger.info(f"自定义清理任务添加成功: {job_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加自定义清理任务失败: {e}")
|
||||
return False
|
||||
|
||||
def remove_job(self, job_id: str) -> bool:
|
||||
"""
|
||||
移除指定的任务
|
||||
|
||||
Args:
|
||||
job_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否移除成功
|
||||
"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
logger.warning("调度器未运行,无法移除任务")
|
||||
return False
|
||||
|
||||
self.scheduler.remove_job(job_id)
|
||||
logger.info(f"任务移除成功: {job_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"移除任务失败: {e}")
|
||||
return False
|
||||
|
||||
def pause_job(self, job_id: str) -> bool:
|
||||
"""
|
||||
暂停指定的任务
|
||||
|
||||
Args:
|
||||
job_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否暂停成功
|
||||
"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
logger.warning("调度器未运行,无法暂停任务")
|
||||
return False
|
||||
|
||||
self.scheduler.pause_job(job_id)
|
||||
logger.info(f"任务暂停成功: {job_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"暂停任务失败: {e}")
|
||||
return False
|
||||
|
||||
def resume_job(self, job_id: str) -> bool:
|
||||
"""
|
||||
恢复指定的任务
|
||||
|
||||
Args:
|
||||
job_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否恢复成功
|
||||
"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
logger.warning("调度器未运行,无法恢复任务")
|
||||
return False
|
||||
|
||||
self.scheduler.resume_job(job_id)
|
||||
logger.info(f"任务恢复成功: {job_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"恢复任务失败: {e}")
|
||||
return False
|
||||
|
||||
def run_cleanup_now(self):
|
||||
"""立即执行一次清理任务"""
|
||||
try:
|
||||
logger.info("开始执行立即清理任务")
|
||||
|
||||
# 执行清理 - 同时清理长短期记忆库
|
||||
self.memory_system._cleanup_collection(
|
||||
self.memory_system.short_term_collection_name,
|
||||
CoolingRate.MINUTES_20,
|
||||
self.memory_system.short_term_threshold
|
||||
)
|
||||
self.memory_system._cleanup_collection(
|
||||
self.memory_system.long_term_collection_name,
|
||||
CoolingRate.DAYS_31,
|
||||
self.memory_system.long_term_threshold
|
||||
)
|
||||
|
||||
logger.info("立即清理任务执行完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"立即清理任务执行失败: {e}")
|
||||
raise
|
||||
169
bionicmemory/services/summary_service.py
Normal file
169
bionicmemory/services/summary_service.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
摘要生成服务
|
||||
基于 ChatHelper 实现长内容摘要功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from bionicmemory.services.chat_helper import ChatHelper
|
||||
|
||||
# 使用统一日志配置
|
||||
from bionicmemory.utils.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
class SummaryService:
|
||||
"""摘要生成服务"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化摘要服务"""
|
||||
# 从环境变量读取配置
|
||||
self.api_key = os.getenv('OPENAI_API_KEY')
|
||||
self.base_url = os.getenv('OPENAI_API_BASE')
|
||||
self.model_name = os.getenv('OPENAI_MODEL_NAME')
|
||||
self.summary_max_length = int(os.getenv('SUMMARY_MAX_LENGTH', '500'))
|
||||
|
||||
# 验证必需配置
|
||||
if not self.api_key:
|
||||
raise ValueError("缺少必需的环境变量: OPENAI_API_KEY")
|
||||
if not self.base_url:
|
||||
raise ValueError("缺少必需的环境变量: OPENAI_API_BASE")
|
||||
if not self.model_name:
|
||||
raise ValueError("缺少必需的环境变量: OPENAI_MODEL_NAME")
|
||||
|
||||
# 初始化LLM助手
|
||||
self.chat_helper = ChatHelper(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
logger.info(f"摘要服务初始化完成")
|
||||
logger.info(f"使用模型: {self.model_name}")
|
||||
logger.info(f"摘要最大长度: {self.summary_max_length}")
|
||||
|
||||
def generate_summary(self, content: str, max_length: Optional[int] = None) -> str:
|
||||
"""
|
||||
生成内容摘要
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
max_length: 摘要最大长度,如果不提供则使用环境变量配置
|
||||
|
||||
Returns:
|
||||
str: 生成的摘要
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 如果内容长度小于阈值,直接返回原内容
|
||||
if len(content) <= self.summary_max_length:
|
||||
return content
|
||||
|
||||
try:
|
||||
# 构建摘要提示词
|
||||
prompt = self._build_summary_prompt(content, max_length or self.summary_max_length)
|
||||
|
||||
# 调用LLM生成摘要
|
||||
summary = self.chat_helper.generate_text(
|
||||
prompt=prompt,
|
||||
model=self.model_name,
|
||||
max_tokens=max_length or self.summary_max_length,
|
||||
temperature=0.3, # 低温度,确保摘要的准确性
|
||||
top_p=0.8
|
||||
)
|
||||
|
||||
# 清理摘要内容
|
||||
summary = self._clean_summary(summary)
|
||||
|
||||
logger.info(f"摘要生成成功: {len(content)} -> {len(summary)} 字符")
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"摘要生成失败: {e}")
|
||||
# 降级到简单截断
|
||||
return self._fallback_summary(content, max_length or self.summary_max_length)
|
||||
|
||||
def _build_summary_prompt(self, content: str, max_length: int) -> str:
|
||||
"""
|
||||
构建摘要生成提示词
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
max_length: 摘要最大长度
|
||||
|
||||
Returns:
|
||||
str: 构建的提示词
|
||||
"""
|
||||
prompt = f"""请为以下内容生成一个简洁的摘要,要求:
|
||||
|
||||
1. 摘要长度控制在 {max_length} 字符以内
|
||||
2. 保留核心信息和关键要点
|
||||
3. 使用简洁明了的语言
|
||||
4. 确保摘要的完整性和准确性
|
||||
|
||||
原始内容:
|
||||
{content}
|
||||
|
||||
请生成摘要:"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _clean_summary(self, summary: str) -> str:
|
||||
"""
|
||||
清理摘要内容
|
||||
|
||||
Args:
|
||||
summary: 原始摘要
|
||||
|
||||
Returns:
|
||||
str: 清理后的摘要
|
||||
"""
|
||||
if not summary:
|
||||
return ""
|
||||
|
||||
# 移除多余的空白字符
|
||||
summary = summary.strip()
|
||||
|
||||
# 移除可能的提示词残留
|
||||
summary = summary.replace("摘要:", "").replace("摘要:", "")
|
||||
summary = summary.replace("总结:", "").replace("总结:", "")
|
||||
|
||||
# 如果摘要以引号开始和结束,移除引号
|
||||
if summary.startswith('"') and summary.endswith('"'):
|
||||
summary = summary[1:-1]
|
||||
if summary.startswith("'") and summary.endswith("'"):
|
||||
summary = summary[1:-1]
|
||||
|
||||
return summary.strip()
|
||||
|
||||
def _fallback_summary(self, content: str, max_length: int) -> str:
|
||||
"""
|
||||
降级摘要方案(简单截断)
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
str: 截断后的内容
|
||||
"""
|
||||
logger.warning("使用降级摘要方案:简单截断")
|
||||
|
||||
# 尝试在句号处截断
|
||||
summary = content[:max_length]
|
||||
|
||||
# 查找最后一个句号位置
|
||||
last_period = summary.rfind('。')
|
||||
if last_period > max_length * 0.8: # 如果句号在80%位置之后
|
||||
summary = summary[:last_period + 1]
|
||||
|
||||
# 如果内容被截断,添加省略号
|
||||
if len(content) > max_length:
|
||||
summary += "..."
|
||||
|
||||
return summary
|
||||
7
bionicmemory/utils/__init__.py
Normal file
7
bionicmemory/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
工具模块
|
||||
|
||||
包含仿生记忆系统的工具函数:
|
||||
- 授权验证
|
||||
- 其他辅助工具
|
||||
"""
|
||||
88
bionicmemory/utils/logging_config.py
Normal file
88
bionicmemory/utils/logging_config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
统一日志配置模块
|
||||
提供统一的日志格式配置,确保所有模块使用相同的日志输出格式
|
||||
支持环境变量配置日志级别和输出文件,支持按日期的多日志文件
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
def setup_logging():
|
||||
"""设置统一的日志配置"""
|
||||
# 从环境变量读取配置
|
||||
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
|
||||
log_dir = os.getenv('LOG_DIR', './logs/')
|
||||
|
||||
# 创建日志目录
|
||||
log_path = Path(log_dir)
|
||||
log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 按日期生成日志文件名
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
log_file = log_path / f'bionicmemory-{today}.log'
|
||||
|
||||
# 统一格式:时间 - 级别 - 文件名:行号 - 消息
|
||||
format_string = '%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
|
||||
date_format = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
# 配置日志级别
|
||||
numeric_level = getattr(logging, log_level, logging.INFO)
|
||||
|
||||
# 创建格式化器
|
||||
formatter = logging.Formatter(format_string, date_format)
|
||||
|
||||
# 清除现有的处理器
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# 控制台处理器(解决乱码问题)
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(numeric_level)
|
||||
# 设置控制台输出编码
|
||||
if hasattr(console_handler.stream, 'reconfigure'):
|
||||
console_handler.stream.reconfigure(encoding='utf-8')
|
||||
|
||||
# 文件处理器(按日期轮转)
|
||||
file_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
log_file,
|
||||
when='midnight', # 每天午夜轮转
|
||||
interval=1, # 间隔1天
|
||||
backupCount=30, # 保留30天的日志
|
||||
encoding='utf-8' # 解决乱码问题
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(numeric_level)
|
||||
|
||||
# 配置根日志器
|
||||
root_logger.setLevel(numeric_level)
|
||||
root_logger.addHandler(console_handler)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# 设置第三方库的日志级别,避免过多输出
|
||||
logging.getLogger('chromadb').setLevel(logging.WARNING)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
logging.getLogger('transformers').setLevel(logging.WARNING)
|
||||
logging.getLogger('sentence_transformers').setLevel(logging.WARNING)
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
获取指定名称的日志器
|
||||
|
||||
Args:
|
||||
name: 日志器名称,通常使用__name__
|
||||
|
||||
Returns:
|
||||
配置好的日志器实例
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
# 在模块导入时自动设置默认日志配置
|
||||
setup_logging()
|
||||
Reference in New Issue
Block a user