Add live tool approvals via Telegram

This commit is contained in:
mirivlad 2026-04-07 17:17:17 +08:00
parent ac7f1bd493
commit aa3154e9d7
8 changed files with 374 additions and 35 deletions

View File

@ -37,6 +37,7 @@ Qwen OAuth + OpenAI-compatible endpoint
- job-based chat polling между `bot` и `serv` - job-based chat polling между `bot` и `serv`
- persistence для chat jobs и pending OAuth flows на стороне `serv` - persistence для chat jobs и pending OAuth flows на стороне `serv`
- policy mode для инструментов: `full-access`, `workspace-write`, `read-only` - policy mode для инструментов: `full-access`, `workspace-write`, `read-only`
- live approval flow для инструментов через Telegram
## Ограничения текущей реализации ## Ограничения текущей реализации
@ -59,6 +60,10 @@ cp serv/.env.example serv/.env
`full-access` - все инструменты `full-access` - все инструменты
`workspace-write` - без `exec_command` `workspace-write` - без `exec_command`
`read-only` - только чтение и поиск `read-only` - только чтение и поиск
`ask-shell` - shell только после подтверждения
`ask-write` - shell и записи только после подтверждения
`ask-all` - любой инструмент только после подтверждения
- `NEW_QWEN_APPROVAL_TIMEOUT_SECONDS` - сколько сервер ждёт решения по approval
Бот: Бот:
@ -101,8 +106,19 @@ curl -X POST http://127.0.0.1:8080/api/v1/auth/device/start
- `POST /api/v1/auth/device/start` - `POST /api/v1/auth/device/start`
- `POST /api/v1/auth/device/poll` - `POST /api/v1/auth/device/poll`
- `GET /api/v1/sessions` - `GET /api/v1/sessions`
- `GET /api/v1/approvals`
- `POST /api/v1/session/get` - `POST /api/v1/session/get`
- `POST /api/v1/session/clear` - `POST /api/v1/session/clear`
- `POST /api/v1/chat` - `POST /api/v1/chat`
- `POST /api/v1/chat/start` - `POST /api/v1/chat/start`
- `POST /api/v1/chat/poll` - `POST /api/v1/chat/poll`
- `POST /api/v1/approval/respond`
## Telegram Approval Flow
Если политика инструментов настроена как `ask-shell`, `ask-write` или `ask-all`, бот пришлёт запрос на подтверждение с `approval_id`.
Дальше можно ответить одной из команд:
- `/approve <approval_id>`
- `/reject <approval_id>`

View File

