feat: improve the memory system

This commit is contained in:
zhayujie
2026-02-01 17:04:46 +08:00
parent 4a1fae3cb4
commit c693e39196
29 changed files with 373 additions and 1596 deletions

View File

@@ -6,5 +6,6 @@ Provides long-term memory capabilities with hybrid search (vector + keyword)
from agent.memory.manager import MemoryManager
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
from agent.memory.embedding import create_embedding_provider
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']

View File

@@ -41,6 +41,10 @@ class MemoryConfig:
enable_auto_sync: bool = True
sync_on_search: bool = True
# Memory flush config (独立于模型 context window)
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
def get_workspace(self) -> Path:
"""Get workspace root directory"""
return Path(self.workspace_root)

View File

@@ -4,20 +4,19 @@ Embedding providers for memory
Supports OpenAI and local embedding models
"""
from typing import List, Optional
from abc import ABC, abstractmethod
import hashlib
import json
from abc import ABC, abstractmethod
from typing import List, Optional
class EmbeddingProvider(ABC):
"""Base class for embedding providers"""
@abstractmethod
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
pass
@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
@@ -31,7 +30,7 @@ class EmbeddingProvider(ABC):
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider"""
"""OpenAI embedding provider using REST API"""
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
"""
@@ -45,87 +44,58 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
# Lazy import to avoid dependency issues
try:
from openai import OpenAI
self.client = OpenAI(api_key=api_key, base_url=api_base)
except ImportError:
raise ImportError("OpenAI package not installed. Install with: pip install openai")
if not self.api_key:
raise ValueError("OpenAI API key is required")
# Set dimensions based on model
self._dimensions = 1536 if "small" in model else 3072
def _call_api(self, input_data):
"""Call OpenAI embedding API using requests"""
import requests
url = f"{self.api_base}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
data = {
"input": input_data,
"model": self.model
}
response = requests.post(url, headers=headers, json=data, timeout=30)
response.raise_for_status()
return response.json()
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
response = self.client.embeddings.create(
input=text,
model=self.model
)
return response.data[0].embedding
result = self._call_api(text)
return result["data"][0]["embedding"]
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
response = self.client.embeddings.create(
input=texts,
model=self.model
)
return [item.embedding for item in response.data]
result = self._call_api(texts)
return [item["embedding"] for item in result["data"]]
@property
def dimensions(self) -> int:
return self._dimensions
class LocalEmbeddingProvider(EmbeddingProvider):
"""Local embedding provider using sentence-transformers"""
def __init__(self, model: str = "all-MiniLM-L6-v2"):
"""
Initialize local embedding provider
Args:
model: Model name from sentence-transformers
"""
self.model_name = model
try:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model)
self._dimensions = self.model.get_sentence_embedding_dimension()
except ImportError:
raise ImportError(
"sentence-transformers not installed. "
"Install with: pip install sentence-transformers"
)
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
embedding = self.model.encode(text, convert_to_numpy=True)
return embedding.tolist()
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
embeddings = self.model.encode(texts, convert_to_numpy=True)
return embeddings.tolist()
@property
def dimensions(self) -> int:
return self._dimensions
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
class EmbeddingCache:
"""Cache for embeddings to avoid recomputation"""
def __init__(self):
self.cache = {}
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
"""Get cached embedding"""
key = self._compute_key(text, provider, model)
@@ -156,20 +126,23 @@ def create_embedding_provider(
"""
Factory function to create embedding provider
Only supports OpenAI embedding via REST API.
If initialization fails, caller should fall back to keyword-only search.
Args:
provider: Provider name ("openai" or "local")
model: Model name (provider-specific)
api_key: API key for remote providers
api_base: API base URL for remote providers
provider: Provider name (only "openai" is supported)
model: Model name (default: text-embedding-3-small)
api_key: OpenAI API key (required)
api_base: API base URL (default: https://api.openai.com/v1)
Returns:
EmbeddingProvider instance
Raises:
ValueError: If provider is not "openai" or api_key is missing
"""
if provider == "openai":
model = model or "text-embedding-3-small"
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
elif provider == "local":
model = model or "all-MiniLM-L6-v2"
return LocalEmbeddingProvider(model=model)
else:
raise ValueError(f"Unknown embedding provider: {provider}")
if provider != "openai":
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
model = model or "text-embedding-3-small"
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)

View File

@@ -70,8 +70,9 @@ class MemoryManager:
except Exception as e:
# Embedding provider failed, but that's OK
# We can still use keyword search and file operations
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
print(f" Memory will work with keyword search only (no semantic search)")
from common.log import logger
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
# Initialize memory flush manager
workspace_dir = self.config.get_workspace()
@@ -135,13 +136,19 @@ class MemoryManager:
# Perform vector search (if embedding provider available)
vector_results = []
if self.embedding_provider:
query_embedding = self.embedding_provider.embed(query)
vector_results = self.storage.search_vector(
query_embedding=query_embedding,
user_id=user_id,
scopes=scopes,
limit=max_results * 2 # Get more candidates for merging
)
try:
from common.log import logger
query_embedding = self.embedding_provider.embed(query)
vector_results = self.storage.search_vector(
query_embedding=query_embedding,
user_id=user_id,
scopes=scopes,
limit=max_results * 2 # Get more candidates for merging
)
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
except Exception as e:
from common.log import logger
logger.warning(f"[MemoryManager] Vector search failed: {e}")
# Perform keyword search
keyword_results = self.storage.search_keyword(
@@ -150,6 +157,8 @@ class MemoryManager:
scopes=scopes,
limit=max_results * 2
)
from common.log import logger
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
# Merge results
merged = self._merge_results(
@@ -356,30 +365,30 @@ class MemoryManager:
def should_flush_memory(
self,
current_tokens: int,
context_window: int = 128000,
reserve_tokens: int = 20000,
soft_threshold: int = 4000
current_tokens: int = 0
) -> bool:
"""
Check if memory flush should be triggered
独立的 flush 触发机制,不依赖模型 context window。
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
Args:
current_tokens: Current session token count
context_window: Model's context window size (default: 128K)
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
soft_threshold: Trigger N tokens before threshold (default: 4K)
Returns:
True if memory flush should run
"""
return self.flush_manager.should_flush(
current_tokens=current_tokens,
context_window=context_window,
reserve_tokens=reserve_tokens,
soft_threshold=soft_threshold
token_threshold=self.config.flush_token_threshold,
turn_threshold=self.config.flush_turn_threshold
)
def increment_turn(self):
"""增加对话轮数计数(每次用户消息+AI回复算一轮"""
self.flush_manager.increment_turn()
async def execute_memory_flush(
self,
agent_executor,

View File

@@ -41,46 +41,42 @@ class MemoryFlushManager:
# Tracking
self.last_flush_token_count: Optional[int] = None
self.last_flush_timestamp: Optional[datetime] = None
self.turn_count: int = 0 # 对话轮数计数器
def should_flush(
self,
current_tokens: int,
context_window: int,
reserve_tokens: int = 20000,
soft_threshold: int = 4000
current_tokens: int = 0,
token_threshold: int = 50000,
turn_threshold: int = 20
) -> bool:
"""
Determine if memory flush should be triggered
Similar to clawdbot's shouldRunMemoryFlush logic:
threshold = contextWindow - reserveTokens - softThreshold
独立的 flush 触发机制,不依赖模型 context window:
- Token 阈值: 达到 50K tokens 时触发
- 轮次阈值: 达到 20 轮对话时触发
Args:
current_tokens: Current session token count
context_window: Model's context window size
reserve_tokens: Reserve tokens for compaction overhead
soft_threshold: Trigger flush N tokens before threshold
token_threshold: Token threshold to trigger flush (default: 50K)
turn_threshold: Turn threshold to trigger flush (default: 20)
Returns:
True if flush should run
"""
if current_tokens <= 0:
return False
# 检查 token 阈值
if current_tokens > 0 and current_tokens >= token_threshold:
# 避免重复 flush
if self.last_flush_token_count is not None:
if current_tokens <= self.last_flush_token_count + 5000:
return False
return True
threshold = max(0, context_window - reserve_tokens - soft_threshold)
if threshold <= 0:
return False
# 检查轮次阈值
if self.turn_count >= turn_threshold:
return True
# Check if we've crossed the threshold
if current_tokens < threshold:
return False
# Avoid duplicate flush in same compaction cycle
if self.last_flush_token_count is not None:
if current_tokens <= self.last_flush_token_count + soft_threshold:
return False
return True
return False
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
"""
@@ -130,7 +126,12 @@ class MemoryFlushManager:
f"Pre-compaction memory flush. "
f"Store durable memories now (use memory/{today}.md for daily notes; "
f"create memory/ if needed). "
f"If nothing to store, reply with NO_REPLY."
f"\n\n"
f"重要提示:\n"
f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
)
def create_flush_system_prompt(self) -> str:
@@ -142,6 +143,20 @@ class MemoryFlushManager:
return (
"Pre-compaction memory flush turn. "
"The session is near auto-compaction; capture durable memories to disk. "
"\n\n"
"记忆写入原则:\n"
"1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens\n"
" - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
" - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
"\n"
"2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
" - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
"\n"
"3. 判断标准:\n"
" - 这个信息未来会经常用到吗?→ MEMORY.md\n"
" - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
" - 这是临时性的、不重要的内容吗?→ 不记录\n"
"\n"
"You may reply, but usually NO_REPLY is correct."
)
@@ -180,6 +195,7 @@ class MemoryFlushManager:
# Track flush
self.last_flush_token_count = current_tokens
self.last_flush_timestamp = datetime.now()
self.turn_count = 0 # 重置轮数计数器
return True
@@ -187,6 +203,10 @@ class MemoryFlushManager:
print(f"Memory flush failed: {e}")
return False
def increment_turn(self):
"""增加对话轮数计数"""
self.turn_count += 1
def get_status(self) -> dict:
"""Get memory flush status"""
return {

View File

@@ -179,8 +179,8 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
tool_map = {}
tool_descriptions = {
"read": "读取文件内容",
"write": "创建或覆盖文件",
"edit": "精确编辑文件内容",
"write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit",
"edit": "精确编辑文件(追加、修改、删除部分内容",
"ls": "列出目录内容",
"grep": "在文件中搜索内容",
"find": "按照模式查找文件",
@@ -305,17 +305,18 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
"",
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
"",
"1. 使用 `memory_search` 在 MEMORY.md 和 memory/*.md 中搜索",
"2. 然后使用 `memory_get` 只拉取需要的行",
"3. 如果搜索后仍然信心不足,告诉用户你已经检查过了",
"1. 不确定信息位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
"2. 已知文件和大致位置 → 直接用 `memory_get` 读取相应的行",
"3. search 无结果 → 尝试用 `memory_get` 读取最近两天的记忆文件",
"",
"**记忆文件结构**:",
"- `MEMORY.md`: 长期记忆,包含重要的背景信息",
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的对话和事件",
"- `MEMORY.md`: 长期记忆(已自动加载,无需主动读取)",
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
"",
"**使用原则**:",
"- 自然使用记忆,就像你本来就知道",
"- 不要主动提起或列举记忆,除非用户明确询问",
"- 自然使用记忆,就像你本来就知道; 不用刻意提起或列举记忆,除非用户提起相关内容",
"- 追加内容到现有记忆文件 → 必须用 `edit` 工具(先 read 读取,再 edit 追加)",
"- 创建新的记忆文件 → 可以用 `write` 工具已有记忆文件不可直接write会覆盖删除",
"",
]

View File

@@ -7,9 +7,9 @@ import json
import time
from typing import List, Dict, Any, Optional, Callable
from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
class AgentStreamExecutor:
@@ -164,30 +164,24 @@ class AgentStreamExecutor:
self._emit_event("turn_start", {"turn": turn})
# Check if memory flush is needed (before calling LLM)
# 使用独立的 flush 阈值50K tokens 或 20 轮)
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
usage = self.agent.last_usage
if usage and 'input_tokens' in usage:
current_tokens = usage.get('input_tokens', 0)
context_window = self.agent._get_model_context_window()
# Use configured reserve_tokens or calculate based on context window
reserve_tokens = self.agent._get_context_reserve_tokens()
# Use smaller soft_threshold to trigger flush earlier (e.g., at 50K tokens)
soft_threshold = 10000 # Trigger 10K tokens before limit
if self.agent.memory_manager.should_flush_memory(
current_tokens=current_tokens,
context_window=context_window,
reserve_tokens=reserve_tokens,
soft_threshold=soft_threshold
current_tokens=current_tokens
):
self._emit_event("memory_flush_start", {
"current_tokens": current_tokens,
"threshold": context_window - reserve_tokens - soft_threshold
"turn_count": self.agent.memory_manager.flush_manager.turn_count
})
# TODO: Execute memory flush in background
# This would require async support
logger.info(f"Memory flush recommended at {current_tokens} tokens")
logger.info(
f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}")
# Call LLM
assistant_msg, tool_calls = self._call_llm_stream()
@@ -321,6 +315,10 @@ class AgentStreamExecutor:
logger.info(f"🏁 完成({turn}轮)")
self._emit_event("agent_end", {"final_response": final_response})
# 每轮对话结束后增加计数(用户消息+AI回复=1轮
if self.agent.memory_manager:
self.agent.memory_manager.increment_turn()
return final_response
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]:
@@ -664,9 +662,11 @@ class AgentStreamExecutor:
if not self.messages or not self.agent:
return
# Get context window and reserve tokens from agent
# Get context window from agent (based on model)
context_window = self.agent._get_model_context_window()
reserve_tokens = self.agent._get_context_reserve_tokens()
# Reserve 10% for response generation
reserve_tokens = int(context_window * 0.1)
max_tokens = context_window - reserve_tokens
# Estimate current tokens

View File

@@ -2,25 +2,17 @@
from agent.tools.base_tool import BaseTool
from agent.tools.tool_manager import ToolManager
# Import basic tools (no external dependencies)
from agent.tools.calculator.calculator import Calculator
# Import file operation tools
from agent.tools.read.read import Read
from agent.tools.write.write import Write
from agent.tools.edit.edit import Edit
from agent.tools.bash.bash import Bash
from agent.tools.grep.grep import Grep
from agent.tools.find.find import Find
from agent.tools.ls.ls import Ls
# Import memory tools
from agent.tools.memory.memory_search import MemorySearchTool
from agent.tools.memory.memory_get import MemoryGetTool
# Import web tools
from agent.tools.web_fetch.web_fetch import WebFetch
# Import tools with optional dependencies
def _import_optional_tools():
"""Import tools that have optional dependencies"""
@@ -80,17 +72,13 @@ BrowserTool = _import_browser_tool()
__all__ = [
'BaseTool',
'ToolManager',
'Calculator',
'Read',
'Write',
'Edit',
'Bash',
'Grep',
'Find',
'Ls',
'MemorySearchTool',
'MemoryGetTool',
'WebFetch',
# Optional tools (may be None if dependencies not available)
'GoogleSearch',
'FileSave',

View File

@@ -1,58 +0,0 @@
import math
from agent.tools.base_tool import BaseTool, ToolResult
class Calculator(BaseTool):
name: str = "calculator"
description: str = "A tool to perform basic mathematical calculations."
params: dict = {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
"Ensure your input is a valid Python expression, it will be evaluated directly."
}
},
"required": ["expression"]
}
config: dict = {}
def execute(self, args: dict) -> ToolResult:
try:
# Get the expression
expression = args["expression"]
# Create a safe local environment containing only basic math functions
safe_locals = {
"abs": abs,
"round": round,
"max": max,
"min": min,
"pow": pow,
"sqrt": math.sqrt,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"pi": math.pi,
"e": math.e,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"floor": math.floor,
"ceil": math.ceil
}
# Safely evaluate the expression
result = eval(expression, {"__builtins__": {}}, safe_locals)
return ToolResult.success({
"result": result,
"expression": expression
})
except Exception as e:
return ToolResult.success({
"error": str(e),
"expression": args.get("expression", "")
})

