mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-12 18:01:30 +08:00
feat: improve the memory system
This commit is contained in:
@@ -6,5 +6,6 @@ Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
|
||||
|
||||
@@ -41,6 +41,10 @@ class MemoryConfig:
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
# Memory flush config (独立于模型 context window)
|
||||
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
|
||||
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
return Path(self.workspace_root)
|
||||
|
||||
@@ -4,20 +4,19 @@ Embedding providers for memory
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
@@ -31,7 +30,7 @@ class EmbeddingProvider(ABC):
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider"""
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
@@ -45,87 +44,58 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
# Lazy import to avoid dependency issues
|
||||
try:
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(api_key=api_key, base_url=api_base)
|
||||
except ImportError:
|
||||
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
||||
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI embedding API using requests"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
response = self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
result = self._call_api(text)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
|
||||
result = self._call_api(texts)
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers"""
|
||||
|
||||
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
||||
"""
|
||||
Initialize local embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name from sentence-transformers
|
||||
"""
|
||||
self.model_name = model
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(model)
|
||||
self._dimensions = self.model.get_sentence_embedding_dimension()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"sentence-transformers not installed. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
embedding = self.model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist()
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings.tolist()
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
@@ -156,20 +126,23 @@ def create_embedding_provider(
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Only supports OpenAI embedding via REST API.
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
Args:
|
||||
provider: Provider name ("openai" or "local")
|
||||
model: Model name (provider-specific)
|
||||
api_key: API key for remote providers
|
||||
api_base: API base URL for remote providers
|
||||
provider: Provider name (only "openai" is supported)
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: OpenAI API key (required)
|
||||
api_base: API base URL (default: https://api.openai.com/v1)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not "openai" or api_key is missing
|
||||
"""
|
||||
if provider == "openai":
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
elif provider == "local":
|
||||
model = model or "all-MiniLM-L6-v2"
|
||||
return LocalEmbeddingProvider(model=model)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||
if provider != "openai":
|
||||
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
|
||||
@@ -70,8 +70,9 @@ class MemoryManager:
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
|
||||
print(f"ℹ️ Memory will work with keyword search only (no semantic search)")
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
@@ -135,13 +136,19 @@ class MemoryManager:
|
||||
# Perform vector search (if embedding provider available)
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
keyword_results = self.storage.search_keyword(
|
||||
@@ -150,6 +157,8 @@ class MemoryManager:
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
@@ -356,30 +365,30 @@ class MemoryManager:
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int = 128000,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
current_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window。
|
||||
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size (default: 128K)
|
||||
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
|
||||
soft_threshold: Trigger N tokens before threshold (default: 4K)
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
token_threshold=self.config.flush_token_threshold,
|
||||
turn_threshold=self.config.flush_turn_threshold
|
||||
)
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数(每次用户消息+AI回复算一轮)"""
|
||||
self.flush_manager.increment_turn()
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
|
||||
@@ -41,46 +41,42 @@ class MemoryFlushManager:
|
||||
# Tracking
|
||||
self.last_flush_token_count: Optional[int] = None
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
self.turn_count: int = 0 # 对话轮数计数器
|
||||
|
||||
def should_flush(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
current_tokens: int = 0,
|
||||
token_threshold: int = 50000,
|
||||
turn_threshold: int = 20
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if memory flush should be triggered
|
||||
|
||||
Similar to clawdbot's shouldRunMemoryFlush logic:
|
||||
threshold = contextWindow - reserveTokens - softThreshold
|
||||
独立的 flush 触发机制,不依赖模型 context window:
|
||||
- Token 阈值: 达到 50K tokens 时触发
|
||||
- 轮次阈值: 达到 20 轮对话时触发
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size
|
||||
reserve_tokens: Reserve tokens for compaction overhead
|
||||
soft_threshold: Trigger flush N tokens before threshold
|
||||
token_threshold: Token threshold to trigger flush (default: 50K)
|
||||
turn_threshold: Turn threshold to trigger flush (default: 20)
|
||||
|
||||
Returns:
|
||||
True if flush should run
|
||||
"""
|
||||
if current_tokens <= 0:
|
||||
return False
|
||||
# 检查 token 阈值
|
||||
if current_tokens > 0 and current_tokens >= token_threshold:
|
||||
# 避免重复 flush
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + 5000:
|
||||
return False
|
||||
return True
|
||||
|
||||
threshold = max(0, context_window - reserve_tokens - soft_threshold)
|
||||
if threshold <= 0:
|
||||
return False
|
||||
# 检查轮次阈值
|
||||
if self.turn_count >= turn_threshold:
|
||||
return True
|
||||
|
||||
# Check if we've crossed the threshold
|
||||
if current_tokens < threshold:
|
||||
return False
|
||||
|
||||
# Avoid duplicate flush in same compaction cycle
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + soft_threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
@@ -130,7 +126,12 @@ class MemoryFlushManager:
|
||||
f"Pre-compaction memory flush. "
|
||||
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||
f"create memory/ if needed). "
|
||||
f"If nothing to store, reply with NO_REPLY."
|
||||
f"\n\n"
|
||||
f"重要提示:\n"
|
||||
f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
|
||||
f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
|
||||
f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
|
||||
f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
|
||||
)
|
||||
|
||||
def create_flush_system_prompt(self) -> str:
|
||||
@@ -142,6 +143,20 @@ class MemoryFlushManager:
|
||||
return (
|
||||
"Pre-compaction memory flush turn. "
|
||||
"The session is near auto-compaction; capture durable memories to disk. "
|
||||
"\n\n"
|
||||
"记忆写入原则:\n"
|
||||
"1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens)\n"
|
||||
" - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
|
||||
" - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
|
||||
"\n"
|
||||
"2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
|
||||
" - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
|
||||
"\n"
|
||||
"3. 判断标准:\n"
|
||||
" - 这个信息未来会经常用到吗?→ MEMORY.md\n"
|
||||
" - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
|
||||
" - 这是临时性的、不重要的内容吗?→ 不记录\n"
|
||||
"\n"
|
||||
"You may reply, but usually NO_REPLY is correct."
|
||||
)
|
||||
|
||||
@@ -180,6 +195,7 @@ class MemoryFlushManager:
|
||||
# Track flush
|
||||
self.last_flush_token_count = current_tokens
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
self.turn_count = 0 # 重置轮数计数器
|
||||
|
||||
return True
|
||||
|
||||
@@ -187,6 +203,10 @@ class MemoryFlushManager:
|
||||
print(f"Memory flush failed: {e}")
|
||||
return False
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数"""
|
||||
self.turn_count += 1
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get memory flush status"""
|
||||
return {
|
||||
|
||||
@@ -179,8 +179,8 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
tool_map = {}
|
||||
tool_descriptions = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件内容",
|
||||
"write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit)",
|
||||
"edit": "精确编辑文件(追加、修改、删除部分内容)",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "在文件中搜索内容",
|
||||
"find": "按照模式查找文件",
|
||||
@@ -305,17 +305,18 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"",
|
||||
"1. 使用 `memory_search` 在 MEMORY.md 和 memory/*.md 中搜索",
|
||||
"2. 然后使用 `memory_get` 只拉取需要的行",
|
||||
"3. 如果搜索后仍然信心不足,告诉用户你已经检查过了",
|
||||
"1. 不确定信息位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件和大致位置 → 直接用 `memory_get` 读取相应的行",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取最近两天的记忆文件",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆,包含重要的背景信息",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的对话和事件",
|
||||
"- `MEMORY.md`: 长期记忆(已自动加载,无需主动读取)",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
|
||||
"",
|
||||
"**使用原则**:",
|
||||
"- 自然使用记忆,就像你本来就知道",
|
||||
"- 不要主动提起或列举记忆,除非用户明确询问",
|
||||
"- 自然使用记忆,就像你本来就知道; 不用刻意提起或列举记忆,除非用户提起相关内容",
|
||||
"- 追加内容到现有记忆文件 → 必须用 `edit` 工具(先 read 读取,再 edit 追加)",
|
||||
"- 创建新的记忆文件 → 可以用 `write` 工具(已有记忆文件不可直接write,会覆盖删除)",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@ import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class AgentStreamExecutor:
|
||||
@@ -164,30 +164,24 @@ class AgentStreamExecutor:
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
# 使用独立的 flush 阈值(50K tokens 或 20 轮)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
# Use configured reserve_tokens or calculate based on context window
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
# Use smaller soft_threshold to trigger flush earlier (e.g., at 50K tokens)
|
||||
soft_threshold = 10000 # Trigger 10K tokens before limit
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
current_tokens=current_tokens
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"threshold": context_window - reserve_tokens - soft_threshold
|
||||
"turn_count": self.agent.memory_manager.flush_manager.turn_count
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
||||
logger.info(
|
||||
f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}")
|
||||
|
||||
# Call LLM
|
||||
assistant_msg, tool_calls = self._call_llm_stream()
|
||||
@@ -321,6 +315,10 @@ class AgentStreamExecutor:
|
||||
logger.info(f"🏁 完成({turn}轮)")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
# 每轮对话结束后增加计数(用户消息+AI回复=1轮)
|
||||
if self.agent.memory_manager:
|
||||
self.agent.memory_manager.increment_turn()
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]:
|
||||
@@ -664,9 +662,11 @@ class AgentStreamExecutor:
|
||||
if not self.messages or not self.agent:
|
||||
return
|
||||
|
||||
# Get context window and reserve tokens from agent
|
||||
# Get context window from agent (based on model)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
|
||||
# Reserve 10% for response generation
|
||||
reserve_tokens = int(context_window * 0.1)
|
||||
max_tokens = context_window - reserve_tokens
|
||||
|
||||
# Estimate current tokens
|
||||
|
||||
@@ -2,25 +2,17 @@
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
|
||||
# Import basic tools (no external dependencies)
|
||||
from agent.tools.calculator.calculator import Calculator
|
||||
|
||||
# Import file operation tools
|
||||
from agent.tools.read.read import Read
|
||||
from agent.tools.write.write import Write
|
||||
from agent.tools.edit.edit import Edit
|
||||
from agent.tools.bash.bash import Bash
|
||||
from agent.tools.grep.grep import Grep
|
||||
from agent.tools.find.find import Find
|
||||
from agent.tools.ls.ls import Ls
|
||||
|
||||
# Import memory tools
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
# Import web tools
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
|
||||
# Import tools with optional dependencies
|
||||
def _import_optional_tools():
|
||||
"""Import tools that have optional dependencies"""
|
||||
@@ -80,17 +72,13 @@ BrowserTool = _import_browser_tool()
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'ToolManager',
|
||||
'Calculator',
|
||||
'Read',
|
||||
'Write',
|
||||
'Edit',
|
||||
'Bash',
|
||||
'Grep',
|
||||
'Find',
|
||||
'Ls',
|
||||
'MemorySearchTool',
|
||||
'MemoryGetTool',
|
||||
'WebFetch',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
'GoogleSearch',
|
||||
'FileSave',
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import math
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Calculator(BaseTool):
|
||||
name: str = "calculator"
|
||||
description: str = "A tool to perform basic mathematical calculations."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
|
||||
"Ensure your input is a valid Python expression, it will be evaluated directly."
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def execute(self, args: dict) -> ToolResult:
|
||||
try:
|
||||
# Get the expression
|
||||
expression = args["expression"]
|
||||
|
||||
# Create a safe local environment containing only basic math functions
|
||||
safe_locals = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"pow": pow,
|
||||
"sqrt": math.sqrt,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"pi": math.pi,
|
||||
"e": math.e,
|
||||
"log": math.log,
|
||||
"log10": math.log10,
|
||||
"exp": math.exp,
|
||||
"floor": math.floor,
|
||||
"ceil": math.ceil
|
||||
}
|
||||
|
||||
# Safely evaluate the expression
|
||||
result = eval(expression, {"__builtins__": {}}, safe_locals)
|
||||
|
||||
return ToolResult.success({
|
||||
"result": result,
|
||||
"expression": expression
|
||||
})
|
||||
except Exception as e:
|
||||
return ToolResult.success({
|
||||
"error": str(e),
|
||||
"expression": args.get("expression", "")
|
||||
})
|
||||
@@ -33,7 +33,7 @@ class Edit(BaseTool):
|
||||
},
|
||||
"oldText": {
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)"
|
||||
"description": "Exact text to find and replace (must match exactly, cannot be empty). To append to end of file, include the last few lines as oldText."
|
||||
},
|
||||
"newText": {
|
||||
"type": "string",
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .find import Find
|
||||
|
||||
__all__ = ['Find']
|
||||
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
Find tool - Search for files by glob pattern
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob as glob_module
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 1000
|
||||
|
||||
|
||||
class Find(BaseTool):
|
||||
"""Tool for finding files by pattern"""
|
||||
|
||||
name: str = "find"
|
||||
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search in (default: current directory)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {search_path}")
|
||||
|
||||
try:
|
||||
# Load .gitignore patterns
|
||||
ignore_patterns = self._load_gitignore(absolute_path)
|
||||
|
||||
# Search for files
|
||||
results = []
|
||||
search_pattern = os.path.join(absolute_path, pattern)
|
||||
|
||||
# Use glob with recursive support
|
||||
for file_path in glob_module.glob(search_pattern, recursive=True):
|
||||
# Skip if matches ignore patterns
|
||||
if self._should_ignore(file_path, absolute_path, ignore_patterns):
|
||||
continue
|
||||
|
||||
# Get relative path
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
|
||||
# Add trailing slash for directories
|
||||
if os.path.isdir(file_path):
|
||||
relative_path += '/'
|
||||
|
||||
results.append(relative_path)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "No files found matching pattern", "files": []})
|
||||
|
||||
# Sort results
|
||||
results.sort()
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
result_limit_reached = len(results) >= limit
|
||||
if result_limit_reached:
|
||||
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["result_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"file_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing find: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _load_gitignore(self, directory: str) -> List[str]:
|
||||
"""Load .gitignore patterns from directory"""
|
||||
patterns = []
|
||||
gitignore_path = os.path.join(directory, '.gitignore')
|
||||
|
||||
if os.path.exists(gitignore_path):
|
||||
try:
|
||||
with open(gitignore_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
patterns.append(line)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add common ignore patterns
|
||||
patterns.extend([
|
||||
'.git',
|
||||
'__pycache__',
|
||||
'*.pyc',
|
||||
'node_modules',
|
||||
'.DS_Store'
|
||||
])
|
||||
|
||||
return patterns
|
||||
|
||||
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
|
||||
"""Check if file should be ignored based on patterns"""
|
||||
relative_path = os.path.relpath(file_path, base_path)
|
||||
|
||||
for pattern in patterns:
|
||||
# Simple pattern matching
|
||||
if pattern in relative_path:
|
||||
return True
|
||||
|
||||
# Check if it's a directory pattern
|
||||
if pattern.endswith('/'):
|
||||
if relative_path.startswith(pattern.rstrip('/')):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1,3 +0,0 @@
|
||||
from .grep import Grep
|
||||
|
||||
__all__ = ['Grep']
|
||||
@@ -1,248 +0,0 @@
|
||||
"""
|
||||
Grep tool - Search file contents for patterns
|
||||
Uses ripgrep (rg) for fast searching
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import (
|
||||
truncate_head, truncate_line, format_size,
|
||||
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 100
|
||||
|
||||
|
||||
class Grep(BaseTool):
|
||||
"""Tool for searching file contents"""
|
||||
|
||||
name: str = "grep"
|
||||
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Search pattern (regex or literal string)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search (default: current directory)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
|
||||
},
|
||||
"ignoreCase": {
|
||||
"type": "boolean",
|
||||
"description": "Case-insensitive search (default: false)"
|
||||
},
|
||||
"literal": {
|
||||
"type": "boolean",
|
||||
"description": "Treat pattern as literal string instead of regex (default: false)"
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to show before and after each match (default: 0)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.rg_path = self._find_ripgrep()
|
||||
|
||||
def _find_ripgrep(self) -> Optional[str]:
|
||||
"""Find ripgrep executable"""
|
||||
try:
|
||||
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute grep search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
if not self.rg_path:
|
||||
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
|
||||
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
glob = args.get("glob")
|
||||
ignore_case = args.get("ignoreCase", False)
|
||||
literal = args.get("literal", False)
|
||||
context = args.get("context", 0)
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
# Build ripgrep command
|
||||
cmd = [
|
||||
self.rg_path,
|
||||
'--json',
|
||||
'--line-number',
|
||||
'--color=never',
|
||||
'--hidden'
|
||||
]
|
||||
|
||||
if ignore_case:
|
||||
cmd.append('--ignore-case')
|
||||
|
||||
if literal:
|
||||
cmd.append('--fixed-strings')
|
||||
|
||||
if glob:
|
||||
cmd.extend(['--glob', glob])
|
||||
|
||||
cmd.extend([pattern, absolute_path])
|
||||
|
||||
try:
|
||||
# Execute ripgrep
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self.cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Parse JSON output
|
||||
matches = []
|
||||
match_count = 0
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if event.get('type') == 'match':
|
||||
data = event.get('data', {})
|
||||
file_path = data.get('path', {}).get('text')
|
||||
line_number = data.get('line_number')
|
||||
|
||||
if file_path and line_number:
|
||||
matches.append({
|
||||
'file': file_path,
|
||||
'line': line_number
|
||||
})
|
||||
match_count += 1
|
||||
|
||||
if match_count >= limit:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if match_count == 0:
|
||||
return ToolResult.success({"message": "No matches found", "matches": []})
|
||||
|
||||
# Format output with context
|
||||
output_lines = []
|
||||
lines_truncated = False
|
||||
is_directory = os.path.isdir(absolute_path)
|
||||
|
||||
for match in matches:
|
||||
file_path = match['file']
|
||||
line_number = match['line']
|
||||
|
||||
# Format file path
|
||||
if is_directory:
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
else:
|
||||
relative_path = os.path.basename(file_path)
|
||||
|
||||
# Read file and get context
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
file_lines = f.read().split('\n')
|
||||
|
||||
# Calculate context range
|
||||
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
|
||||
end = min(len(file_lines), line_number + context) if context > 0 else line_number
|
||||
|
||||
# Format lines with context
|
||||
for i in range(start, end):
|
||||
line_text = file_lines[i].replace('\r', '')
|
||||
|
||||
# Truncate long lines
|
||||
truncated_text, was_truncated = truncate_line(line_text)
|
||||
if was_truncated:
|
||||
lines_truncated = True
|
||||
|
||||
# Format output
|
||||
current_line = i + 1
|
||||
if current_line == line_number:
|
||||
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
|
||||
else:
|
||||
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
|
||||
|
||||
except Exception:
|
||||
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
|
||||
|
||||
# Apply byte truncation
|
||||
raw_output = '\n'.join(output_lines)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if match_count >= limit:
|
||||
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["match_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if lines_truncated:
|
||||
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
|
||||
details["lines_truncated"] = True
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"match_count": match_count,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail("Error: Search timed out after 30 seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing grep: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
@@ -4,8 +4,6 @@ Memory get tool
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@@ -22,7 +20,7 @@ class MemoryGetTool(BaseTool):
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2026-01-01.md')"
|
||||
"description": "Relative path to the memory file (e.g. 'memory/2026-01-01.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
@@ -70,7 +68,8 @@ class MemoryGetTool(BaseTool):
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
if not path.startswith('memory/') and not path.startswith('/'):
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
# WebFetch Tool
|
||||
|
||||
免费的网页抓取工具,无需 API Key,可直接抓取网页内容并提取可读文本。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- ✅ **完全免费** - 无需任何 API Key
|
||||
- 🌐 **智能提取** - 自动提取网页主要内容
|
||||
- 📝 **格式转换** - 支持 HTML → Markdown/Text
|
||||
- 🚀 **高性能** - 内置请求重试和超时控制
|
||||
- 🎯 **智能降级** - 优先使用 Readability,可降级到基础提取
|
||||
|
||||
## 安装依赖
|
||||
|
||||
### 基础功能(必需)
|
||||
```bash
|
||||
pip install requests
|
||||
```
|
||||
|
||||
### 增强功能(推荐)
|
||||
```bash
|
||||
# 安装 readability-lxml 以获得更好的内容提取效果
|
||||
pip install readability-lxml
|
||||
|
||||
# 安装 html2text 以获得更好的 Markdown 转换
|
||||
pip install html2text
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 在代码中使用
|
||||
|
||||
```python
|
||||
from agent.tools.web_fetch import WebFetch
|
||||
|
||||
# 创建工具实例
|
||||
tool = WebFetch()
|
||||
|
||||
# 抓取网页(默认返回 Markdown 格式)
|
||||
result = tool.execute({
|
||||
"url": "https://example.com"
|
||||
})
|
||||
|
||||
# 抓取并转换为纯文本
|
||||
result = tool.execute({
|
||||
"url": "https://example.com",
|
||||
"extract_mode": "text",
|
||||
"max_chars": 5000
|
||||
})
|
||||
|
||||
if result.status == "success":
|
||||
data = result.result
|
||||
print(f"标题: {data['title']}")
|
||||
print(f"内容: {data['text']}")
|
||||
```
|
||||
|
||||
### 2. 在 Agent 中使用
|
||||
|
||||
工具会自动加载到 Agent 的工具列表中:
|
||||
|
||||
```python
|
||||
from agent.tools import WebFetch
|
||||
|
||||
tools = [
|
||||
WebFetch(),
|
||||
# ... 其他工具
|
||||
]
|
||||
|
||||
agent = create_agent(tools=tools)
|
||||
```
|
||||
|
||||
### 3. 通过 Skills 使用
|
||||
|
||||
创建一个 skill 文件 `skills/web-fetch/SKILL.md`:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: web-fetch
|
||||
emoji: 🌐
|
||||
always: true
|
||||
---
|
||||
|
||||
# 网页内容获取
|
||||
|
||||
使用 web_fetch 工具获取网页内容。
|
||||
|
||||
## 使用场景
|
||||
|
||||
- 需要读取某个网页的内容
|
||||
- 需要提取文章正文
|
||||
- 需要获取网页信息
|
||||
|
||||
## 示例
|
||||
|
||||
<example>
|
||||
用户: 帮我看看 https://example.com 这个网页讲了什么
|
||||
助手: <tool_use name="web_fetch">
|
||||
<url>https://example.com</url>
|
||||
<extract_mode>markdown</extract_mode>
|
||||
</tool_use>
|
||||
</example>
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
|
||||
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| `url` | string | ✅ | - | 要抓取的 URL(http/https) |
|
||||
| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown` 或 `text` |
|
||||
| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100) |
|
||||
|
||||
## 返回结果
|
||||
|
||||
```python
|
||||
{
|
||||
"url": "https://example.com", # 最终 URL(处理重定向后)
|
||||
"status": 200, # HTTP 状态码
|
||||
"content_type": "text/html", # 内容类型
|
||||
"title": "Example Domain", # 页面标题
|
||||
"extractor": "readability", # 提取器:readability/basic/raw
|
||||
"extract_mode": "markdown", # 提取模式
|
||||
"text": "# Example Domain\n\n...", # 提取的文本内容
|
||||
"length": 1234, # 文本长度
|
||||
"truncated": false, # 是否被截断
|
||||
"warning": "..." # 警告信息(如果有)
|
||||
}
|
||||
```
|
||||
|
||||
## 与其他搜索工具的对比
|
||||
|
||||
| 工具 | 需要 API Key | 功能 | 成本 |
|
||||
|------|-------------|------|------|
|
||||
| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 |
|
||||
| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 |
|
||||
| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 |
|
||||
| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 |
|
||||
| `google_search` | ✅ 需要 | Google 搜索 API | 付费 |
|
||||
|
||||
## 技术细节
|
||||
|
||||
### 内容提取策略
|
||||
|
||||
1. **Readability 模式**(推荐)
|
||||
- 使用 Mozilla 的 Readability 算法
|
||||
- 自动识别文章主体内容
|
||||
- 过滤广告、导航栏等噪音
|
||||
|
||||
2. **Basic 模式**(降级)
|
||||
- 简单的 HTML 标签清理
|
||||
- 正则表达式提取文本
|
||||
- 适用于简单页面
|
||||
|
||||
3. **Raw 模式**
|
||||
- 用于非 HTML 内容
|
||||
- 直接返回原始内容
|
||||
|
||||
### 错误处理
|
||||
|
||||
工具会自动处理以下情况:
|
||||
- ✅ HTTP 重定向(最多 3 次)
|
||||
- ✅ 请求超时(默认 30 秒)
|
||||
- ✅ 网络错误自动重试
|
||||
- ✅ 内容提取失败降级
|
||||
|
||||
## 测试
|
||||
|
||||
运行测试脚本:
|
||||
|
||||
```bash
|
||||
cd agent/tools/web_fetch
|
||||
python test_web_fetch.py
|
||||
```
|
||||
|
||||
## 配置选项
|
||||
|
||||
在创建工具时可以传入配置:
|
||||
|
||||
```python
|
||||
tool = WebFetch(config={
|
||||
"timeout": 30, # 请求超时时间(秒)
|
||||
"max_redirects": 3, # 最大重定向次数
|
||||
"user_agent": "..." # 自定义 User-Agent
|
||||
})
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 为什么推荐安装 readability-lxml?
|
||||
|
||||
A: readability-lxml 提供更好的内容提取质量,能够:
|
||||
- 自动识别文章主体
|
||||
- 过滤广告和导航栏
|
||||
- 保留文章结构
|
||||
|
||||
没有它也能工作,但提取质量会下降。
|
||||
|
||||
### Q: 与 clawdbot 的 web_fetch 有什么区别?
|
||||
|
||||
A: 本实现参考了 clawdbot 的设计,主要区别:
|
||||
- Python 实现(clawdbot 是 TypeScript)
|
||||
- 简化了一些高级特性(如 Firecrawl 集成)
|
||||
- 保留了核心的免费功能
|
||||
- 更容易集成到现有项目
|
||||
|
||||
### Q: 可以抓取需要登录的页面吗?
|
||||
|
||||
A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。
|
||||
|
||||
## 参考
|
||||
|
||||
- [Mozilla Readability](https://github.com/mozilla/readability)
|
||||
- [Clawdbot Web Tools](https://github.com/moltbot/moltbot)
|
||||
@@ -1,3 +0,0 @@
|
||||
from .web_fetch import WebFetch
|
||||
|
||||
__all__ = ['WebFetch']
|
||||
@@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# WebFetch 工具依赖安装脚本
|
||||
|
||||
echo "=================================="
|
||||
echo "WebFetch 工具依赖安装"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
|
||||
# 检查 Python 版本
|
||||
python_version=$(python3 --version 2>&1 | awk '{print $2}')
|
||||
echo "✓ Python 版本: $python_version"
|
||||
echo ""
|
||||
|
||||
# 安装基础依赖
|
||||
echo "📦 安装基础依赖..."
|
||||
python3 -m pip install requests
|
||||
|
||||
# 检查是否成功
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ requests 安装成功"
|
||||
else
|
||||
echo "❌ requests 安装失败"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# 安装推荐依赖
|
||||
echo "📦 安装推荐依赖(提升内容提取质量)..."
|
||||
python3 -m pip install readability-lxml html2text
|
||||
|
||||
# 检查是否成功
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ readability-lxml 和 html2text 安装成功"
|
||||
else
|
||||
echo "⚠️ 推荐依赖安装失败,但不影响基础功能"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=================================="
|
||||
echo "安装完成!"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
echo "运行测试:"
|
||||
echo " python3 agent/tools/web_fetch/test_web_fetch.py"
|
||||
echo ""
|
||||
@@ -1,365 +0,0 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from URLs
|
||||
Supports HTML to Markdown/Text conversion using Mozilla's Readability
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching and extracting readable content from web pages"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "HTTP or HTTPS URL to fetch"
|
||||
},
|
||||
"extract_mode": {
|
||||
"type": "string",
|
||||
"description": "Extraction mode: 'markdown' (default) or 'text'",
|
||||
"enum": ["markdown", "text"],
|
||||
"default": "markdown"
|
||||
},
|
||||
"max_chars": {
|
||||
"type": "integer",
|
||||
"description": "Maximum characters to return (default: 50000)",
|
||||
"minimum": 100,
|
||||
"default": 50000
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.timeout = self.config.get("timeout", 20)
|
||||
self.max_redirects = self.config.get("max_redirects", 3)
|
||||
self.user_agent = self.config.get(
|
||||
"user_agent",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
# Setup session with retry strategy
|
||||
self.session = self._create_session()
|
||||
|
||||
# Check if readability-lxml is available
|
||||
self.readability_available = self._check_readability()
|
||||
|
||||
def _create_session(self) -> requests.Session:
|
||||
"""Create a requests session with retry strategy"""
|
||||
session = requests.Session()
|
||||
|
||||
# Retry strategy - handles failed requests, not redirects
|
||||
retry_strategy = Retry(
|
||||
total=3,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["GET", "HEAD"]
|
||||
)
|
||||
|
||||
# HTTPAdapter handles retries; requests handles redirects via allow_redirects
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
# Set max redirects on session
|
||||
session.max_redirects = self.max_redirects
|
||||
|
||||
return session
|
||||
|
||||
def _check_readability(self) -> bool:
|
||||
"""Check if readability-lxml is available"""
|
||||
try:
|
||||
from readability import Document
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"readability-lxml not installed. Install with: pip install readability-lxml\n"
|
||||
"Falling back to basic HTML extraction."
|
||||
)
|
||||
return False
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute web fetch operation
|
||||
|
||||
:param args: Contains url, extract_mode, and max_chars parameters
|
||||
:return: Extracted content or error message
|
||||
"""
|
||||
url = args.get("url", "").strip()
|
||||
extract_mode = args.get("extract_mode", "markdown").lower()
|
||||
max_chars = args.get("max_chars", 50000)
|
||||
|
||||
if not url:
|
||||
return ToolResult.fail("Error: url parameter is required")
|
||||
|
||||
# Validate URL
|
||||
if not self._is_valid_url(url):
|
||||
return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}")
|
||||
|
||||
# Validate extract_mode
|
||||
if extract_mode not in ["markdown", "text"]:
|
||||
extract_mode = "markdown"
|
||||
|
||||
# Validate max_chars
|
||||
if not isinstance(max_chars, int) or max_chars < 100:
|
||||
max_chars = 50000
|
||||
|
||||
try:
|
||||
# Fetch the URL
|
||||
response = self._fetch_url(url)
|
||||
|
||||
# Extract content
|
||||
result = self._extract_content(
|
||||
html=response.text,
|
||||
url=response.url,
|
||||
status_code=response.status_code,
|
||||
content_type=response.headers.get("content-type", ""),
|
||||
extract_mode=extract_mode,
|
||||
max_chars=max_chars
|
||||
)
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds")
|
||||
except requests.exceptions.TooManyRedirects:
|
||||
return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})")
|
||||
except requests.exceptions.RequestException as e:
|
||||
return ToolResult.fail(f"Error fetching URL: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Web fetch error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: {str(e)}")
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format"""
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return result.scheme in ["http", "https"] and bool(result.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _fetch_url(self, url: str) -> requests.Response:
|
||||
"""
|
||||
Fetch URL with proper headers and error handling
|
||||
|
||||
:param url: URL to fetch
|
||||
:return: Response object
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": self.user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
# Note: requests library handles redirects automatically
|
||||
# The max_redirects is set in the session's adapter (HTTPAdapter)
|
||||
response = self.session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
allow_redirects=True
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _extract_content(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract readable content from HTML
|
||||
|
||||
:param html: HTML content
|
||||
:param url: Original URL
|
||||
:param status_code: HTTP status code
|
||||
:param content_type: Content type header
|
||||
:param extract_mode: 'markdown' or 'text'
|
||||
:param max_chars: Maximum characters to return
|
||||
:return: Extracted content and metadata
|
||||
"""
|
||||
# Check content type
|
||||
if "text/html" not in content_type.lower():
|
||||
# Non-HTML content
|
||||
text = html[:max_chars]
|
||||
truncated = len(html) > max_chars
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"extractor": "raw",
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated,
|
||||
"message": f"Non-HTML content (type: {content_type})"
|
||||
}
|
||||
|
||||
# Extract readable content from HTML
|
||||
if self.readability_available:
|
||||
return self._extract_with_readability(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
else:
|
||||
return self._extract_basic(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
|
||||
def _extract_with_readability(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Extract content using Mozilla's Readability"""
|
||||
try:
|
||||
from readability import Document
|
||||
|
||||
# Parse with Readability
|
||||
doc = Document(html)
|
||||
title = doc.title()
|
||||
content_html = doc.summary()
|
||||
|
||||
# Convert to markdown or text
|
||||
if extract_mode == "markdown":
|
||||
text = self._html_to_markdown(content_html)
|
||||
else:
|
||||
text = self._html_to_text(content_html)
|
||||
|
||||
# Truncate if needed
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"extractor": "readability",
|
||||
"extract_mode": extract_mode,
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Readability extraction failed: {e}")
|
||||
# Fallback to basic extraction
|
||||
return self._extract_basic(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
|
||||
def _extract_basic(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Basic HTML extraction without Readability"""
|
||||
# Extract title
|
||||
title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL)
|
||||
title = title_match.group(1).strip() if title_match else "Untitled"
|
||||
|
||||
# Remove script and style tags
|
||||
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Remove HTML tags
|
||||
text = re.sub(r'<[^>]+>', ' ', text)
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
# Truncate if needed
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"extractor": "basic",
|
||||
"extract_mode": extract_mode,
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated,
|
||||
"warning": "Using basic extraction. Install readability-lxml for better results."
|
||||
}
|
||||
|
||||
def _html_to_markdown(self, html: str) -> str:
|
||||
"""Convert HTML to Markdown (basic implementation)"""
|
||||
try:
|
||||
# Try to use html2text if available
|
||||
import html2text
|
||||
h = html2text.HTML2Text()
|
||||
h.ignore_links = False
|
||||
h.ignore_images = False
|
||||
h.body_width = 0 # Don't wrap lines
|
||||
return h.handle(html)
|
||||
except ImportError:
|
||||
# Fallback to basic conversion
|
||||
return self._html_to_text(html)
|
||||
|
||||
def _html_to_text(self, html: str) -> str:
|
||||
"""Convert HTML to plain text"""
|
||||
# Remove script and style tags
|
||||
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Convert common tags to text equivalents
|
||||
text = re.sub(r'<br\s*/?>', '\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'<p[^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'</p>', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'<h[1-6][^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'</h[1-6]>', '\n', text, flags=re.IGNORECASE)
|
||||
|
||||
# Remove all other HTML tags
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
|
||||
# Decode HTML entities
|
||||
import html
|
||||
text = html.unescape(text)
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
def close(self):
|
||||
"""Close the session"""
|
||||
if hasattr(self, 'session'):
|
||||
self.session.close()
|
||||
@@ -283,14 +283,53 @@ class AgentBridge:
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
workspace_root=workspace_root,
|
||||
embedding_provider="local", # Use local embedding (no API key needed)
|
||||
embedding_model="all-MiniLM-L6-v2"
|
||||
)
|
||||
# 从 config.json 读取 OpenAI 配置
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
|
||||
# Create memory manager with the config
|
||||
memory_manager = MemoryManager(memory_config)
|
||||
# 尝试初始化 OpenAI embedding provider
|
||||
embedding_provider = None
|
||||
if openai_api_key:
|
||||
try:
|
||||
from agent.memory import create_embedding_provider
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
logger.info(f"[AgentBridge] OpenAI embedding initialized")
|
||||
except Exception as embed_error:
|
||||
logger.warning(f"[AgentBridge] OpenAI embedding failed: {embed_error}")
|
||||
logger.info(f"[AgentBridge] Using keyword-only search")
|
||||
else:
|
||||
logger.info(f"[AgentBridge] No OpenAI API key, using keyword-only search")
|
||||
|
||||
# 创建 memory config
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
# 创建 memory manager
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
|
||||
# 初始化时执行一次 sync,确保数据库有数据
|
||||
import asyncio
|
||||
try:
|
||||
# 尝试在当前事件循环中执行
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建任务
|
||||
asyncio.create_task(memory_manager.sync())
|
||||
logger.info("[AgentBridge] Memory sync scheduled")
|
||||
else:
|
||||
# 如果没有运行的循环,直接执行
|
||||
loop.run_until_complete(memory_manager.sync())
|
||||
logger.info("[AgentBridge] Memory synced successfully")
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建新的
|
||||
asyncio.run(memory_manager.sync())
|
||||
logger.info("[AgentBridge] Memory synced successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Memory sync failed: {e}")
|
||||
|
||||
# Create memory tools
|
||||
memory_tools = [
|
||||
@@ -420,23 +459,65 @@ class AgentBridge:
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
workspace_root=workspace_root,
|
||||
embedding_provider="local",
|
||||
embedding_model="all-MiniLM-L6-v2"
|
||||
)
|
||||
# 从 config.json 读取 OpenAI 配置
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
|
||||
# 尝试初始化 OpenAI embedding provider
|
||||
embedding_provider = None
|
||||
if openai_api_key:
|
||||
try:
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
logger.info(f"[AgentBridge] OpenAI embedding initialized for session {session_id}")
|
||||
except Exception as embed_error:
|
||||
logger.warning(f"[AgentBridge] OpenAI embedding failed for session {session_id}: {embed_error}")
|
||||
logger.info(f"[AgentBridge] Using keyword-only search for session {session_id}")
|
||||
else:
|
||||
logger.info(f"[AgentBridge] No OpenAI API key, using keyword-only search for session {session_id}")
|
||||
|
||||
# 创建 memory config
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
# 创建 memory manager
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
|
||||
# 初始化时执行一次 sync,确保数据库有数据
|
||||
import asyncio
|
||||
try:
|
||||
# 尝试在当前事件循环中执行
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建任务
|
||||
asyncio.create_task(memory_manager.sync())
|
||||
logger.info(f"[AgentBridge] Memory sync scheduled for session {session_id}")
|
||||
else:
|
||||
# 如果没有运行的循环,直接执行
|
||||
loop.run_until_complete(memory_manager.sync())
|
||||
logger.info(f"[AgentBridge] Memory synced successfully for session {session_id}")
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建新的
|
||||
asyncio.run(memory_manager.sync())
|
||||
logger.info(f"[AgentBridge] Memory synced successfully for session {session_id}")
|
||||
except Exception as sync_error:
|
||||
logger.warning(f"[AgentBridge] Memory sync failed for session {session_id}: {sync_error}")
|
||||
|
||||
memory_manager = MemoryManager(memory_config)
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
|
||||
logger.warning(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
|
||||
import traceback
|
||||
logger.warning(f"[AgentBridge] Memory init traceback: {traceback.format_exc()}")
|
||||
|
||||
# Load tools
|
||||
from agent.tools import ToolManager
|
||||
|
||||
@@ -158,7 +158,7 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
|
||||
# Build request parameters for ChatCompletion
|
||||
request_params = {
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-4.1"),
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
|
||||
"top_p": kwargs.get("top_p", 1),
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
## 插件说明
|
||||
|
||||
利用百度UNIT实现智能对话
|
||||
|
||||
- 1.解决问题:chatgpt无法处理的指令,交给百度UNIT处理如:天气,日期时间,数学运算等
|
||||
- 2.如问时间:现在几点钟,今天几号
|
||||
- 3.如问天气:明天广州天气怎么样,这个周末深圳会不会下雨
|
||||
- 4.如问数学运算:23+45=多少,100-23=多少,35转化为二进制是多少?
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 获取apikey
|
||||
|
||||
在百度UNIT官网上自己创建应用,申请百度机器人,可以把预先训练好的模型导入到自己的应用中,
|
||||
|
||||
see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087095fc10c8377aaf https://console.bce.baidu.com/ai平台申请
|
||||
|
||||
### 配置文件
|
||||
|
||||
将文件夹中`config.json.template`复制为`config.json`。
|
||||
|
||||
在其中填写百度UNIT官网上获取应用的API Key和Secret Key
|
||||
|
||||
``` json
|
||||
{
|
||||
"service_id": "s...", #"机器人ID"
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
}
|
||||
```
|
||||
@@ -1 +0,0 @@
|
||||
from .bdunit import *
|
||||
@@ -1,252 +0,0 @@
|
||||
# encoding:utf-8
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from uuid import getnode as get_mac
|
||||
|
||||
import requests
|
||||
|
||||
import plugins
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from plugins import *
|
||||
|
||||
"""利用百度UNIT实现智能对话
|
||||
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
|
||||
"""
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="BDunit",
|
||||
desire_priority=0,
|
||||
hidden=True,
|
||||
desc="Baidu unit bot system",
|
||||
version="0.1",
|
||||
author="jackson",
|
||||
)
|
||||
class BDunit(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
conf = super().load_config()
|
||||
if not conf:
|
||||
raise Exception("config.json not found")
|
||||
self.service_id = conf["service_id"]
|
||||
self.api_key = conf["api_key"]
|
||||
self.secret_key = conf["secret_key"]
|
||||
self.access_token = self.get_token()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[BDunit] inited")
|
||||
except Exception as e:
|
||||
logger.warn("[BDunit] init failed, ignore ")
|
||||
raise e
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
if e_context["context"].type != ContextType.TEXT:
|
||||
return
|
||||
|
||||
content = e_context["context"].content
|
||||
logger.debug("[BDunit] on_handle_context. content: %s" % content)
|
||||
parsed = self.getUnit2(content)
|
||||
intent = self.getIntent(parsed)
|
||||
if intent: # 找到意图
|
||||
logger.debug("[BDunit] Baidu_AI Intent= %s", intent)
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = self.getSay(parsed)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
else:
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
help_text = "本插件会处理询问实时日期时间,天气,数学运算等问题,这些技能由您的百度智能对话UNIT决定\n"
|
||||
return help_text
|
||||
|
||||
def get_token(self):
|
||||
"""获取访问百度UUNIT 的access_token
|
||||
#param api_key: UNIT apk_key
|
||||
#param secret_key: UNIT secret_key
|
||||
Returns:
|
||||
string: access_token
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
|
||||
payload = ""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
# print(response.text)
|
||||
return response.json()["access_token"]
|
||||
|
||||
def getUnit(self, query):
|
||||
"""
|
||||
NLU 解析version 3.0
|
||||
:param query: 用户的指令字符串
|
||||
:returns: UNIT 解析结果。如果解析失败,返回 None
|
||||
"""
|
||||
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
|
||||
request = {
|
||||
"query": query,
|
||||
"user_id": str(get_mac())[:32],
|
||||
"terminal_id": "88888",
|
||||
}
|
||||
body = {
|
||||
"log_id": str(uuid.uuid1()),
|
||||
"version": "3.0",
|
||||
"service_id": self.service_id,
|
||||
"session_id": str(uuid.uuid1()),
|
||||
"request": request,
|
||||
}
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(url, json=body, headers=headers)
|
||||
return json.loads(response.text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def getUnit2(self, query):
|
||||
"""
|
||||
NLU 解析 version 2.0
|
||||
|
||||
:param query: 用户的指令字符串
|
||||
:returns: UNIT 解析结果。如果解析失败,返回 None
|
||||
"""
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
|
||||
request = {"query": query, "user_id": str(get_mac())[:32]}
|
||||
body = {
|
||||
"log_id": str(uuid.uuid1()),
|
||||
"version": "2.0",
|
||||
"service_id": self.service_id,
|
||||
"session_id": str(uuid.uuid1()),
|
||||
"request": request,
|
||||
}
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(url, json=body, headers=headers)
|
||||
return json.loads(response.text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def getIntent(self, parsed):
|
||||
"""
|
||||
提取意图
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:returns: 意图数组
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
try:
|
||||
return parsed["result"]["response_list"][0]["schema"]["intent"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
|
||||
def hasIntent(self, parsed, intent):
|
||||
"""
|
||||
判断是否包含某个意图
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: True: 包含; False: 不包含
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
for response in response_list:
|
||||
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def getSlots(self, parsed, intent=""):
|
||||
"""
|
||||
提取某个意图的所有词槽
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: 词槽列表。你可以通过 name 属性筛选词槽,
|
||||
再通过 normalized_word 属性取出相应的值
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
if intent == "":
|
||||
try:
|
||||
return parsed["result"]["response_list"][0]["schema"]["slots"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return []
|
||||
for response in response_list:
|
||||
if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
return response["schema"]["slots"]
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
def getSlotWords(self, parsed, intent, name):
|
||||
"""
|
||||
找出命中某个词槽的内容
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:param name: 词槽名
|
||||
:returns: 命中该词槽的值的列表。
|
||||
"""
|
||||
slots = self.getSlots(parsed, intent)
|
||||
words = []
|
||||
for slot in slots:
|
||||
if slot["name"] == name:
|
||||
words.append(slot["normalized_word"])
|
||||
return words
|
||||
|
||||
def getSayByConfidence(self, parsed):
|
||||
"""
|
||||
提取 UNIT 置信度最高的回复文本
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:returns: UNIT 的回复文本
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
answer = {}
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent_confidence" in response["schema"]
|
||||
and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
|
||||
):
|
||||
answer = response
|
||||
return answer["action_list"][0]["say"]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def getSay(self, parsed, intent=""):
|
||||
"""
|
||||
提取 UNIT 的回复文本
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: UNIT 的回复文本
|
||||
"""
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
if intent == "":
|
||||
try:
|
||||
return response_list[0]["action_list"][0]["say"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
for response in response_list:
|
||||
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
try:
|
||||
return response["action_list"][0]["say"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"service_id": "s...",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
}
|
||||
@@ -8,17 +8,6 @@
|
||||
"reply_filter": true,
|
||||
"reply_action": "ignore"
|
||||
},
|
||||
"tool": {
|
||||
"tools": [
|
||||
"url-get",
|
||||
"meteo-weather"
|
||||
],
|
||||
"kwargs": {
|
||||
"top_k_results": 2,
|
||||
"no_default": false,
|
||||
"model_name": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
"linkai": {
|
||||
"group_app_map": {
|
||||
"测试群1": "default",
|
||||
|
||||
@@ -214,6 +214,19 @@ These files contain established best practices for effective skill design.
|
||||
|
||||
To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`.
|
||||
|
||||
**Available Base Tools**:
|
||||
|
||||
The agent has access to these core tools that you can leverage in your skill:
|
||||
- **bash**: Execute shell commands (use for curl, ls, grep, sed, awk, bc for calculations, etc.)
|
||||
- **read**: Read file contents
|
||||
- **write**: Write files
|
||||
- **edit**: Edit files with search/replace
|
||||
|
||||
**Minimize Dependencies**:
|
||||
- ✅ **Prefer bash + curl** for HTTP API calls (no Python dependencies)
|
||||
- ✅ **Use bash tools** (grep, sed, awk) for text processing
|
||||
- ✅ **Keep scripts simple** - if bash can do it, no need for Python (document packages/versions if Python is used)
|
||||
|
||||
**Important Guidelines**:
|
||||
- **scripts/**: Only create scripts that will be executed. Test all scripts before including.
|
||||
- **references/**: ONLY create if documentation is too large for SKILL.md (>500 lines). Most skills don't need this.
|
||||
|
||||
49
skills/web-fetch/SKILL.md
Normal file
49
skills/web-fetch/SKILL.md
Normal file
@@ -0,0 +1,49 @@
|
||||
---
|
||||
name: web-fetch
|
||||
description: Fetch and extract readable content from web pages
|
||||
homepage: https://github.com/zhayujie/chatgpt-on-wechat
|
||||
metadata:
|
||||
emoji: 🌐
|
||||
requires:
|
||||
bins: ["curl"]
|
||||
---
|
||||
|
||||
# Web Fetch
|
||||
|
||||
Fetch and extract readable content from web pages using curl and basic text processing.
|
||||
|
||||
## Usage
|
||||
|
||||
Use the provided script to fetch a URL and extract its content:
|
||||
|
||||
```bash
|
||||
bash scripts/fetch.sh <url> [output_file]
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `url`: The HTTP/HTTPS URL to fetch (required)
|
||||
- `output_file`: Optional file to save the output (default: stdout)
|
||||
|
||||
**Returns:**
|
||||
- Extracted page content with title and text
|
||||
|
||||
## Examples
|
||||
|
||||
### Fetch a web page
|
||||
```bash
|
||||
bash scripts/fetch.sh "https://example.com"
|
||||
```
|
||||
|
||||
### Save to file
|
||||
```bash
|
||||
bash scripts/fetch.sh "https://example.com" output.txt
|
||||
cat output.txt
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Uses curl for HTTP requests (timeout: 20s)
|
||||
- Extracts title and basic text content
|
||||
- Removes HTML tags and scripts
|
||||
- Works with any standard web page
|
||||
- No external dependencies beyond curl
|
||||
54
skills/web-fetch/scripts/fetch.sh
Executable file
54
skills/web-fetch/scripts/fetch.sh
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env bash
|
||||
# Fetch and extract readable content from a web page
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
url="${1:-}"
|
||||
output_file="${2:-}"
|
||||
|
||||
if [ -z "$url" ]; then
|
||||
echo "Error: URL is required"
|
||||
echo "Usage: bash fetch.sh <url> [output_file]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate URL
|
||||
if [[ ! "$url" =~ ^https?:// ]]; then
|
||||
echo "Error: Invalid URL (must start with http:// or https://)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch the page with curl
|
||||
html=$(curl -sS -L --max-time 20 \
|
||||
-H "User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" \
|
||||
-H "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" \
|
||||
"$url" 2>&1) || {
|
||||
echo "Error: Failed to fetch URL: $url"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Extract title
|
||||
title=$(echo "$html" | grep -oP '(?<=<title>).*?(?=</title>)' | head -1 || echo "Untitled")
|
||||
|
||||
# Remove script and style tags
|
||||
text=$(echo "$html" | sed 's/<script[^>]*>.*<\/script>//gI' | sed 's/<style[^>]*>.*<\/style>//gI')
|
||||
|
||||
# Remove HTML tags
|
||||
text=$(echo "$text" | sed 's/<[^>]*>//g')
|
||||
|
||||
# Clean up whitespace
|
||||
text=$(echo "$text" | tr -s ' ' | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
|
||||
|
||||
# Format output
|
||||
result="Title: $title
|
||||
|
||||
Content:
|
||||
$text"
|
||||
|
||||
# Output to file or stdout
|
||||
if [ -n "$output_file" ]; then
|
||||
echo "$result" > "$output_file"
|
||||
echo "Content saved to: $output_file"
|
||||
else
|
||||
echo "$result"
|
||||
fi
|
||||
Reference in New Issue
Block a user