@ -15,10 +15,18 @@ STATE_FILE = Path(__file__).resolve().parent.parent / ".new-qwen" / "telegram-st
def load_state() -> dict[str, Any]: def load_state() -> dict[str, Any]:
if not STATE_FILE.exists(): if not STATE_FILE.exists():
return {"offset": None, "sessions": {}, "auth_flows": {}} return {
"offset": None,
"sessions": {},
"auth_flows": {},
"active_jobs": {},
"pending_approvals": {},
}
state = json.loads(STATE_FILE.read_text(encoding="utf-8")) state = json.loads(STATE_FILE.read_text(encoding="utf-8"))
state.setdefault("sessions", {}) state.setdefault("sessions", {})
state.setdefault("auth_flows", {}) state.setdefault("auth_flows", {})
state.setdefault("active_jobs", {})
state.setdefault("pending_approvals", {})
return state return state
@ -66,9 +74,26 @@ def summarize_event(event: dict[str, Any]) -> str | None:
return f"Инструмент {event.get('name')} завершён" return f"Инструмент {event.get('name')} завершён"
if event_type == "error": if event_type == "error":
return f"Ошибка: {event.get('message')}" return f"Ошибка: {event.get('message')}"
if event_type == "approval_result":
status = event.get("status")
tool_name = event.get("tool_name")
if status == "approved":
return f"Подтверждение получено для {tool_name}"
return f"Подтверждение отклонено для {tool_name}"
return None return None
def format_approval_request(event: dict[str, Any]) -> str:
return (
"Нужно подтверждение инструмента.\n"
f"approval_id: {event.get('approval_id')}\n"
f"tool: {event.get('tool_name')}\n"
f"args: {json.dumps(event.get('arguments', {}), ensure_ascii=False)}\n\n"
f"/approve {event.get('approval_id')}\n"
f"/reject {event.get('approval_id')}"
)
def get_auth_flow(state: dict[str, Any], chat_id: int) -> dict[str, Any] | None: def get_auth_flow(state: dict[str, Any], chat_id: int) -> dict[str, Any] | None:
return state.setdefault("auth_flows", {}).get(str(chat_id)) return state.setdefault("auth_flows", {}).get(str(chat_id))
@ -145,7 +170,7 @@ def enqueue_pending_message(
) )
def deliver_chat_message( def start_chat_job(
api: TelegramAPI, api: TelegramAPI,
config: BotConfig, config: BotConfig,
state: dict[str, Any], state: dict[str, Any],
@ -168,31 +193,15 @@ def deliver_chat_message(
}, },
) )
state["sessions"][session_key] = start_result["session_id"] state["sessions"][session_key] = start_result["session_id"]
job_id = start_result["job_id"] state.setdefault("active_jobs", {})[start_result["job_id"]] = {
seen_seq = 0 "job_id": start_result["job_id"],
sent_statuses: set[str] = set() "chat_id": chat_id,
answer = None "user_id": user_id,
while True: "session_key": session_key,
poll_result = post_json( "session_id": start_result["session_id"],
f"{config.server_url}/api/v1/chat/poll", "seen_seq": 0,
{"job_id": job_id, "since_seq": seen_seq}, "sent_statuses": [],
) }
for event in poll_result.get("events", []):
seen_seq = max(seen_seq, int(event.get("seq", 0)))
summary = summarize_event(event)
if summary and summary not in sent_statuses:
api.send_message(chat_id, summary[:4000])
sent_statuses.add(summary)
if poll_result.get("status") == "completed":
answer = poll_result.get("answer")
state["sessions"][session_key] = poll_result["session_id"]
break
if poll_result.get("status") == "failed":
raise RuntimeError(poll_result.get("error") or "Chat job failed")
time.sleep(1.2)
answer = answer or "Пустой ответ от модели."
send_text_chunks(api, chat_id, answer)
def poll_auth_flow( def poll_auth_flow(
@ -240,7 +249,7 @@ def poll_auth_flow(
state["auth_flows"].pop(str(chat_id), None) state["auth_flows"].pop(str(chat_id), None)
api.send_message(chat_id, "Qwen OAuth успешно настроен.") api.send_message(chat_id, "Qwen OAuth успешно настроен.")
for item in flow.get("pending_messages", []): for item in flow.get("pending_messages", []):
deliver_chat_message( start_chat_job(
api, api,
config, config,
state, state,
@ -267,6 +276,58 @@ def process_auth_flows(
print(f"auth flow poll error for chat {chat_id_raw}: {exc}") print(f"auth flow poll error for chat {chat_id_raw}: {exc}")
def process_active_jobs(
api: TelegramAPI,
config: BotConfig,
state: dict[str, Any],
) -> None:
active_jobs = state.setdefault("active_jobs", {})
pending_approvals = state.setdefault("pending_approvals", {})
for job_id in list(active_jobs.keys()):
job_state = active_jobs[job_id]
poll_result = post_json(
f"{config.server_url}/api/v1/chat/poll",
{"job_id": job_id, "since_seq": job_state.get("seen_seq", 0)},
)
for event in poll_result.get("events", []):
seq = int(event.get("seq", 0))
job_state["seen_seq"] = max(job_state.get("seen_seq", 0), seq)
if event.get("type") == "approval_request":
pending_approvals[str(job_state["chat_id"])] = {
"approval_id": event["approval_id"],
"job_id": job_id,
}
send_text_chunks(
api,
int(job_state["chat_id"]),
format_approval_request(event),
)
continue
summary = summarize_event(event)
sent_statuses = set(job_state.get("sent_statuses", []))
if summary and summary not in sent_statuses:
api.send_message(int(job_state["chat_id"]), summary[:4000])
sent_statuses.add(summary)
job_state["sent_statuses"] = sorted(sent_statuses)
status = poll_result.get("status")
if status == "completed":
state["sessions"][job_state["session_key"]] = poll_result["session_id"]
send_text_chunks(
api,
int(job_state["chat_id"]),
poll_result.get("answer") or "Пустой ответ от модели.",
)
active_jobs.pop(job_id, None)
elif status == "failed":
send_text_chunks(
api,
int(job_state["chat_id"]),
f"Job завершился с ошибкой: {poll_result.get('error')}",
)
active_jobs.pop(job_id, None)
def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], message: dict[str, Any]) -> None: def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], message: dict[str, Any]) -> None:
chat_id = message["chat"]["id"] chat_id = message["chat"]["id"]
user_id = str(message.get("from", {}).get("id", chat_id)) user_id = str(message.get("from", {}).get("id", chat_id))
@ -282,7 +343,7 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
if text == "/start": if text == "/start":
api.send_message( api.send_message(
chat_id, chat_id,
"new-qwen bot готов.\nКоманды: /help, /auth, /status, /session, /clear.", "new-qwen bot готов.\nКоманды: /help, /auth, /status, /session, /clear, /approve, /reject.",
) )
return return
@ -294,10 +355,35 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
"/auth_check [flow_id] - проверить авторизацию\n" "/auth_check [flow_id] - проверить авторизацию\n"
"/status - статус OAuth и сервера\n" "/status - статус OAuth и сервера\n"
"/session - показать текущую сессию\n" "/session - показать текущую сессию\n"
"/approve [approval_id] - подтвердить инструмент\n"
"/reject [approval_id] - отклонить инструмент\n"
"/clear - очистить контекст", "/clear - очистить контекст",
) )
return return
if text.startswith("/approve") or text.startswith("/reject"):
parts = text.split(maxsplit=1)
approval = state.setdefault("pending_approvals", {}).get(str(chat_id))
approval_id = parts[1] if len(parts) == 2 else approval.get("approval_id") if approval else None
if not approval_id:
api.send_message(chat_id, "Нет pending approval для этого чата.")
return
response = post_json(
f"{config.server_url}/api/v1/approval/respond",
{
"approval_id": approval_id,
"approved": text.startswith("/approve"),
"actor": user_id,
},
)
if response.get("status") != "pending":
state["pending_approvals"].pop(str(chat_id), None)
api.send_message(
chat_id,
f"Approval {approval_id}: {response.get('status')}",
)
return
if text == "/auth": if text == "/auth":
start_auth_flow(api, config, state, chat_id, force_new=True) start_auth_flow(api, config, state, chat_id, force_new=True)
return return
@ -324,7 +410,9 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
"Сервер доступен.\n" "Сервер доступен.\n"
f"OAuth: {'configured' if status.get('authenticated') else 'not configured'}\n" f"OAuth: {'configured' if status.get('authenticated') else 'not configured'}\n"
f"resource_url: {status.get('resource_url')}\n" f"resource_url: {status.get('resource_url')}\n"
f"expires_at: {status.get('expires_at')}", f"expires_at: {status.get('expires_at')}\n"
f"tool_policy: {status.get('tool_policy')}\n"
f"pending_approvals: {status.get('pending_approvals')}",
) )
return return
@ -359,7 +447,7 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
api.send_message(chat_id, "Сообщение поставлено в очередь до завершения авторизации.") api.send_message(chat_id, "Сообщение поставлено в очередь до завершения авторизации.")
return return
deliver_chat_message(api, config, state, chat_id, user_id, session_key, text) start_chat_job(api, config, state, chat_id, user_id, session_key, text)
def main() -> None: def main() -> None:
@ -370,7 +458,11 @@ def main() -> None:
while True: while True:
try: try:
process_auth_flows(api, config, state) process_auth_flows(api, config, state)
updates = api.get_updates(state.get("offset"), config.poll_timeout) process_active_jobs(api, config, state)
timeout = config.poll_timeout
if state.get("active_jobs"):
timeout = min(timeout, 3)
updates = api.get_updates(state.get("offset"), timeout)
for update in updates: for update in updates:
state["offset"] = update["update_id"] + 1 state["offset"] = update["update_id"] + 1
message = update.get("message") message = update.get("message")