View File

@@ -33,7 +33,7 @@ class Edit(BaseTool):
},
"oldText": {
"type": "string",
"description": "Exact text to find and replace (must match exactly)"
"description": "Exact text to find and replace (must match exactly, cannot be empty). To append to end of file, include the last few lines as oldText."
},
"newText": {
"type": "string",

View File

@@ -1,3 +0,0 @@
from .find import Find
__all__ = ['Find']

View File

@@ -1,177 +0,0 @@
"""
Find tool - Search for files by glob pattern
"""
import os
import glob as glob_module
from typing import Dict, Any, List
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
DEFAULT_LIMIT = 1000
class Find(BaseTool):
"""Tool for finding files by pattern"""
name: str = "find"
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
params: dict = {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
},
"path": {
"type": "string",
"description": "Directory to search in (default: current directory)"
},
"limit": {
"type": "integer",
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
}
},
"required": ["pattern"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file search
:param args: Search parameters
:return: Search results or error
"""
pattern = args.get("pattern", "").strip()
search_path = args.get("path", ".").strip()
limit = args.get("limit", DEFAULT_LIMIT)
if not pattern:
return ToolResult.fail("Error: pattern parameter is required")
# Resolve search path
absolute_path = self._resolve_path(search_path)
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: Path not found: {search_path}")
if not os.path.isdir(absolute_path):
return ToolResult.fail(f"Error: Not a directory: {search_path}")
try:
# Load .gitignore patterns
ignore_patterns = self._load_gitignore(absolute_path)
# Search for files
results = []
search_pattern = os.path.join(absolute_path, pattern)
# Use glob with recursive support
for file_path in glob_module.glob(search_pattern, recursive=True):
# Skip if matches ignore patterns
if self._should_ignore(file_path, absolute_path, ignore_patterns):
continue
# Get relative path
relative_path = os.path.relpath(file_path, absolute_path)
# Add trailing slash for directories
if os.path.isdir(file_path):
relative_path += '/'
results.append(relative_path)
if len(results) >= limit:
break
if not results:
return ToolResult.success({"message": "No files found matching pattern", "files": []})
# Sort results
results.sort()
# Format output
raw_output = '\n'.join(results)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
result_limit_reached = len(results) >= limit
if result_limit_reached:
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
details["result_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"file_count": len(results),
"details": details if details else None
})
except Exception as e:
return ToolResult.fail(f"Error executing find: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))
def _load_gitignore(self, directory: str) -> List[str]:
"""Load .gitignore patterns from directory"""
patterns = []
gitignore_path = os.path.join(directory, '.gitignore')
if os.path.exists(gitignore_path):
try:
with open(gitignore_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
patterns.append(line)
except:
pass
# Add common ignore patterns
patterns.extend([
'.git',
'__pycache__',
'*.pyc',
'node_modules',
'.DS_Store'
])
return patterns
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
"""Check if file should be ignored based on patterns"""
relative_path = os.path.relpath(file_path, base_path)
for pattern in patterns:
# Simple pattern matching
if pattern in relative_path:
return True
# Check if it's a directory pattern
if pattern.endswith('/'):
if relative_path.startswith(pattern.rstrip('/')):
return True
return False

View File

@@ -1,3 +0,0 @@
from .grep import Grep
__all__ = ['Grep']

View File

@@ -1,248 +0,0 @@
"""
Grep tool - Search file contents for patterns
Uses ripgrep (rg) for fast searching
"""
import os
import re
import subprocess
import json
from typing import Dict, Any, List, Optional
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import (
truncate_head, truncate_line, format_size,
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
)
DEFAULT_LIMIT = 100
class Grep(BaseTool):
"""Tool for searching file contents"""
name: str = "grep"
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
params: dict = {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Search pattern (regex or literal string)"
},
"path": {
"type": "string",
"description": "Directory or file to search (default: current directory)"
},
"glob": {
"type": "string",
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
},
"ignoreCase": {
"type": "boolean",
"description": "Case-insensitive search (default: false)"
},
"literal": {
"type": "boolean",
"description": "Treat pattern as literal string instead of regex (default: false)"
},
"context": {
"type": "integer",
"description": "Number of lines to show before and after each match (default: 0)"
},
"limit": {
"type": "integer",
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
}
},
"required": ["pattern"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
self.rg_path = self._find_ripgrep()
def _find_ripgrep(self) -> Optional[str]:
"""Find ripgrep executable"""
try:
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
if result.returncode == 0:
return result.stdout.strip()
except:
pass
return None
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute grep search
:param args: Search parameters
:return: Search results or error
"""
if not self.rg_path:
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
pattern = args.get("pattern", "").strip()
search_path = args.get("path", ".").strip()
glob = args.get("glob")
ignore_case = args.get("ignoreCase", False)
literal = args.get("literal", False)
context = args.get("context", 0)
limit = args.get("limit", DEFAULT_LIMIT)
if not pattern:
return ToolResult.fail("Error: pattern parameter is required")
# Resolve search path
absolute_path = self._resolve_path(search_path)
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: Path not found: {search_path}")
# Build ripgrep command
cmd = [
self.rg_path,
'--json',
'--line-number',
'--color=never',
'--hidden'
]
if ignore_case:
cmd.append('--ignore-case')
if literal:
cmd.append('--fixed-strings')
if glob:
cmd.extend(['--glob', glob])
cmd.extend([pattern, absolute_path])
try:
# Execute ripgrep
result = subprocess.run(
cmd,
cwd=self.cwd,
capture_output=True,
text=True,
timeout=30
)
# Parse JSON output
matches = []
match_count = 0
for line in result.stdout.splitlines():
if not line.strip():
continue
try:
event = json.loads(line)
if event.get('type') == 'match':
data = event.get('data', {})
file_path = data.get('path', {}).get('text')
line_number = data.get('line_number')
if file_path and line_number:
matches.append({
'file': file_path,
'line': line_number
})
match_count += 1
if match_count >= limit:
break
except json.JSONDecodeError:
continue
if match_count == 0:
return ToolResult.success({"message": "No matches found", "matches": []})
# Format output with context
output_lines = []
lines_truncated = False
is_directory = os.path.isdir(absolute_path)
for match in matches:
file_path = match['file']
line_number = match['line']
# Format file path
if is_directory:
relative_path = os.path.relpath(file_path, absolute_path)
else:
relative_path = os.path.basename(file_path)
# Read file and get context
try:
with open(file_path, 'r', encoding='utf-8') as f:
file_lines = f.read().split('\n')
# Calculate context range
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
end = min(len(file_lines), line_number + context) if context > 0 else line_number
# Format lines with context
for i in range(start, end):
line_text = file_lines[i].replace('\r', '')
# Truncate long lines
truncated_text, was_truncated = truncate_line(line_text)
if was_truncated:
lines_truncated = True
# Format output
current_line = i + 1
if current_line == line_number:
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
else:
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
except Exception:
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
# Apply byte truncation
raw_output = '\n'.join(output_lines)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
if match_count >= limit:
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
details["match_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if lines_truncated:
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
details["lines_truncated"] = True
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"match_count": match_count,
"details": details if details else None
})
except subprocess.TimeoutExpired:
return ToolResult.fail("Error: Search timed out after 30 seconds")
except Exception as e:
return ToolResult.fail(f"Error executing grep: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -4,8 +4,6 @@ Memory get tool
Allows agents to read specific sections from memory files
"""
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool
@@ -22,7 +20,7 @@ class MemoryGetTool(BaseTool):
"properties": {
"path": {
"type": "string",
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2026-01-01.md')"
"description": "Relative path to the memory file (e.g. 'memory/2026-01-01.md')"
},
"start_line": {
"type": "integer",
@@ -70,7 +68,8 @@ class MemoryGetTool(BaseTool):
workspace_dir = self.memory_manager.config.get_workspace()
# Auto-prepend memory/ if not present and not absolute path
if not path.startswith('memory/') and not path.startswith('/'):
# Exception: MEMORY.md is in the root directory
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
path = f'memory/{path}'
file_path = workspace_dir / path

View File

@@ -1,212 +0,0 @@
# WebFetch Tool
免费的网页抓取工具,无需 API Key可直接抓取网页内容并提取可读文本。
## 功能特性
-**完全免费** - 无需任何 API Key
- 🌐 **智能提取** - 自动提取网页主要内容
- 📝 **格式转换** - 支持 HTML → Markdown/Text
- 🚀 **高性能** - 内置请求重试和超时控制
- 🎯 **智能降级** - 优先使用 Readability可降级到基础提取
## 安装依赖
### 基础功能(必需)
```bash
pip install requests
```
### 增强功能(推荐)
```bash
# 安装 readability-lxml 以获得更好的内容提取效果
pip install readability-lxml
# 安装 html2text 以获得更好的 Markdown 转换
pip install html2text
```
## 使用方法
### 1. 在代码中使用
```python
from agent.tools.web_fetch import WebFetch
# 创建工具实例
tool = WebFetch()
# 抓取网页(默认返回 Markdown 格式)
result = tool.execute({
"url": "https://example.com"
})
# 抓取并转换为纯文本
result = tool.execute({
"url": "https://example.com",
"extract_mode": "text",
"max_chars": 5000
})
if result.status == "success":
data = result.result
print(f"标题: {data['title']}")
print(f"内容: {data['text']}")
```
### 2. 在 Agent 中使用
工具会自动加载到 Agent 的工具列表中:
```python
from agent.tools import WebFetch
tools = [
WebFetch(),
# ... 其他工具
]
agent = create_agent(tools=tools)
```
### 3. 通过 Skills 使用
创建一个 skill 文件 `skills/web-fetch/SKILL.md`
```markdown
---
name: web-fetch
emoji: 🌐
always: true
---
# 网页内容获取
使用 web_fetch 工具获取网页内容。
## 使用场景
- 需要读取某个网页的内容
- 需要提取文章正文
- 需要获取网页信息
## 示例
<example>
用户: 帮我看看 https://example.com 这个网页讲了什么
助手: <tool_use name="web_fetch">
<url>https://example.com</url>
<extract_mode>markdown</extract_mode>
</tool_use>
</example>
```
## 参数说明
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `url` | string | ✅ | - | 要抓取的 URLhttp/https |
| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown``text` |
| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100 |
## 返回结果
```python
{
"url": "https://example.com", # 最终 URL处理重定向后
"status": 200, # HTTP 状态码
"content_type": "text/html", # 内容类型
"title": "Example Domain", # 页面标题
"extractor": "readability", # 提取器readability/basic/raw
"extract_mode": "markdown", # 提取模式
"text": "# Example Domain\n\n...", # 提取的文本内容
"length": 1234, # 文本长度
"truncated": false, # 是否被截断
"warning": "..." # 警告信息(如果有)
}
```
## 与其他搜索工具的对比
| 工具 | 需要 API Key | 功能 | 成本 |
|------|-------------|------|------|
| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 |
| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 |
| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 |
| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 |
| `google_search` | ✅ 需要 | Google 搜索 API | 付费 |
## 技术细节
### 内容提取策略
1. **Readability 模式**(推荐)
- 使用 Mozilla 的 Readability 算法
- 自动识别文章主体内容
- 过滤广告、导航栏等噪音
2. **Basic 模式**(降级)
- 简单的 HTML 标签清理
- 正则表达式提取文本
- 适用于简单页面
3. **Raw 模式**
- 用于非 HTML 内容
- 直接返回原始内容
### 错误处理
工具会自动处理以下情况:
- ✅ HTTP 重定向(最多 3 次)
- ✅ 请求超时(默认 30 秒)
- ✅ 网络错误自动重试
- ✅ 内容提取失败降级
## 测试
运行测试脚本:
```bash
cd agent/tools/web_fetch
python test_web_fetch.py
```
## 配置选项
在创建工具时可以传入配置:
```python
tool = WebFetch(config={
"timeout": 30, # 请求超时时间(秒)
"max_redirects": 3, # 最大重定向次数
"user_agent": "..." # 自定义 User-Agent
})
```
## 常见问题
### Q: 为什么推荐安装 readability-lxml
A: readability-lxml 提供更好的内容提取质量,能够:
- 自动识别文章主体
- 过滤广告和导航栏
- 保留文章结构
没有它也能工作,但提取质量会下降。
### Q: 与 clawdbot 的 web_fetch 有什么区别?
A: 本实现参考了 clawdbot 的设计,主要区别:
- Python 实现clawdbot 是 TypeScript
- 简化了一些高级特性(如 Firecrawl 集成)
- 保留了核心的免费功能
- 更容易集成到现有项目
### Q: 可以抓取需要登录的页面吗?
A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。
## 参考
- [Mozilla Readability](https://github.com/mozilla/readability)
- [Clawdbot Web Tools](https://github.com/moltbot/moltbot)

View File

@@ -1,3 +0,0 @@
from .web_fetch import WebFetch
__all__ = ['WebFetch']

View File

@@ -1,47 +0,0 @@
#!/bin/bash
# WebFetch 工具依赖安装脚本
echo "=================================="
echo "WebFetch 工具依赖安装"
echo "=================================="
echo ""
# 检查 Python 版本
python_version=$(python3 --version 2>&1 | awk '{print $2}')
echo "✓ Python 版本: $python_version"
echo ""
# 安装基础依赖
echo "📦 安装基础依赖..."
python3 -m pip install requests
# 检查是否成功
if [ $? -eq 0 ]; then
echo "✅ requests 安装成功"
else
echo "❌ requests 安装失败"
exit 1
fi
echo ""
# 安装推荐依赖
echo "📦 安装推荐依赖(提升内容提取质量)..."
python3 -m pip install readability-lxml html2text
# 检查是否成功
if [ $? -eq 0 ]; then
echo "✅ readability-lxml 和 html2text 安装成功"
else
echo "⚠️ 推荐依赖安装失败,但不影响基础功能"
fi
echo ""
echo "=================================="
echo "安装完成!"
echo "=================================="
echo ""
echo "运行测试:"
echo " python3 agent/tools/web_fetch/test_web_fetch.py"
echo ""

View File

@@ -1,365 +0,0 @@
"""
Web Fetch tool - Fetch and extract readable content from URLs
Supports HTML to Markdown/Text conversion using Mozilla's Readability
"""
import os
import re
from typing import Dict, Any, Optional
from urllib.parse import urlparse
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
class WebFetch(BaseTool):
"""Tool for fetching and extracting readable content from web pages"""
name: str = "web_fetch"
description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata."
params: dict = {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "HTTP or HTTPS URL to fetch"
},
"extract_mode": {
"type": "string",
"description": "Extraction mode: 'markdown' (default) or 'text'",
"enum": ["markdown", "text"],
"default": "markdown"
},
"max_chars": {
"type": "integer",
"description": "Maximum characters to return (default: 50000)",
"minimum": 100,
"default": 50000
}
},
"required": ["url"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.timeout = self.config.get("timeout", 20)
self.max_redirects = self.config.get("max_redirects", 3)
self.user_agent = self.config.get(
"user_agent",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
)
# Setup session with retry strategy
self.session = self._create_session()
# Check if readability-lxml is available
self.readability_available = self._check_readability()
def _create_session(self) -> requests.Session:
"""Create a requests session with retry strategy"""
session = requests.Session()
# Retry strategy - handles failed requests, not redirects
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["GET", "HEAD"]
)
# HTTPAdapter handles retries; requests handles redirects via allow_redirects
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
# Set max redirects on session
session.max_redirects = self.max_redirects
return session
def _check_readability(self) -> bool:
"""Check if readability-lxml is available"""
try:
from readability import Document
return True
except ImportError:
logger.warning(
"readability-lxml not installed. Install with: pip install readability-lxml\n"
"Falling back to basic HTML extraction."
)
return False
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute web fetch operation
:param args: Contains url, extract_mode, and max_chars parameters
:return: Extracted content or error message
"""
url = args.get("url", "").strip()
extract_mode = args.get("extract_mode", "markdown").lower()
max_chars = args.get("max_chars", 50000)
if not url:
return ToolResult.fail("Error: url parameter is required")
# Validate URL
if not self._is_valid_url(url):
return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}")
# Validate extract_mode
if extract_mode not in ["markdown", "text"]:
extract_mode = "markdown"
# Validate max_chars
if not isinstance(max_chars, int) or max_chars < 100:
max_chars = 50000
try:
# Fetch the URL
response = self._fetch_url(url)
# Extract content
result = self._extract_content(
html=response.text,
url=response.url,
status_code=response.status_code,
content_type=response.headers.get("content-type", ""),
extract_mode=extract_mode,
max_chars=max_chars
)
return ToolResult.success(result)
except requests.exceptions.Timeout:
return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds")
except requests.exceptions.TooManyRedirects:
return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})")
except requests.exceptions.RequestException as e:
return ToolResult.fail(f"Error fetching URL: {str(e)}")
except Exception as e:
logger.error(f"Web fetch error: {e}", exc_info=True)
return ToolResult.fail(f"Error: {str(e)}")
def _is_valid_url(self, url: str) -> bool:
"""Validate URL format"""
try:
result = urlparse(url)
return result.scheme in ["http", "https"] and bool(result.netloc)
except Exception:
return False
def _fetch_url(self, url: str) -> requests.Response:
"""
Fetch URL with proper headers and error handling
:param url: URL to fetch
:return: Response object
"""
headers = {
"User-Agent": self.user_agent,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8",
"Accept-Encoding": "gzip, deflate",
"Connection": "keep-alive",
}
# Note: requests library handles redirects automatically
# The max_redirects is set in the session's adapter (HTTPAdapter)
response = self.session.get(
url,
headers=headers,
timeout=self.timeout,
allow_redirects=True
)
response.raise_for_status()
return response
def _extract_content(
self,
html: str,
url: str,
status_code: int,
content_type: str,
extract_mode: str,
max_chars: int
) -> Dict[str, Any]:
"""
Extract readable content from HTML
:param html: HTML content
:param url: Original URL
:param status_code: HTTP status code
:param content_type: Content type header
:param extract_mode: 'markdown' or 'text'
:param max_chars: Maximum characters to return
:return: Extracted content and metadata
"""
# Check content type
if "text/html" not in content_type.lower():
# Non-HTML content
text = html[:max_chars]
truncated = len(html) > max_chars
return {
"url": url,
"status": status_code,
"content_type": content_type,
"extractor": "raw",
"text": text,
"length": len(text),
"truncated": truncated,
"message": f"Non-HTML content (type: {content_type})"
}
# Extract readable content from HTML
if self.readability_available:
return self._extract_with_readability(
html, url, status_code, content_type, extract_mode, max_chars
)
else:
return self._extract_basic(
html, url, status_code, content_type, extract_mode, max_chars
)
def _extract_with_readability(
self,
html: str,
url: str,
status_code: int,
content_type: str,
extract_mode: str,
max_chars: int
) -> Dict[str, Any]:
"""Extract content using Mozilla's Readability"""
try:
from readability import Document
# Parse with Readability
doc = Document(html)
title = doc.title()
content_html = doc.summary()
# Convert to markdown or text
if extract_mode == "markdown":
text = self._html_to_markdown(content_html)
else:
text = self._html_to_text(content_html)
# Truncate if needed
truncated = len(text) > max_chars
if truncated:
text = text[:max_chars]
return {
"url": url,
"status": status_code,
"content_type": content_type,
"title": title,
"extractor": "readability",
"extract_mode": extract_mode,
"text": text,
"length": len(text),
"truncated": truncated
}
except Exception as e:
logger.warning(f"Readability extraction failed: {e}")
# Fallback to basic extraction
return self._extract_basic(
html, url, status_code, content_type, extract_mode, max_chars
)
def _extract_basic(
self,
html: str,
url: str,
status_code: int,
content_type: str,
extract_mode: str,
max_chars: int
) -> Dict[str, Any]:
"""Basic HTML extraction without Readability"""
# Extract title
title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL)
title = title_match.group(1).strip() if title_match else "Untitled"
# Remove script and style tags
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
# Remove HTML tags
text = re.sub(r'<[^>]+>', ' ', text)
# Clean up whitespace
text = re.sub(r'\s+', ' ', text)
text = text.strip()
# Truncate if needed
truncated = len(text) > max_chars
if truncated:
text = text[:max_chars]
return {
"url": url,
"status": status_code,
"content_type": content_type,
"title": title,
"extractor": "basic",
"extract_mode": extract_mode,
"text": text,
"length": len(text),
"truncated": truncated,
"warning": "Using basic extraction. Install readability-lxml for better results."
}
def _html_to_markdown(self, html: str) -> str:
"""Convert HTML to Markdown (basic implementation)"""
try:
# Try to use html2text if available
import html2text
h = html2text.HTML2Text()
h.ignore_links = False
h.ignore_images = False
h.body_width = 0 # Don't wrap lines
return h.handle(html)
except ImportError:
# Fallback to basic conversion
return self._html_to_text(html)
def _html_to_text(self, html: str) -> str:
"""Convert HTML to plain text"""
# Remove script and style tags
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
# Convert common tags to text equivalents
text = re.sub(r'<br\s*/?>', '\n', text, flags=re.IGNORECASE)
text = re.sub(r'<p[^>]*>', '\n\n', text, flags=re.IGNORECASE)
text = re.sub(r'</p>', '', text, flags=re.IGNORECASE)
text = re.sub(r'<h[1-6][^>]*>', '\n\n', text, flags=re.IGNORECASE)
text = re.sub(r'</h[1-6]>', '\n', text, flags=re.IGNORECASE)
# Remove all other HTML tags
text = re.sub(r'<[^>]+>', '', text)
# Decode HTML entities
import html
text = html.unescape(text)
# Clean up whitespace
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
text = re.sub(r' +', ' ', text)
text = text.strip()
return text
def close(self):
"""Close the session"""
if hasattr(self, 'session'):
self.session.close()

