mirror of
https://github.com/xszyou/Fay.git
synced 2026-03-12 17:51:28 +08:00
649 lines
26 KiB
Python
649 lines
26 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import sys
|
||
import threading
|
||
import inspect
|
||
import time
|
||
from contextlib import AsyncExitStack
|
||
from typing import Optional, Dict, Any, Tuple, List, Callable
|
||
|
||
from mcp import ClientSession
|
||
from mcp.client.sse import sse_client
|
||
from faymcp import tool_registry
|
||
|
||
# 尝试导入本地 stdio 传输
|
||
try:
|
||
from mcp.client.stdio import stdio_client, StdioServerParameters
|
||
HAS_STDIO = True
|
||
except Exception:
|
||
stdio_client = None
|
||
StdioServerParameters = None
|
||
HAS_STDIO = False
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _runtime_root_dir() -> str:
|
||
if getattr(sys, "frozen", False):
|
||
return os.path.abspath(os.path.dirname(sys.executable))
|
||
return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||
|
||
|
||
def _normalize_rel_path(path_value: Optional[str]) -> str:
|
||
return str(path_value or "").replace("\\", "/").strip().lower()
|
||
|
||
|
||
def _path_matches(path_value: Optional[str], expected_suffix: str) -> bool:
|
||
normalized = _normalize_rel_path(path_value)
|
||
suffix = _normalize_rel_path(expected_suffix)
|
||
if not normalized or not suffix:
|
||
return False
|
||
return normalized == suffix or normalized.endswith("/" + suffix)
|
||
|
||
|
||
def _is_python_command(command: Any) -> bool:
|
||
command_text = str(command or "").strip().lower()
|
||
if not command_text:
|
||
return False
|
||
if command_text in {"python", "python.exe", "pythonw", "pythonw.exe"}:
|
||
return True
|
||
return os.path.basename(command_text) in {"python", "python.exe", "pythonw", "pythonw.exe"}
|
||
|
||
|
||
def _resolve_existing_path(path_value: Optional[str], *base_dirs: Optional[str]) -> Optional[str]:
|
||
if not isinstance(path_value, str) or not path_value.strip():
|
||
return None
|
||
candidate = path_value.strip()
|
||
if os.path.isabs(candidate):
|
||
return candidate if os.path.exists(candidate) else None
|
||
for base_dir in base_dirs:
|
||
if not base_dir:
|
||
continue
|
||
abs_path = os.path.abspath(os.path.join(base_dir, candidate))
|
||
if os.path.exists(abs_path):
|
||
return abs_path
|
||
return None
|
||
|
||
|
||
_PACKAGED_STDIO_SERVERS = [
|
||
{
|
||
"script": "test/mcp_stdio_example.py",
|
||
"cwd": "",
|
||
"exe_relpath": os.path.join("mcp_bin", "mcp_stdio_example", "mcp_stdio_example.exe"),
|
||
},
|
||
{
|
||
"script": "mcp_servers/schedule_manager/server.py",
|
||
"cwd": "mcp_servers/schedule_manager",
|
||
"exe_relpath": os.path.join("mcp_bin", "schedule_manager_mcp", "schedule_manager_mcp.exe"),
|
||
},
|
||
{
|
||
"script": "mcp_servers/logseq/server.py",
|
||
"cwd": "mcp_servers/logseq",
|
||
"exe_relpath": os.path.join("mcp_bin", "logseq_mcp", "logseq_mcp.exe"),
|
||
},
|
||
{
|
||
"script": "mcp_servers/yueshen_rag/server.py",
|
||
"cwd": "mcp_servers/yueshen_rag",
|
||
"exe_relpath": os.path.join("mcp_bin", "yueshen_rag_mcp", "yueshen_rag_mcp.exe"),
|
||
},
|
||
{
|
||
"script": "mcp_servers/window_capture/server.py",
|
||
"cwd": "mcp_servers/window_capture",
|
||
"exe_relpath": os.path.join("mcp_bin", "window_capture_mcp", "window_capture_mcp.exe"),
|
||
},
|
||
{
|
||
"script": "mcp_servers/mcp-todo-server/server.py",
|
||
"cwd": "mcp_servers/mcp-todo-server",
|
||
"exe_relpath": os.path.join("mcp_bin", "todo_server_mcp", "todo_server_mcp.exe"),
|
||
},
|
||
{
|
||
"script": "mcp_servers/elderly_mcp/server.py",
|
||
"cwd": "mcp_servers/elderly_mcp",
|
||
"exe_relpath": os.path.join("mcp_bin", "elderly_mcp_server", "elderly_mcp_server.exe"),
|
||
},
|
||
]
|
||
|
||
|
||
def _resolve_packaged_stdio_binary(
|
||
runtime_root: str, command: Any, args: List[Any], cwd: Optional[str]
|
||
) -> Optional[Tuple[str, List[Any], str]]:
|
||
if not getattr(sys, "frozen", False):
|
||
return None
|
||
|
||
arg_paths = [str(arg) for arg in args if isinstance(arg, str) and arg and not str(arg).startswith("-")]
|
||
command_text = str(command or "")
|
||
for target in _PACKAGED_STDIO_SERVERS:
|
||
matched = any(_path_matches(arg, target["script"]) for arg in arg_paths)
|
||
if not matched and target["cwd"] and _path_matches(cwd, target["cwd"]):
|
||
if not arg_paths or any(os.path.basename(arg).lower() == "server.py" for arg in arg_paths):
|
||
matched = True
|
||
if not matched and _path_matches(command_text, target["exe_relpath"]):
|
||
matched = True
|
||
if not matched:
|
||
continue
|
||
|
||
exe_path = os.path.join(runtime_root, target["exe_relpath"])
|
||
if os.path.exists(exe_path):
|
||
return exe_path, [], os.path.dirname(exe_path)
|
||
return None
|
||
|
||
|
||
def _is_awaitable(obj: Any) -> bool:
|
||
try:
|
||
return inspect.isawaitable(obj)
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
async def _await_or_value(obj, timeout: Optional[float] = None):
|
||
"""如果是 awaitable 则等待(带超时),否则直接返回。"""
|
||
if _is_awaitable(obj):
|
||
if timeout is not None:
|
||
return await asyncio.wait_for(obj, timeout=timeout)
|
||
return await obj
|
||
return obj
|
||
|
||
|
||
class McpClient:
|
||
"""
|
||
兼容多版本 mcp 的 MCP 客户端,支持 SSE 与 STDIO。
|
||
修复:部分版本的 list_tools 返回同步 list,如果对其 await 会报
|
||
"object list can't be used in 'await' expression"。
|
||
"""
|
||
|
||
def __init__(self, server_url: Optional[str] = None, api_key: Optional[str] = None,
|
||
transport: str = "sse", stdio_config: Optional[Dict[str, Any]] = None,
|
||
server_id: Optional[int] = None, tools_refresh_interval: int = 60,
|
||
enabled_lookup: Optional[Callable[[str], bool]] = None):
|
||
self.server_url = server_url
|
||
self.api_key = api_key
|
||
self.transport = transport or "sse"
|
||
if self.transport not in ("sse", "stdio"):
|
||
self.transport = "sse"
|
||
self.stdio_config = stdio_config or {}
|
||
self.server_id = server_id
|
||
self._enabled_lookup = enabled_lookup
|
||
|
||
self.session: Optional[ClientSession] = None
|
||
self.tools: Optional[List[Any]] = None
|
||
self.connected = False
|
||
self.exit_stack: Optional[AsyncExitStack] = None
|
||
|
||
# timeouts (seconds)
|
||
self.init_timeout_seconds = 30
|
||
self.list_timeout_seconds = 30
|
||
self.call_timeout_seconds = 90
|
||
|
||
# dedicated event loop in background thread
|
||
self.event_loop = asyncio.new_event_loop()
|
||
t = threading.Thread(target=self._loop_runner, args=(self.event_loop,), daemon=True)
|
||
t.start()
|
||
self._loop_thread = t
|
||
|
||
self._stdio_errlog_file = None
|
||
self._manager_task: Optional[asyncio.Task] = None
|
||
self._disconnect_event: Optional[asyncio.Event] = None
|
||
self._connect_ready_future: Optional[asyncio.Future] = None
|
||
self._last_error: Optional[str] = None
|
||
self._resolved_stdio_config: Optional[Dict[str, Any]] = None
|
||
|
||
# tool availability cache
|
||
self.tools_refresh_interval = max(int(tools_refresh_interval), 5)
|
||
self._tool_cache: List[Dict[str, Any]] = []
|
||
self._tool_cache_timestamp: float = 0.0
|
||
self._tools_lock = threading.RLock()
|
||
self._tools_refresh_thread: Optional[threading.Thread] = None
|
||
self._tools_stop_event = threading.Event()
|
||
|
||
@staticmethod
|
||
def _loop_runner(loop: asyncio.AbstractEventLoop):
|
||
asyncio.set_event_loop(loop)
|
||
loop.run_forever()
|
||
|
||
def set_enabled_lookup(self, lookup: Optional[Callable[[str], bool]]) -> None:
|
||
"""Allow callers to update the enabled-state resolver at runtime."""
|
||
self._enabled_lookup = lookup
|
||
|
||
def _clone_tool_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||
clone = dict(entry)
|
||
if isinstance(clone.get("inputSchema"), dict):
|
||
clone["inputSchema"] = dict(clone["inputSchema"])
|
||
return clone
|
||
|
||
def _sanitize_tools(self, tools: Any) -> List[Dict[str, Any]]:
|
||
sanitized: List[Dict[str, Any]] = []
|
||
if not tools:
|
||
return sanitized
|
||
|
||
# Unwrap known container shapes (dict/object with .tools etc.)
|
||
container = tools
|
||
for _ in range(3):
|
||
if container is None:
|
||
break
|
||
if isinstance(container, dict):
|
||
inner = container.get("tools")
|
||
if inner is not None:
|
||
container = inner
|
||
continue
|
||
if hasattr(container, "tools"):
|
||
try:
|
||
inner = getattr(container, "tools")
|
||
except Exception:
|
||
inner = None
|
||
if inner is not None:
|
||
container = inner
|
||
continue
|
||
break
|
||
|
||
# Handle responses expressed as iterable of key/value pairs
|
||
try:
|
||
iterable = list(container)
|
||
except TypeError:
|
||
iterable = [container]
|
||
else:
|
||
if iterable and all(isinstance(item, tuple) and len(item) == 2 for item in iterable):
|
||
for key, value in iterable:
|
||
if key == "tools":
|
||
return self._sanitize_tools(value)
|
||
|
||
for tool in iterable:
|
||
try:
|
||
if hasattr(tool, "name"):
|
||
name = str(getattr(tool, "name", "")).strip()
|
||
if not name:
|
||
continue
|
||
description = str(getattr(tool, "description", "") or "")
|
||
input_schema = getattr(tool, "inputSchema", {})
|
||
if not isinstance(input_schema, dict):
|
||
input_schema = {}
|
||
sanitized.append({
|
||
"name": name,
|
||
"description": description,
|
||
"inputSchema": dict(input_schema),
|
||
})
|
||
elif isinstance(tool, dict) and tool.get("name"):
|
||
name = str(tool.get("name", "")).strip()
|
||
if not name:
|
||
continue
|
||
entry = {
|
||
"name": name,
|
||
"description": str(tool.get("description", "") or ""),
|
||
"inputSchema": dict(tool.get("inputSchema") or {})
|
||
if isinstance(tool.get("inputSchema"), dict) else {},
|
||
}
|
||
if "enabled" in tool:
|
||
entry["enabled"] = bool(tool["enabled"])
|
||
sanitized.append(entry)
|
||
else:
|
||
# Skip placeholder tuples or metadata fragments
|
||
if isinstance(tool, tuple) and len(tool) == 2 and isinstance(tool[0], str):
|
||
continue
|
||
name = str(tool).strip()
|
||
if not name:
|
||
continue
|
||
sanitized.append({
|
||
"name": name,
|
||
"description": "",
|
||
"inputSchema": {},
|
||
})
|
||
except Exception as exc:
|
||
logger.debug(f"Failed to normalize MCP tool definition {tool!r}: {exc}")
|
||
return sanitized
|
||
|
||
def _apply_tool_cache_update(self, tools: List[Dict[str, Any]]) -> None:
|
||
with self._tools_lock:
|
||
cloned = [self._clone_tool_entry(entry) for entry in tools]
|
||
self._tool_cache = cloned
|
||
self._tool_cache_timestamp = time.time()
|
||
self.tools = [self._clone_tool_entry(entry) for entry in cloned] # backward compatibility
|
||
if self.server_id is not None:
|
||
tool_registry.set_server_tools(self.server_id, tools, self._enabled_lookup)
|
||
|
||
def _get_tool_cache_copy(self) -> List[Dict[str, Any]]:
|
||
with self._tools_lock:
|
||
return [self._clone_tool_entry(entry) for entry in self._tool_cache]
|
||
|
||
def get_cached_tools(self) -> List[Dict[str, Any]]:
|
||
"""Expose a copy of the cached tool metadata without refreshing remotely."""
|
||
return self._get_tool_cache_copy()
|
||
|
||
def _ensure_refresh_worker(self) -> None:
|
||
if self.tools_refresh_interval <= 0:
|
||
return
|
||
with self._tools_lock:
|
||
if self._tools_refresh_thread and self._tools_refresh_thread.is_alive():
|
||
return
|
||
self._tools_stop_event.clear()
|
||
thread = threading.Thread(
|
||
target=self._refresh_loop,
|
||
name=f"mcp-tools-refresh-{self.server_id or 'unknown'}",
|
||
daemon=True,
|
||
)
|
||
self._tools_refresh_thread = thread
|
||
thread.start()
|
||
|
||
def _stop_refresh_worker(self) -> None:
|
||
thread = None
|
||
with self._tools_lock:
|
||
thread = self._tools_refresh_thread
|
||
if not thread:
|
||
self._tools_stop_event.set()
|
||
return
|
||
self._tools_stop_event.set()
|
||
if thread.is_alive():
|
||
thread.join(timeout=self.tools_refresh_interval)
|
||
with self._tools_lock:
|
||
self._tools_refresh_thread = None
|
||
self._tools_stop_event = threading.Event()
|
||
|
||
def _refresh_loop(self) -> None:
|
||
while not self._tools_stop_event.wait(self.tools_refresh_interval):
|
||
if not self.connected:
|
||
continue
|
||
try:
|
||
self._refresh_tools()
|
||
except Exception as exc:
|
||
logger.debug(f"MCP tool refresh failed: {exc}")
|
||
|
||
async def _refresh_tools_async(self) -> bool:
|
||
if not self.session:
|
||
return False
|
||
tools_resp = await _await_or_value(self.session.list_tools(), self.list_timeout_seconds)
|
||
sanitized = self._sanitize_tools(tools_resp)
|
||
if sanitized or self._tool_cache:
|
||
self._apply_tool_cache_update(sanitized)
|
||
return True
|
||
|
||
def _refresh_tools(self) -> bool:
|
||
try:
|
||
future = asyncio.run_coroutine_threadsafe(self._refresh_tools_async(), self.event_loop)
|
||
return future.result(timeout=self.list_timeout_seconds + 5)
|
||
except Exception as exc:
|
||
logger.debug(f"Failed to refresh MCP tool cache: {exc}")
|
||
return False
|
||
|
||
def _clear_tool_cache(self) -> None:
|
||
with self._tools_lock:
|
||
self._tool_cache = []
|
||
self._tool_cache_timestamp = 0.0
|
||
self.tools = None
|
||
if self.server_id is not None:
|
||
tool_registry.mark_all_unavailable(self.server_id)
|
||
|
||
def _resolve_stdio_launch_config(self) -> Dict[str, Any]:
|
||
cfg = self.stdio_config or {}
|
||
runtime_root = _runtime_root_dir()
|
||
command = cfg.get("command") or sys.executable
|
||
args = list(cfg.get("args") or [])
|
||
env = cfg.get("env") or None
|
||
cwd = cfg.get("cwd") or None
|
||
if cwd and not os.path.isabs(cwd):
|
||
cwd = os.path.abspath(os.path.join(runtime_root, cwd))
|
||
|
||
packaged_launch = _resolve_packaged_stdio_binary(runtime_root, command, args, cwd)
|
||
if packaged_launch is not None:
|
||
command, args, cwd = packaged_launch
|
||
else:
|
||
if _is_python_command(command):
|
||
command = sys.executable
|
||
else:
|
||
resolved_command = _resolve_existing_path(command, cwd, runtime_root)
|
||
if resolved_command:
|
||
command = resolved_command
|
||
|
||
resolved_cfg = {
|
||
"command": command,
|
||
"args": args,
|
||
"env": env,
|
||
"cwd": cwd,
|
||
}
|
||
self._resolved_stdio_config = resolved_cfg
|
||
return resolved_cfg
|
||
|
||
async def _connect_async(self) -> Tuple[bool, Any]:
|
||
if self.connected and self.session:
|
||
return True, self.get_cached_tools()
|
||
|
||
if self._manager_task and self._manager_task.done():
|
||
try:
|
||
await self._manager_task
|
||
except Exception:
|
||
pass
|
||
self._manager_task = None
|
||
|
||
if self._manager_task:
|
||
if self._connect_ready_future:
|
||
try:
|
||
return await self._connect_ready_future
|
||
except Exception as exc:
|
||
logger.error(f"Unexpected connection error during startup wait: {exc}")
|
||
return False, str(exc)
|
||
await self._manager_task
|
||
if self.connected and self.session:
|
||
return True, self.get_cached_tools()
|
||
return False, self._last_error or "MCP server connection failed"
|
||
|
||
loop = asyncio.get_running_loop()
|
||
ready_future: asyncio.Future = loop.create_future()
|
||
disconnect_event = asyncio.Event()
|
||
self._disconnect_event = disconnect_event
|
||
self._connect_ready_future = ready_future
|
||
self._last_error = None
|
||
self._manager_task = loop.create_task(self._run_session(ready_future, disconnect_event))
|
||
|
||
try:
|
||
result = await ready_future
|
||
finally:
|
||
self._connect_ready_future = None
|
||
return result
|
||
|
||
async def _run_session(self, ready_future: asyncio.Future, disconnect_event: asyncio.Event) -> None:
|
||
stdio_errlog = None
|
||
stack = AsyncExitStack()
|
||
self.exit_stack = stack
|
||
try:
|
||
async with stack:
|
||
if self.transport == "stdio":
|
||
if not HAS_STDIO:
|
||
message = "Missing stdio-capable MCP client, run: pip install -U mcp"
|
||
self._last_error = message
|
||
if not ready_future.done():
|
||
ready_future.set_result((False, message))
|
||
return
|
||
cfg = self._resolve_stdio_launch_config()
|
||
command = cfg.get("command") or sys.executable
|
||
args = list(cfg.get("args") or [])
|
||
env = cfg.get("env") or None
|
||
cwd = cfg.get("cwd") or None
|
||
|
||
try:
|
||
log_dir = os.path.join(os.getcwd(), 'logs')
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
base = os.path.basename(str(command))
|
||
log_path = os.path.join(log_dir, f"mcp_stdio_{base}.log")
|
||
stdio_errlog = open(log_path, 'a', encoding='utf-8')
|
||
except Exception:
|
||
stdio_errlog = None
|
||
|
||
self._stdio_errlog_file = stdio_errlog
|
||
params = StdioServerParameters(command=command, args=args, env=env, cwd=cwd)
|
||
read_stream, write_stream = await stack.enter_async_context(
|
||
stdio_client(params, errlog=stdio_errlog or sys.stderr)
|
||
)
|
||
else:
|
||
headers = {}
|
||
if self.api_key:
|
||
headers['Authorization'] = f'Bearer {self.api_key}'
|
||
read_stream, write_stream = await stack.enter_async_context(
|
||
sse_client(self.server_url, headers=headers)
|
||
)
|
||
|
||
self.session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
|
||
try:
|
||
await _await_or_value(getattr(self.session, 'initialize', lambda: None)(), self.init_timeout_seconds)
|
||
except Exception:
|
||
pass
|
||
|
||
tools_resp = await _await_or_value(self.session.list_tools(), self.list_timeout_seconds)
|
||
sanitized_tools = self._sanitize_tools(tools_resp)
|
||
self._apply_tool_cache_update(sanitized_tools)
|
||
self.connected = True
|
||
if not ready_future.done():
|
||
ready_future.set_result((True, self.get_cached_tools()))
|
||
|
||
self._ensure_refresh_worker()
|
||
|
||
await disconnect_event.wait()
|
||
|
||
except asyncio.TimeoutError as e:
|
||
self._last_error = f"Connection or tool discovery timed out: {e}"
|
||
if not ready_future.done():
|
||
ready_future.set_result((False, self._last_error))
|
||
except Exception as e:
|
||
self._last_error = str(e)
|
||
logger.error(f"Error while handling connection lifecycle: {e}")
|
||
if not ready_future.done():
|
||
ready_future.set_result((False, self._last_error))
|
||
finally:
|
||
if stdio_errlog:
|
||
try:
|
||
stdio_errlog.close()
|
||
except Exception:
|
||
pass
|
||
if self._stdio_errlog_file and self._stdio_errlog_file is not stdio_errlog:
|
||
try:
|
||
self._stdio_errlog_file.close()
|
||
except Exception:
|
||
pass
|
||
self._stdio_errlog_file = None
|
||
self._resolved_stdio_config = None
|
||
self._stop_refresh_worker()
|
||
self.connected = False
|
||
self.session = None
|
||
self._clear_tool_cache()
|
||
if not ready_future.done():
|
||
ready_future.set_result((False, self._last_error or "MCP server connection failed"))
|
||
if self._disconnect_event is disconnect_event:
|
||
self._disconnect_event = None
|
||
self._manager_task = None
|
||
self.exit_stack = None
|
||
|
||
def connect(self):
|
||
fut = asyncio.run_coroutine_threadsafe(self._connect_async(), self.event_loop)
|
||
return fut.result(timeout=self.init_timeout_seconds + self.list_timeout_seconds + 10)
|
||
|
||
async def _call_tool_async(self, method: str, params=None):
|
||
if not self.connected or not self.session:
|
||
return False, "未连接到MCP服务器"
|
||
try:
|
||
params = params or {}
|
||
result = await _await_or_value(self.session.call_tool(method, params), self.call_timeout_seconds)
|
||
return True, result
|
||
except asyncio.TimeoutError:
|
||
return False, f"调用工具超时({self.call_timeout_seconds}s)"
|
||
except Exception as e:
|
||
logger.exception("调用工具失败异常堆栈")
|
||
return False, f"调用工具失败: {type(e).__name__}: {e}"
|
||
|
||
def call_tool(self, method, params=None):
|
||
future = asyncio.run_coroutine_threadsafe(self._call_tool_async(method, params), self.event_loop)
|
||
return future.result(timeout=self.call_timeout_seconds + 5)
|
||
|
||
def list_tools(self, refresh: bool = False):
|
||
if not self.connected:
|
||
success, tools = self.connect()
|
||
if not success:
|
||
return []
|
||
return tools or []
|
||
if refresh:
|
||
self._refresh_tools()
|
||
return self.get_cached_tools()
|
||
|
||
async def _disconnect_async(self) -> bool:
|
||
task = self._manager_task
|
||
event = self._disconnect_event
|
||
if event and not event.is_set():
|
||
event.set()
|
||
if task:
|
||
try:
|
||
await task
|
||
except Exception as e:
|
||
logger.error(f"Error while closing connection: {e}")
|
||
return False
|
||
return True
|
||
|
||
def _kill_stdio_process(self) -> None:
|
||
"""强制终止 stdio 子进程,确保子进程被完全清理"""
|
||
if self.transport != "stdio":
|
||
return
|
||
|
||
cfg = self._resolved_stdio_config or self.stdio_config or {}
|
||
command = cfg.get("command") or ""
|
||
args = cfg.get("args") or []
|
||
|
||
# 构建用于匹配进程的关键字
|
||
if args:
|
||
# 使用第一个参数(通常是脚本路径)作为匹配关键字
|
||
match_pattern = str(args[0]) if args else command
|
||
else:
|
||
match_pattern = command
|
||
|
||
if not match_pattern:
|
||
return
|
||
|
||
import subprocess
|
||
|
||
try:
|
||
if sys.platform == "win32":
|
||
# Windows: 使用 wmic 查找进程并用 taskkill 终止
|
||
# 查找包含匹配模式的 python 进程
|
||
result = subprocess.run(
|
||
["wmic", "process", "where",
|
||
f"commandline like '%{match_pattern.replace(os.sep, os.sep + os.sep)}%'",
|
||
"get", "processid"],
|
||
capture_output=True, text=True, timeout=5
|
||
)
|
||
if result.returncode == 0:
|
||
for line in result.stdout.strip().split('\n'):
|
||
line = line.strip()
|
||
if line.isdigit():
|
||
pid = int(line)
|
||
try:
|
||
subprocess.run(
|
||
["taskkill", "/F", "/PID", str(pid)],
|
||
capture_output=True, timeout=5
|
||
)
|
||
logger.info(f"强制终止 stdio 子进程 PID: {pid}")
|
||
except Exception as e:
|
||
logger.debug(f"终止进程 {pid} 失败: {e}")
|
||
else:
|
||
# Unix: 使用 pkill 终止
|
||
try:
|
||
subprocess.run(
|
||
["pkill", "-f", match_pattern],
|
||
capture_output=True, timeout=5
|
||
)
|
||
logger.info(f"强制终止匹配 '{match_pattern}' 的进程")
|
||
except Exception as e:
|
||
logger.debug(f"pkill 执行失败: {e}")
|
||
except Exception as e:
|
||
logger.debug(f"强制终止 stdio 子进程失败: {e}")
|
||
|
||
def disconnect(self):
|
||
if not self._manager_task and not self._disconnect_event:
|
||
return True
|
||
try:
|
||
fut = asyncio.run_coroutine_threadsafe(self._disconnect_async(), self.event_loop)
|
||
fut.result(timeout=5)
|
||
except Exception as e:
|
||
logger.error(f"Error while closing connection: {e}")
|
||
|
||
# 确保 stdio 子进程被终止
|
||
self._kill_stdio_process()
|
||
return True
|
||
|