View File

@ -9,3 +9,4 @@ NEW_QWEN_MAX_TOOL_ROUNDS=8
NEW_QWEN_MAX_FILE_READ_BYTES=200000 NEW_QWEN_MAX_FILE_READ_BYTES=200000
NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES=12000 NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES=12000
NEW_QWEN_TOOL_POLICY=full-access NEW_QWEN_TOOL_POLICY=full-access
NEW_QWEN_APPROVAL_TIMEOUT_SECONDS=3600

View File

@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any from typing import Any
from config import ServerConfig from config import ServerConfig
from approvals import ApprovalStore
from jobs import JobStore from jobs import JobStore
from llm import QwenAgent from llm import QwenAgent
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
@ -25,6 +26,7 @@ class AppState:
self.tools = ToolRegistry(config) self.tools = ToolRegistry(config)
self.agent = QwenAgent(config, self.oauth, self.tools) self.agent = QwenAgent(config, self.oauth, self.tools)
self.jobs = JobStore(config.state_dir / "jobs") self.jobs = JobStore(config.state_dir / "jobs")
self.approvals = ApprovalStore(config.state_dir / "approvals")
self.pending_flows_path = config.state_dir / "oauth_flows.json" self.pending_flows_path = config.state_dir / "oauth_flows.json"
self.pending_device_flows: dict[str, DeviceAuthState] = self._load_pending_flows() self.pending_device_flows: dict[str, DeviceAuthState] = self._load_pending_flows()
self.lock = threading.Lock() self.lock = threading.Lock()
@ -83,6 +85,7 @@ class AppState:
"authenticated": False, "authenticated": False,
"tool_policy": self.config.tool_policy, "tool_policy": self.config.tool_policy,
"pending_flows": len(self.pending_device_flows), "pending_flows": len(self.pending_device_flows),
"pending_approvals": len(self.approvals.list_pending()),
} }
return { return {
"authenticated": True, "authenticated": True,
@ -90,6 +93,7 @@ class AppState:
"expires_at": creds.get("expiry_date"), "expires_at": creds.get("expiry_date"),
"tool_policy": self.config.tool_policy, "tool_policy": self.config.tool_policy,
"pending_flows": len(self.pending_device_flows), "pending_flows": len(self.pending_device_flows),
"pending_approvals": len(self.approvals.list_pending()),
} }
@ -125,6 +129,9 @@ class RequestHandler(BaseHTTPRequestHandler):
if self.path == "/api/v1/sessions": if self.path == "/api/v1/sessions":
self._send(HTTPStatus.OK, {"sessions": self.app.sessions.list_sessions()}) self._send(HTTPStatus.OK, {"sessions": self.app.sessions.list_sessions()})
return return
if self.path == "/api/v1/approvals":
self._send(HTTPStatus.OK, {"approvals": self.app.approvals.list_pending()})
return
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"}) self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
def _run_chat_job(self, job_id: str, session_id: str, user_id: str, message: str) -> None: def _run_chat_job(self, job_id: str, session_id: str, user_id: str, message: str) -> None:
@ -140,6 +147,11 @@ class RequestHandler(BaseHTTPRequestHandler):
history, history,
message, message,
on_event=lambda event: self.app.jobs.append_event(job_id, event), on_event=lambda event: self.app.jobs.append_event(job_id, event),
approval_callback=lambda tool_name, arguments: self._request_approval(
job_id,
tool_name,
arguments,
),
) )
persisted_messages = result["messages"][1:] persisted_messages = result["messages"][1:]
self.app.sessions.save( self.app.sessions.save(
@ -168,6 +180,34 @@ class RequestHandler(BaseHTTPRequestHandler):
) )
self.app.jobs.fail(job_id, str(exc)) self.app.jobs.fail(job_id, str(exc))
def _request_approval(
self,
job_id: str,
tool_name: str,
arguments: dict[str, Any],
) -> dict[str, Any]:
approval = self.app.approvals.create(
job_id=job_id,
tool_name=tool_name,
arguments=arguments,
)
self.app.jobs.append_event(
job_id,
{
"type": "approval_request",
"approval_id": approval["approval_id"],
"tool_name": tool_name,
"arguments": arguments,
},
)
self.app.jobs.set_status(job_id, "waiting_approval")
decision = self.app.approvals.wait(
approval["approval_id"],
timeout_seconds=float(self.app.config.approval_timeout_seconds),
)
self.app.jobs.set_status(job_id, "running")
return decision
def do_POST(self) -> None: def do_POST(self) -> None:
try: try:
if self.path == "/api/v1/auth/device/start": if self.path == "/api/v1/auth/device/start":
@ -284,6 +324,19 @@ class RequestHandler(BaseHTTPRequestHandler):
) )
return return
if self.path == "/api/v1/approval/respond":
body = self._json_body()
approval_id = body["approval_id"]
approved = bool(body["approved"])
actor = str(body.get("actor") or "unknown")
approval = self.app.approvals.respond(
approval_id,
approved=approved,
actor=actor,
)
self._send(HTTPStatus.OK, approval)
return
if self.path == "/api/v1/session/get": if self.path == "/api/v1/session/get":
body = self._json_body() body = self._json_body()
session_id = body["session_id"] session_id = body["session_id"]

