Add live tool approvals via Telegram
This commit is contained in:
parent
ac7f1bd493
commit
aa3154e9d7
16
README.md
16
README.md
|
|
@ -37,6 +37,7 @@ Qwen OAuth + OpenAI-compatible endpoint
|
|||
- job-based chat polling между `bot` и `serv`
|
||||
- persistence для chat jobs и pending OAuth flows на стороне `serv`
|
||||
- policy mode для инструментов: `full-access`, `workspace-write`, `read-only`
|
||||
- live approval flow для инструментов через Telegram
|
||||
|
||||
## Ограничения текущей реализации
|
||||
|
||||
|
|
@ -59,6 +60,10 @@ cp serv/.env.example serv/.env
|
|||
`full-access` - все инструменты
|
||||
`workspace-write` - без `exec_command`
|
||||
`read-only` - только чтение и поиск
|
||||
`ask-shell` - shell только после подтверждения
|
||||
`ask-write` - shell и записи только после подтверждения
|
||||
`ask-all` - любой инструмент только после подтверждения
|
||||
- `NEW_QWEN_APPROVAL_TIMEOUT_SECONDS` - сколько сервер ждёт решения по approval
|
||||
|
||||
Бот:
|
||||
|
||||
|
|
@ -101,8 +106,19 @@ curl -X POST http://127.0.0.1:8080/api/v1/auth/device/start
|
|||
- `POST /api/v1/auth/device/start`
|
||||
- `POST /api/v1/auth/device/poll`
|
||||
- `GET /api/v1/sessions`
|
||||
- `GET /api/v1/approvals`
|
||||
- `POST /api/v1/session/get`
|
||||
- `POST /api/v1/session/clear`
|
||||
- `POST /api/v1/chat`
|
||||
- `POST /api/v1/chat/start`
|
||||
- `POST /api/v1/chat/poll`
|
||||
- `POST /api/v1/approval/respond`
|
||||
|
||||
## Telegram Approval Flow
|
||||
|
||||
Если политика инструментов настроена как `ask-shell`, `ask-write` или `ask-all`, бот пришлёт запрос на подтверждение с `approval_id`.
|
||||
|
||||
Дальше можно ответить одной из команд:
|
||||
|
||||
- `/approve <approval_id>`
|
||||
- `/reject <approval_id>`
|
||||
|
|
|
|||
156
bot/app.py
156
bot/app.py
|
|
@ -15,10 +15,18 @@ STATE_FILE = Path(__file__).resolve().parent.parent / ".new-qwen" / "telegram-st
|
|||
|
||||
def load_state() -> dict[str, Any]:
|
||||
if not STATE_FILE.exists():
|
||||
return {"offset": None, "sessions": {}, "auth_flows": {}}
|
||||
return {
|
||||
"offset": None,
|
||||
"sessions": {},
|
||||
"auth_flows": {},
|
||||
"active_jobs": {},
|
||||
"pending_approvals": {},
|
||||
}
|
||||
state = json.loads(STATE_FILE.read_text(encoding="utf-8"))
|
||||
state.setdefault("sessions", {})
|
||||
state.setdefault("auth_flows", {})
|
||||
state.setdefault("active_jobs", {})
|
||||
state.setdefault("pending_approvals", {})
|
||||
return state
|
||||
|
||||
|
||||
|
|
@ -66,9 +74,26 @@ def summarize_event(event: dict[str, Any]) -> str | None:
|
|||
return f"Инструмент {event.get('name')} завершён"
|
||||
if event_type == "error":
|
||||
return f"Ошибка: {event.get('message')}"
|
||||
if event_type == "approval_result":
|
||||
status = event.get("status")
|
||||
tool_name = event.get("tool_name")
|
||||
if status == "approved":
|
||||
return f"Подтверждение получено для {tool_name}"
|
||||
return f"Подтверждение отклонено для {tool_name}"
|
||||
return None
|
||||
|
||||
|
||||
def format_approval_request(event: dict[str, Any]) -> str:
|
||||
return (
|
||||
"Нужно подтверждение инструмента.\n"
|
||||
f"approval_id: {event.get('approval_id')}\n"
|
||||
f"tool: {event.get('tool_name')}\n"
|
||||
f"args: {json.dumps(event.get('arguments', {}), ensure_ascii=False)}\n\n"
|
||||
f"/approve {event.get('approval_id')}\n"
|
||||
f"/reject {event.get('approval_id')}"
|
||||
)
|
||||
|
||||
|
||||
def get_auth_flow(state: dict[str, Any], chat_id: int) -> dict[str, Any] | None:
|
||||
return state.setdefault("auth_flows", {}).get(str(chat_id))
|
||||
|
||||
|
|
@ -145,7 +170,7 @@ def enqueue_pending_message(
|
|||
)
|
||||
|
||||
|
||||
def deliver_chat_message(
|
||||
def start_chat_job(
|
||||
api: TelegramAPI,
|
||||
config: BotConfig,
|
||||
state: dict[str, Any],
|
||||
|
|
@ -168,31 +193,15 @@ def deliver_chat_message(
|
|||
},
|
||||
)
|
||||
state["sessions"][session_key] = start_result["session_id"]
|
||||
job_id = start_result["job_id"]
|
||||
seen_seq = 0
|
||||
sent_statuses: set[str] = set()
|
||||
answer = None
|
||||
while True:
|
||||
poll_result = post_json(
|
||||
f"{config.server_url}/api/v1/chat/poll",
|
||||
{"job_id": job_id, "since_seq": seen_seq},
|
||||
)
|
||||
for event in poll_result.get("events", []):
|
||||
seen_seq = max(seen_seq, int(event.get("seq", 0)))
|
||||
summary = summarize_event(event)
|
||||
if summary and summary not in sent_statuses:
|
||||
api.send_message(chat_id, summary[:4000])
|
||||
sent_statuses.add(summary)
|
||||
if poll_result.get("status") == "completed":
|
||||
answer = poll_result.get("answer")
|
||||
state["sessions"][session_key] = poll_result["session_id"]
|
||||
break
|
||||
if poll_result.get("status") == "failed":
|
||||
raise RuntimeError(poll_result.get("error") or "Chat job failed")
|
||||
time.sleep(1.2)
|
||||
|
||||
answer = answer or "Пустой ответ от модели."
|
||||
send_text_chunks(api, chat_id, answer)
|
||||
state.setdefault("active_jobs", {})[start_result["job_id"]] = {
|
||||
"job_id": start_result["job_id"],
|
||||
"chat_id": chat_id,
|
||||
"user_id": user_id,
|
||||
"session_key": session_key,
|
||||
"session_id": start_result["session_id"],
|
||||
"seen_seq": 0,
|
||||
"sent_statuses": [],
|
||||
}
|
||||
|
||||
|
||||
def poll_auth_flow(
|
||||
|
|
@ -240,7 +249,7 @@ def poll_auth_flow(
|
|||
state["auth_flows"].pop(str(chat_id), None)
|
||||
api.send_message(chat_id, "Qwen OAuth успешно настроен.")
|
||||
for item in flow.get("pending_messages", []):
|
||||
deliver_chat_message(
|
||||
start_chat_job(
|
||||
api,
|
||||
config,
|
||||
state,
|
||||
|
|
@ -267,6 +276,58 @@ def process_auth_flows(
|
|||
print(f"auth flow poll error for chat {chat_id_raw}: {exc}")
|
||||
|
||||
|
||||
def process_active_jobs(
|
||||
api: TelegramAPI,
|
||||
config: BotConfig,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
active_jobs = state.setdefault("active_jobs", {})
|
||||
pending_approvals = state.setdefault("pending_approvals", {})
|
||||
for job_id in list(active_jobs.keys()):
|
||||
job_state = active_jobs[job_id]
|
||||
poll_result = post_json(
|
||||
f"{config.server_url}/api/v1/chat/poll",
|
||||
{"job_id": job_id, "since_seq": job_state.get("seen_seq", 0)},
|
||||
)
|
||||
for event in poll_result.get("events", []):
|
||||
seq = int(event.get("seq", 0))
|
||||
job_state["seen_seq"] = max(job_state.get("seen_seq", 0), seq)
|
||||
if event.get("type") == "approval_request":
|
||||
pending_approvals[str(job_state["chat_id"])] = {
|
||||
"approval_id": event["approval_id"],
|
||||
"job_id": job_id,
|
||||
}
|
||||
send_text_chunks(
|
||||
api,
|
||||
int(job_state["chat_id"]),
|
||||
format_approval_request(event),
|
||||
)
|
||||
continue
|
||||
summary = summarize_event(event)
|
||||
sent_statuses = set(job_state.get("sent_statuses", []))
|
||||
if summary and summary not in sent_statuses:
|
||||
api.send_message(int(job_state["chat_id"]), summary[:4000])
|
||||
sent_statuses.add(summary)
|
||||
job_state["sent_statuses"] = sorted(sent_statuses)
|
||||
|
||||
status = poll_result.get("status")
|
||||
if status == "completed":
|
||||
state["sessions"][job_state["session_key"]] = poll_result["session_id"]
|
||||
send_text_chunks(
|
||||
api,
|
||||
int(job_state["chat_id"]),
|
||||
poll_result.get("answer") or "Пустой ответ от модели.",
|
||||
)
|
||||
active_jobs.pop(job_id, None)
|
||||
elif status == "failed":
|
||||
send_text_chunks(
|
||||
api,
|
||||
int(job_state["chat_id"]),
|
||||
f"Job завершился с ошибкой: {poll_result.get('error')}",
|
||||
)
|
||||
active_jobs.pop(job_id, None)
|
||||
|
||||
|
||||
def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], message: dict[str, Any]) -> None:
|
||||
chat_id = message["chat"]["id"]
|
||||
user_id = str(message.get("from", {}).get("id", chat_id))
|
||||
|
|
@ -282,7 +343,7 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
|||
if text == "/start":
|
||||
api.send_message(
|
||||
chat_id,
|
||||
"new-qwen bot готов.\nКоманды: /help, /auth, /status, /session, /clear.",
|
||||
"new-qwen bot готов.\nКоманды: /help, /auth, /status, /session, /clear, /approve, /reject.",
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -294,10 +355,35 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
|||
"/auth_check [flow_id] - проверить авторизацию\n"
|
||||
"/status - статус OAuth и сервера\n"
|
||||
"/session - показать текущую сессию\n"
|
||||
"/approve [approval_id] - подтвердить инструмент\n"
|
||||
"/reject [approval_id] - отклонить инструмент\n"
|
||||
"/clear - очистить контекст",
|
||||
)
|
||||
return
|
||||
|
||||
if text.startswith("/approve") or text.startswith("/reject"):
|
||||
parts = text.split(maxsplit=1)
|
||||
approval = state.setdefault("pending_approvals", {}).get(str(chat_id))
|
||||
approval_id = parts[1] if len(parts) == 2 else approval.get("approval_id") if approval else None
|
||||
if not approval_id:
|
||||
api.send_message(chat_id, "Нет pending approval для этого чата.")
|
||||
return
|
||||
response = post_json(
|
||||
f"{config.server_url}/api/v1/approval/respond",
|
||||
{
|
||||
"approval_id": approval_id,
|
||||
"approved": text.startswith("/approve"),
|
||||
"actor": user_id,
|
||||
},
|
||||
)
|
||||
if response.get("status") != "pending":
|
||||
state["pending_approvals"].pop(str(chat_id), None)
|
||||
api.send_message(
|
||||
chat_id,
|
||||
f"Approval {approval_id}: {response.get('status')}",
|
||||
)
|
||||
return
|
||||
|
||||
if text == "/auth":
|
||||
start_auth_flow(api, config, state, chat_id, force_new=True)
|
||||
return
|
||||
|
|
@ -324,7 +410,9 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
|||
"Сервер доступен.\n"
|
||||
f"OAuth: {'configured' if status.get('authenticated') else 'not configured'}\n"
|
||||
f"resource_url: {status.get('resource_url')}\n"
|
||||
f"expires_at: {status.get('expires_at')}",
|
||||
f"expires_at: {status.get('expires_at')}\n"
|
||||
f"tool_policy: {status.get('tool_policy')}\n"
|
||||
f"pending_approvals: {status.get('pending_approvals')}",
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -359,7 +447,7 @@ def handle_message(api: TelegramAPI, config: BotConfig, state: dict[str, Any], m
|
|||
api.send_message(chat_id, "Сообщение поставлено в очередь до завершения авторизации.")
|
||||
return
|
||||
|
||||
deliver_chat_message(api, config, state, chat_id, user_id, session_key, text)
|
||||
start_chat_job(api, config, state, chat_id, user_id, session_key, text)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
@ -370,7 +458,11 @@ def main() -> None:
|
|||
while True:
|
||||
try:
|
||||
process_auth_flows(api, config, state)
|
||||
updates = api.get_updates(state.get("offset"), config.poll_timeout)
|
||||
process_active_jobs(api, config, state)
|
||||
timeout = config.poll_timeout
|
||||
if state.get("active_jobs"):
|
||||
timeout = min(timeout, 3)
|
||||
updates = api.get_updates(state.get("offset"), timeout)
|
||||
for update in updates:
|
||||
state["offset"] = update["update_id"] + 1
|
||||
message = update.get("message")
|
||||
|
|
|
|||
|
|
@ -9,3 +9,4 @@ NEW_QWEN_MAX_TOOL_ROUNDS=8
|
|||
NEW_QWEN_MAX_FILE_READ_BYTES=200000
|
||||
NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES=12000
|
||||
NEW_QWEN_TOOL_POLICY=full-access
|
||||
NEW_QWEN_APPROVAL_TIMEOUT_SECONDS=3600
|
||||
|
|
|
|||
53
serv/app.py
53
serv/app.py
|
|
@ -10,6 +10,7 @@ from pathlib import Path
|
|||
from typing import Any
|
||||
|
||||
from config import ServerConfig
|
||||
from approvals import ApprovalStore
|
||||
from jobs import JobStore
|
||||
from llm import QwenAgent
|
||||
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
|
||||
|
|
@ -25,6 +26,7 @@ class AppState:
|
|||
self.tools = ToolRegistry(config)
|
||||
self.agent = QwenAgent(config, self.oauth, self.tools)
|
||||
self.jobs = JobStore(config.state_dir / "jobs")
|
||||
self.approvals = ApprovalStore(config.state_dir / "approvals")
|
||||
self.pending_flows_path = config.state_dir / "oauth_flows.json"
|
||||
self.pending_device_flows: dict[str, DeviceAuthState] = self._load_pending_flows()
|
||||
self.lock = threading.Lock()
|
||||
|
|
@ -83,6 +85,7 @@ class AppState:
|
|||
"authenticated": False,
|
||||
"tool_policy": self.config.tool_policy,
|
||||
"pending_flows": len(self.pending_device_flows),
|
||||
"pending_approvals": len(self.approvals.list_pending()),
|
||||
}
|
||||
return {
|
||||
"authenticated": True,
|
||||
|
|
@ -90,6 +93,7 @@ class AppState:
|
|||
"expires_at": creds.get("expiry_date"),
|
||||
"tool_policy": self.config.tool_policy,
|
||||
"pending_flows": len(self.pending_device_flows),
|
||||
"pending_approvals": len(self.approvals.list_pending()),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -125,6 +129,9 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
if self.path == "/api/v1/sessions":
|
||||
self._send(HTTPStatus.OK, {"sessions": self.app.sessions.list_sessions()})
|
||||
return
|
||||
if self.path == "/api/v1/approvals":
|
||||
self._send(HTTPStatus.OK, {"approvals": self.app.approvals.list_pending()})
|
||||
return
|
||||
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
||||
|
||||
def _run_chat_job(self, job_id: str, session_id: str, user_id: str, message: str) -> None:
|
||||
|
|
@ -140,6 +147,11 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
history,
|
||||
message,
|
||||
on_event=lambda event: self.app.jobs.append_event(job_id, event),
|
||||
approval_callback=lambda tool_name, arguments: self._request_approval(
|
||||
job_id,
|
||||
tool_name,
|
||||
arguments,
|
||||
),
|
||||
)
|
||||
persisted_messages = result["messages"][1:]
|
||||
self.app.sessions.save(
|
||||
|
|
@ -168,6 +180,34 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
)
|
||||
self.app.jobs.fail(job_id, str(exc))
|
||||
|
||||
def _request_approval(
|
||||
self,
|
||||
job_id: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
approval = self.app.approvals.create(
|
||||
job_id=job_id,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
self.app.jobs.append_event(
|
||||
job_id,
|
||||
{
|
||||
"type": "approval_request",
|
||||
"approval_id": approval["approval_id"],
|
||||
"tool_name": tool_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
)
|
||||
self.app.jobs.set_status(job_id, "waiting_approval")
|
||||
decision = self.app.approvals.wait(
|
||||
approval["approval_id"],
|
||||
timeout_seconds=float(self.app.config.approval_timeout_seconds),
|
||||
)
|
||||
self.app.jobs.set_status(job_id, "running")
|
||||
return decision
|
||||
|
||||
def do_POST(self) -> None:
|
||||
try:
|
||||
if self.path == "/api/v1/auth/device/start":
|
||||
|
|
@ -284,6 +324,19 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
)
|
||||
return
|
||||
|
||||
if self.path == "/api/v1/approval/respond":
|
||||
body = self._json_body()
|
||||
approval_id = body["approval_id"]
|
||||
approved = bool(body["approved"])
|
||||
actor = str(body.get("actor") or "unknown")
|
||||
approval = self.app.approvals.respond(
|
||||
approval_id,
|
||||
approved=approved,
|
||||
actor=actor,
|
||||
)
|
||||
self._send(HTTPStatus.OK, approval)
|
||||
return
|
||||
|
||||
if self.path == "/api/v1/session/get":
|
||||
body = self._json_body()
|
||||
session_id = body["session_id"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,130 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ApprovalStore:
|
||||
def __init__(self, base_dir: Path) -> None:
|
||||
self.base_dir = base_dir
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._approvals: dict[str, dict[str, Any]] = {}
|
||||
self._conditions: dict[str, threading.Condition] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._load_existing()
|
||||
|
||||
def _path(self, approval_id: str) -> Path:
|
||||
return self.base_dir / f"{approval_id}.json"
|
||||
|
||||
def _save(self, approval: dict[str, Any]) -> None:
|
||||
self._path(approval["approval_id"]).write_text(
|
||||
json.dumps(approval, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def _load_existing(self) -> None:
|
||||
for path in sorted(self.base_dir.glob("*.json")):
|
||||
try:
|
||||
approval = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
continue
|
||||
if approval.get("status") == "pending":
|
||||
approval["status"] = "rejected"
|
||||
approval["reason"] = "Server restarted while waiting for approval"
|
||||
approval["updated_at"] = time.time()
|
||||
path.write_text(
|
||||
json.dumps(approval, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
self._approvals[approval["approval_id"]] = approval
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
job_id: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
approval_id = uuid.uuid4().hex
|
||||
approval = {
|
||||
"approval_id": approval_id,
|
||||
"job_id": job_id,
|
||||
"tool_name": tool_name,
|
||||
"arguments": arguments,
|
||||
"status": "pending",
|
||||
"created_at": time.time(),
|
||||
"updated_at": time.time(),
|
||||
"actor": None,
|
||||
"reason": None,
|
||||
}
|
||||
with self._lock:
|
||||
self._approvals[approval_id] = approval
|
||||
self._conditions[approval_id] = threading.Condition(self._lock)
|
||||
self._save(approval)
|
||||
return approval.copy()
|
||||
|
||||
def get(self, approval_id: str) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
approval = self._approvals.get(approval_id)
|
||||
return approval.copy() if approval else None
|
||||
|
||||
def list_pending(self) -> list[dict[str, Any]]:
|
||||
with self._lock:
|
||||
pending = [
|
||||
approval.copy()
|
||||
for approval in self._approvals.values()
|
||||
if approval.get("status") == "pending"
|
||||
]
|
||||
pending.sort(key=lambda item: item.get("created_at", 0))
|
||||
return pending
|
||||
|
||||
def respond(
|
||||
self,
|
||||
approval_id: str,
|
||||
*,
|
||||
approved: bool,
|
||||
actor: str,
|
||||
) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
approval = self._approvals.get(approval_id)
|
||||
if not approval:
|
||||
raise KeyError("Unknown approval_id")
|
||||
if approval["status"] != "pending":
|
||||
return approval.copy()
|
||||
approval["status"] = "approved" if approved else "rejected"
|
||||
approval["actor"] = actor
|
||||
approval["updated_at"] = time.time()
|
||||
approval["reason"] = None if approved else "Rejected by operator"
|
||||
self._save(approval)
|
||||
condition = self._conditions.get(approval_id)
|
||||
if condition:
|
||||
condition.notify_all()
|
||||
return approval.copy()
|
||||
|
||||
def wait(self, approval_id: str, timeout_seconds: float = 3600.0) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
approval = self._approvals.get(approval_id)
|
||||
if not approval:
|
||||
raise KeyError("Unknown approval_id")
|
||||
if approval["status"] != "pending":
|
||||
return approval.copy()
|
||||
condition = self._conditions.setdefault(
|
||||
approval_id,
|
||||
threading.Condition(self._lock),
|
||||
)
|
||||
deadline = time.time() + timeout_seconds
|
||||
while approval["status"] == "pending":
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
approval["status"] = "rejected"
|
||||
approval["reason"] = "Approval timeout"
|
||||
approval["updated_at"] = time.time()
|
||||
self._save(approval)
|
||||
break
|
||||
condition.wait(timeout=remaining)
|
||||
return approval.copy()
|
||||
|
||||
|
|
@ -29,6 +29,7 @@ class ServerConfig:
|
|||
max_file_read_bytes: int
|
||||
max_command_output_bytes: int
|
||||
tool_policy: str
|
||||
approval_timeout_seconds: int
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> "ServerConfig":
|
||||
|
|
@ -65,4 +66,7 @@ class ServerConfig:
|
|||
os.environ.get("NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES", "12000")
|
||||
),
|
||||
tool_policy=os.environ.get("NEW_QWEN_TOOL_POLICY", "full-access").strip(),
|
||||
approval_timeout_seconds=int(
|
||||
os.environ.get("NEW_QWEN_APPROVAL_TIMEOUT_SECONDS", "3600")
|
||||
),
|
||||
)
|
||||
|
|
|
|||
25
serv/llm.py
25
serv/llm.py
|
|
@ -54,6 +54,7 @@ class QwenAgent:
|
|||
history: list[dict[str, Any]],
|
||||
user_message: str,
|
||||
on_event: Callable[[dict[str, Any]], None] | None = None,
|
||||
approval_callback: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
emit = on_event or (lambda _event: None)
|
||||
system_prompt = self.config.system_prompt or DEFAULT_SYSTEM_PROMPT
|
||||
|
|
@ -99,6 +100,30 @@ class QwenAgent:
|
|||
tool_call_event = {"type": "tool_call", "name": tool_name, "arguments": arguments}
|
||||
events.append(tool_call_event)
|
||||
emit(tool_call_event)
|
||||
if approval_callback and self.tools.requires_approval(tool_name):
|
||||
approval_result = approval_callback(tool_name, arguments)
|
||||
approval_event = {
|
||||
"type": "approval_result",
|
||||
"tool_name": tool_name,
|
||||
"approval_id": approval_result["approval_id"],
|
||||
"status": approval_result["status"],
|
||||
"actor": approval_result.get("actor"),
|
||||
}
|
||||
events.append(approval_event)
|
||||
emit(approval_event)
|
||||
if approval_result["status"] != "approved":
|
||||
result = {"error": f"Tool '{tool_name}' was rejected by operator"}
|
||||
tool_result_event = {"type": "tool_result", "name": tool_name, "result": result}
|
||||
events.append(tool_result_event)
|
||||
emit(tool_result_event)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": call["id"],
|
||||
"content": self.tools.encode_result(result),
|
||||
}
|
||||
)
|
||||
continue
|
||||
try:
|
||||
result = self.tools.execute(tool_name, arguments)
|
||||
except ToolError as exc:
|
||||
|
|
|
|||
|
|
@ -182,17 +182,35 @@ class ToolRegistry:
|
|||
def _check_policy(self, tool_name: str) -> None:
|
||||
policy = self.config.tool_policy
|
||||
read_only_tools = {"list_files", "glob_search", "grep_text", "stat_path", "read_file"}
|
||||
write_tools = {"replace_in_file", "write_file", "make_directory"}
|
||||
shell_tools = {"exec_command"}
|
||||
if policy == "full-access":
|
||||
if policy in {"full-access", "ask-shell", "ask-write", "ask-all"}:
|
||||
return
|
||||
if policy == "read-only" and tool_name not in read_only_tools:
|
||||
raise ToolError(f"Tool '{tool_name}' is blocked by read-only policy")
|
||||
if policy == "workspace-write" and tool_name in shell_tools:
|
||||
raise ToolError(f"Tool '{tool_name}' is blocked by workspace-write policy")
|
||||
if policy not in {"full-access", "workspace-write", "read-only"}:
|
||||
if policy not in {
|
||||
"full-access",
|
||||
"workspace-write",
|
||||
"read-only",
|
||||
"ask-shell",
|
||||
"ask-write",
|
||||
"ask-all",
|
||||
}:
|
||||
raise ToolError(f"Unknown tool policy: {policy}")
|
||||
|
||||
def requires_approval(self, tool_name: str) -> bool:
|
||||
policy = self.config.tool_policy
|
||||
write_tools = {"replace_in_file", "write_file", "make_directory"}
|
||||
shell_tools = {"exec_command"}
|
||||
if policy == "ask-all":
|
||||
return True
|
||||
if policy == "ask-shell":
|
||||
return tool_name in shell_tools
|
||||
if policy == "ask-write":
|
||||
return tool_name in shell_tools or tool_name in write_tools
|
||||
return False
|
||||
|
||||
def execute(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
handler = self._handlers.get(name)
|
||||
if not handler:
|
||||
|
|
|
|||
Loading…
Reference in New Issue