View File

@@ -283,14 +283,53 @@ class AgentBridge:
from agent.memory import MemoryManager, MemoryConfig
from agent.tools import MemorySearchTool, MemoryGetTool
memory_config = MemoryConfig(
workspace_root=workspace_root,
embedding_provider="local", # Use local embedding (no API key needed)
embedding_model="all-MiniLM-L6-v2"
)
# 从 config.json 读取 OpenAI 配置
openai_api_key = conf().get("open_ai_api_key", "")
openai_api_base = conf().get("open_ai_api_base", "")
# Create memory manager with the config
memory_manager = MemoryManager(memory_config)
# 尝试初始化 OpenAI embedding provider
embedding_provider = None
if openai_api_key:
try:
from agent.memory import create_embedding_provider
embedding_provider = create_embedding_provider(
provider="openai",
model="text-embedding-3-small",
api_key=openai_api_key,
api_base=openai_api_base or "https://api.openai.com/v1"
)
logger.info(f"[AgentBridge] OpenAI embedding initialized")
except Exception as embed_error:
logger.warning(f"[AgentBridge] OpenAI embedding failed: {embed_error}")
logger.info(f"[AgentBridge] Using keyword-only search")
else:
logger.info(f"[AgentBridge] No OpenAI API key, using keyword-only search")
# 创建 memory config
memory_config = MemoryConfig(workspace_root=workspace_root)
# 创建 memory manager
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
# 初始化时执行一次 sync确保数据库有数据
import asyncio
try:
# 尝试在当前事件循环中执行
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建任务
asyncio.create_task(memory_manager.sync())
logger.info("[AgentBridge] Memory sync scheduled")
else:
# 如果没有运行的循环,直接执行
loop.run_until_complete(memory_manager.sync())
logger.info("[AgentBridge] Memory synced successfully")
except RuntimeError:
# 没有事件循环,创建新的
asyncio.run(memory_manager.sync())
logger.info("[AgentBridge] Memory synced successfully")
except Exception as e:
logger.warning(f"[AgentBridge] Memory sync failed: {e}")
# Create memory tools
memory_tools = [
@@ -420,23 +459,65 @@ class AgentBridge:
memory_tools = []
try:
from agent.memory import MemoryManager, MemoryConfig
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
from agent.tools import MemorySearchTool, MemoryGetTool
memory_config = MemoryConfig(
workspace_root=workspace_root,
embedding_provider="local",
embedding_model="all-MiniLM-L6-v2"
)
# 从 config.json 读取 OpenAI 配置
openai_api_key = conf().get("open_ai_api_key", "")
openai_api_base = conf().get("open_ai_api_base", "")
# 尝试初始化 OpenAI embedding provider
embedding_provider = None
if openai_api_key:
try:
embedding_provider = create_embedding_provider(
provider="openai",
model="text-embedding-3-small",
api_key=openai_api_key,
api_base=openai_api_base or "https://api.openai.com/v1"
)
logger.info(f"[AgentBridge] OpenAI embedding initialized for session {session_id}")
except Exception as embed_error:
logger.warning(f"[AgentBridge] OpenAI embedding failed for session {session_id}: {embed_error}")
logger.info(f"[AgentBridge] Using keyword-only search for session {session_id}")
else:
logger.info(f"[AgentBridge] No OpenAI API key, using keyword-only search for session {session_id}")
# 创建 memory config
memory_config = MemoryConfig(workspace_root=workspace_root)
# 创建 memory manager
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
# 初始化时执行一次 sync确保数据库有数据
import asyncio
try:
# 尝试在当前事件循环中执行
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建任务
asyncio.create_task(memory_manager.sync())
logger.info(f"[AgentBridge] Memory sync scheduled for session {session_id}")
else:
# 如果没有运行的循环,直接执行
loop.run_until_complete(memory_manager.sync())
logger.info(f"[AgentBridge] Memory synced successfully for session {session_id}")
except RuntimeError:
# 没有事件循环,创建新的
asyncio.run(memory_manager.sync())
logger.info(f"[AgentBridge] Memory synced successfully for session {session_id}")
except Exception as sync_error:
logger.warning(f"[AgentBridge] Memory sync failed for session {session_id}: {sync_error}")
memory_manager = MemoryManager(memory_config)
memory_tools = [
MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager)
]
except Exception as e:
logger.debug(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
logger.warning(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
import traceback
logger.warning(f"[AgentBridge] Memory init traceback: {traceback.format_exc()}")
# Load tools
from agent.tools import ToolManager

View File

@@ -158,7 +158,7 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
# Build request parameters for ChatCompletion
request_params = {
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
"model": kwargs.get("model", conf().get("model") or "gpt-4.1"),
"messages": messages,
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
"top_p": kwargs.get("top_p", 1),

View File

@@ -1,30 +0,0 @@
## 插件说明
利用百度UNIT实现智能对话
- 1.解决问题chatgpt无法处理的指令交给百度UNIT处理如天气日期时间数学运算等
- 2.如问时间:现在几点钟,今天几号
- 3.如问天气:明天广州天气怎么样,这个周末深圳会不会下雨
- 4.如问数学运算23+45=多少100-23=多少35转化为二进制是多少
## 使用说明
### 获取apikey
在百度UNIT官网上自己创建应用申请百度机器人,可以把预先训练好的模型导入到自己的应用中,
see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087095fc10c8377aaf https://console.bce.baidu.com/ai平台申请
### 配置文件
将文件夹中`config.json.template`复制为`config.json`
在其中填写百度UNIT官网上获取应用的API Key和Secret Key
``` json
{
"service_id": "s...", #"机器人ID"
"api_key": "",
"secret_key": ""
}
```

View File

@@ -1 +0,0 @@
from .bdunit import *

View File

@@ -1,252 +0,0 @@
# encoding:utf-8
import json
import os
import uuid
from uuid import getnode as get_mac
import requests
import plugins
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from plugins import *
"""利用百度UNIT实现智能对话
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
"""
@plugins.register(
name="BDunit",
desire_priority=0,
hidden=True,
desc="Baidu unit bot system",
version="0.1",
author="jackson",
)
class BDunit(Plugin):
def __init__(self):
super().__init__()
try:
conf = super().load_config()
if not conf:
raise Exception("config.json not found")
self.service_id = conf["service_id"]
self.api_key = conf["api_key"]
self.secret_key = conf["secret_key"]
self.access_token = self.get_token()
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[BDunit] inited")
except Exception as e:
logger.warn("[BDunit] init failed, ignore ")
raise e
def on_handle_context(self, e_context: EventContext):
if e_context["context"].type != ContextType.TEXT:
return
content = e_context["context"].content
logger.debug("[BDunit] on_handle_context. content: %s" % content)
parsed = self.getUnit2(content)
intent = self.getIntent(parsed)
if intent: # 找到意图
logger.debug("[BDunit] Baidu_AI Intent= %s", intent)
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = self.getSay(parsed)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑
else:
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
def get_help_text(self, **kwargs):
help_text = "本插件会处理询问实时日期时间天气数学运算等问题这些技能由您的百度智能对话UNIT决定\n"
return help_text
def get_token(self):
"""获取访问百度UUNIT 的access_token
#param api_key: UNIT apk_key
#param secret_key: UNIT secret_key
Returns:
string: access_token
"""
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
payload = ""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
# print(response.text)
return response.json()["access_token"]
def getUnit(self, query):
"""
NLU 解析version 3.0
:param query: 用户的指令字符串
:returns: UNIT 解析结果。如果解析失败,返回 None
"""
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
request = {
"query": query,
"user_id": str(get_mac())[:32],
"terminal_id": "88888",
}
body = {
"log_id": str(uuid.uuid1()),
"version": "3.0",
"service_id": self.service_id,
"session_id": str(uuid.uuid1()),
"request": request,
}
try:
headers = {"Content-Type": "application/json"}
response = requests.post(url, json=body, headers=headers)
return json.loads(response.text)
except Exception:
return None
def getUnit2(self, query):
"""
NLU 解析 version 2.0
:param query: 用户的指令字符串
:returns: UNIT 解析结果。如果解析失败,返回 None
"""
url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
request = {"query": query, "user_id": str(get_mac())[:32]}
body = {
"log_id": str(uuid.uuid1()),
"version": "2.0",
"service_id": self.service_id,
"session_id": str(uuid.uuid1()),
"request": request,
}
try:
headers = {"Content-Type": "application/json"}
response = requests.post(url, json=body, headers=headers)
return json.loads(response.text)
except Exception:
return None
def getIntent(self, parsed):
"""
提取意图
:param parsed: UNIT 解析结果
:returns: 意图数组
"""
if parsed and "result" in parsed and "response_list" in parsed["result"]:
try:
return parsed["result"]["response_list"][0]["schema"]["intent"]
except Exception as e:
logger.warning(e)
return ""
else:
return ""
def hasIntent(self, parsed, intent):
"""
判断是否包含某个意图
:param parsed: UNIT 解析结果
:param intent: 意图的名称
:returns: True: 包含; False: 不包含
"""
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
for response in response_list:
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
return True
return False
else:
return False
def getSlots(self, parsed, intent=""):
"""
提取某个意图的所有词槽
:param parsed: UNIT 解析结果
:param intent: 意图的名称
:returns: 词槽列表。你可以通过 name 属性筛选词槽,
再通过 normalized_word 属性取出相应的值
"""
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
if intent == "":
try:
return parsed["result"]["response_list"][0]["schema"]["slots"]
except Exception as e:
logger.warning(e)
return []
for response in response_list:
if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
return response["schema"]["slots"]
return []
else:
return []
def getSlotWords(self, parsed, intent, name):
"""
找出命中某个词槽的内容
:param parsed: UNIT 解析结果
:param intent: 意图的名称
:param name: 词槽名
:returns: 命中该词槽的值的列表。
"""
slots = self.getSlots(parsed, intent)
words = []
for slot in slots:
if slot["name"] == name:
words.append(slot["normalized_word"])
return words
def getSayByConfidence(self, parsed):
"""
提取 UNIT 置信度最高的回复文本
:param parsed: UNIT 解析结果
:returns: UNIT 的回复文本
"""
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
answer = {}
for response in response_list:
if (
"schema" in response
and "intent_confidence" in response["schema"]
and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
):
answer = response
return answer["action_list"][0]["say"]
else:
return ""
def getSay(self, parsed, intent=""):
"""
提取 UNIT 的回复文本
:param parsed: UNIT 解析结果
:param intent: 意图的名称
:returns: UNIT 的回复文本
"""
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
if intent == "":
try:
return response_list[0]["action_list"][0]["say"]
except Exception as e:
logger.warning(e)
return ""
for response in response_list:
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
try:
return response["action_list"][0]["say"]
except Exception as e:
logger.warning(e)
return ""
return ""
else:
return ""

View File

@@ -1,5 +0,0 @@
{
"service_id": "s...",
"api_key": "",
"secret_key": ""
}

View File

@@ -8,17 +8,6 @@
"reply_filter": true,
"reply_action": "ignore"
},
"tool": {
"tools": [
"url-get",
"meteo-weather"
],
"kwargs": {
"top_k_results": 2,
"no_default": false,
"model_name": "gpt-3.5-turbo"
}
},
"linkai": {
"group_app_map": {
"测试群1": "default",

View File

@@ -214,6 +214,19 @@ These files contain established best practices for effective skill design.
To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`.
**Available Base Tools**:
The agent has access to these core tools that you can leverage in your skill:
- **bash**: Execute shell commands (use for curl, ls, grep, sed, awk, bc for calculations, etc.)
- **read**: Read file contents
- **write**: Write files
- **edit**: Edit files with search/replace
**Minimize Dependencies**:
-**Prefer bash + curl** for HTTP API calls (no Python dependencies)
-**Use bash tools** (grep, sed, awk) for text processing
-**Keep scripts simple** - if bash can do it, no need for Python (document packages/versions if Python is used)
**Important Guidelines**:
- **scripts/**: Only create scripts that will be executed. Test all scripts before including.
- **references/**: ONLY create if documentation is too large for SKILL.md (>500 lines). Most skills don't need this.

49
skills/web-fetch/SKILL.md Normal file
View File

@@ -0,0 +1,49 @@
---
name: web-fetch
description: Fetch and extract readable content from web pages
homepage: https://github.com/zhayujie/chatgpt-on-wechat
metadata:
emoji: 🌐
requires:
bins: ["curl"]
---
# Web Fetch
Fetch and extract readable content from web pages using curl and basic text processing.
## Usage
Use the provided script to fetch a URL and extract its content:
```bash
bash scripts/fetch.sh <url> [output_file]
```
**Parameters:**
- `url`: The HTTP/HTTPS URL to fetch (required)
- `output_file`: Optional file to save the output (default: stdout)
**Returns:**
- Extracted page content with title and text
## Examples
### Fetch a web page
```bash
bash scripts/fetch.sh "https://example.com"
```
### Save to file
```bash
bash scripts/fetch.sh "https://example.com" output.txt
cat output.txt
```
## Notes
- Uses curl for HTTP requests (timeout: 20s)
- Extracts title and basic text content
- Removes HTML tags and scripts
- Works with any standard web page
- No external dependencies beyond curl

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env bash
# Fetch and extract readable content from a web page
set -euo pipefail
url="${1:-}"
output_file="${2:-}"
if [ -z "$url" ]; then
echo "Error: URL is required"
echo "Usage: bash fetch.sh <url> [output_file]"
exit 1
fi
# Validate URL
if [[ ! "$url" =~ ^https?:// ]]; then
echo "Error: Invalid URL (must start with http:// or https://)"
exit 1
fi
# Fetch the page with curl
html=$(curl -sS -L --max-time 20 \
-H "User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" \
-H "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" \
"$url" 2>&1) || {
echo "Error: Failed to fetch URL: $url"
exit 1
}
# Extract title
title=$(echo "$html" | grep -oP '(?<=<title>).*?(?=</title>)' | head -1 || echo "Untitled")
# Remove script and style tags
text=$(echo "$html" | sed 's/<script[^>]*>.*<\/script>//gI' | sed 's/<style[^>]*>.*<\/style>//gI')
# Remove HTML tags
text=$(echo "$text" | sed 's/<[^>]*>//g')
# Clean up whitespace
text=$(echo "$text" | tr -s ' ' | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
# Format output
result="Title: $title
Content:
$text"
# Output to file or stdout
if [ -n "$output_file" ]; then
echo "$result" > "$output_file"
echo "Content saved to: $output_file"
else
echo "$result"
fi