130
serv/approvals.py Normal file
View File

@ -0,0 +1,130 @@
from __future__ import annotations
import json
import threading
import time
import uuid
from pathlib import Path
from typing import Any
class ApprovalStore:
def __init__(self, base_dir: Path) -> None:
self.base_dir = base_dir
self.base_dir.mkdir(parents=True, exist_ok=True)
self._approvals: dict[str, dict[str, Any]] = {}
self._conditions: dict[str, threading.Condition] = {}
self._lock = threading.RLock()
self._load_existing()
def _path(self, approval_id: str) -> Path:
return self.base_dir / f"{approval_id}.json"
def _save(self, approval: dict[str, Any]) -> None:
self._path(approval["approval_id"]).write_text(
json.dumps(approval, ensure_ascii=False, indent=2),
encoding="utf-8",
)
def _load_existing(self) -> None:
for path in sorted(self.base_dir.glob("*.json")):
try:
approval = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
continue
if approval.get("status") == "pending":
approval["status"] = "rejected"
approval["reason"] = "Server restarted while waiting for approval"
approval["updated_at"] = time.time()
path.write_text(
json.dumps(approval, ensure_ascii=False, indent=2),
encoding="utf-8",
)
self._approvals[approval["approval_id"]] = approval
def create(
self,
*,
job_id: str,
tool_name: str,
arguments: dict[str, Any],
) -> dict[str, Any]:
approval_id = uuid.uuid4().hex
approval = {
"approval_id": approval_id,
"job_id": job_id,
"tool_name": tool_name,
"arguments": arguments,
"status": "pending",
"created_at": time.time(),
"updated_at": time.time(),
"actor": None,
"reason": None,
}
with self._lock:
self._approvals[approval_id] = approval
self._conditions[approval_id] = threading.Condition(self._lock)
self._save(approval)
return approval.copy()
def get(self, approval_id: str) -> dict[str, Any] | None:
with self._lock:
approval = self._approvals.get(approval_id)
return approval.copy() if approval else None
def list_pending(self) -> list[dict[str, Any]]:
with self._lock:
pending = [
approval.copy()
for approval in self._approvals.values()
if approval.get("status") == "pending"
]
pending.sort(key=lambda item: item.get("created_at", 0))
return pending
def respond(
self,
approval_id: str,
*,
approved: bool,
actor: str,
) -> dict[str, Any]:
with self._lock:
approval = self._approvals.get(approval_id)
if not approval:
raise KeyError("Unknown approval_id")
if approval["status"] != "pending":
return approval.copy()
approval["status"] = "approved" if approved else "rejected"
approval["actor"] = actor
approval["updated_at"] = time.time()
approval["reason"] = None if approved else "Rejected by operator"
self._save(approval)
condition = self._conditions.get(approval_id)
if condition:
condition.notify_all()
return approval.copy()
def wait(self, approval_id: str, timeout_seconds: float = 3600.0) -> dict[str, Any]:
with self._lock:
approval = self._approvals.get(approval_id)
if not approval:
raise KeyError("Unknown approval_id")
if approval["status"] != "pending":
return approval.copy()
condition = self._conditions.setdefault(
approval_id,
threading.Condition(self._lock),
)
deadline = time.time() + timeout_seconds
while approval["status"] == "pending":
remaining = deadline - time.time()
if remaining <= 0:
approval["status"] = "rejected"
approval["reason"] = "Approval timeout"
approval["updated_at"] = time.time()
self._save(approval)
break
condition.wait(timeout=remaining)
return approval.copy()

View File

@ -29,6 +29,7 @@ class ServerConfig:
max_file_read_bytes: int max_file_read_bytes: int
max_command_output_bytes: int max_command_output_bytes: int
tool_policy: str tool_policy: str
approval_timeout_seconds: int
@classmethod @classmethod
def load(cls) -> "ServerConfig": def load(cls) -> "ServerConfig":
@ -65,4 +66,7 @@ class ServerConfig:
os.environ.get("NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES", "12000") os.environ.get("NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES", "12000")
), ),
tool_policy=os.environ.get("NEW_QWEN_TOOL_POLICY", "full-access").strip(), tool_policy=os.environ.get("NEW_QWEN_TOOL_POLICY", "full-access").strip(),
approval_timeout_seconds=int(
os.environ.get("NEW_QWEN_APPROVAL_TIMEOUT_SECONDS", "3600")
),
) )

View File

@ -54,6 +54,7 @@ class QwenAgent:
history: list[dict[str, Any]], history: list[dict[str, Any]],
user_message: str, user_message: str,
on_event: Callable[[dict[str, Any]], None] | None = None, on_event: Callable[[dict[str, Any]], None] | None = None,
approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
emit = on_event or (lambda _event: None) emit = on_event or (lambda _event: None)
system_prompt = self.config.system_prompt or DEFAULT_SYSTEM_PROMPT system_prompt = self.config.system_prompt or DEFAULT_SYSTEM_PROMPT
@ -99,6 +100,30 @@ class QwenAgent:
tool_call_event = {"type": "tool_call", "name": tool_name, "arguments": arguments} tool_call_event = {"type": "tool_call", "name": tool_name, "arguments": arguments}
events.append(tool_call_event) events.append(tool_call_event)
emit(tool_call_event) emit(tool_call_event)
if approval_callback and self.tools.requires_approval(tool_name):
approval_result = approval_callback(tool_name, arguments)
approval_event = {
"type": "approval_result",
"tool_name": tool_name,
"approval_id": approval_result["approval_id"],
"status": approval_result["status"],
"actor": approval_result.get("actor"),
}
events.append(approval_event)
emit(approval_event)
if approval_result["status"] != "approved":
result = {"error": f"Tool '{tool_name}' was rejected by operator"}
tool_result_event = {"type": "tool_result", "name": tool_name, "result": result}
events.append(tool_result_event)
emit(tool_result_event)
messages.append(
{
"role": "tool",
"tool_call_id": call["id"],
"content": self.tools.encode_result(result),
}
)
continue
try: try:
result = self.tools.execute(tool_name, arguments) result = self.tools.execute(tool_name, arguments)
except ToolError as exc: except ToolError as exc:

View File

@ -182,17 +182,35 @@ class ToolRegistry:
def _check_policy(self, tool_name: str) -> None: def _check_policy(self, tool_name: str) -> None:
policy = self.config.tool_policy policy = self.config.tool_policy
read_only_tools = {"list_files", "glob_search", "grep_text", "stat_path", "read_file"} read_only_tools = {"list_files", "glob_search", "grep_text", "stat_path", "read_file"}
write_tools = {"replace_in_file", "write_file", "make_directory"}
shell_tools = {"exec_command"} shell_tools = {"exec_command"}
if policy == "full-access": if policy in {"full-access", "ask-shell", "ask-write", "ask-all"}:
return return
if policy == "read-only" and tool_name not in read_only_tools: if policy == "read-only" and tool_name not in read_only_tools:
raise ToolError(f"Tool '{tool_name}' is blocked by read-only policy") raise ToolError(f"Tool '{tool_name}' is blocked by read-only policy")
if policy == "workspace-write" and tool_name in shell_tools: if policy == "workspace-write" and tool_name in shell_tools:
raise ToolError(f"Tool '{tool_name}' is blocked by workspace-write policy") raise ToolError(f"Tool '{tool_name}' is blocked by workspace-write policy")
if policy not in {"full-access", "workspace-write", "read-only"}: if policy not in {
"full-access",
"workspace-write",
"read-only",
"ask-shell",
"ask-write",
"ask-all",
}:
raise ToolError(f"Unknown tool policy: {policy}") raise ToolError(f"Unknown tool policy: {policy}")
def requires_approval(self, tool_name: str) -> bool:
policy = self.config.tool_policy
write_tools = {"replace_in_file", "write_file", "make_directory"}
shell_tools = {"exec_command"}
if policy == "ask-all":
return True
if policy == "ask-shell":
return tool_name in shell_tools
if policy == "ask-write":
return tool_name in shell_tools or tool_name in write_tools
return False
def execute(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]: def execute(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
handler = self._handlers.get(name) handler = self._handlers.get(name)
if not handler: if not handler: