new-qwen/serv/approvals.py

181 lines
6.4 KiB
Python

from __future__ import annotations
import json
import threading
import time
import uuid
from pathlib import Path
from typing import Any
class ApprovalStore:
def __init__(self, base_dir: Path, retention_seconds: int) -> None:
self.base_dir = base_dir
self.base_dir.mkdir(parents=True, exist_ok=True)
self.retention_seconds = retention_seconds
self._approvals: dict[str, dict[str, Any]] = {}
self._conditions: dict[str, threading.Condition] = {}
self._lock = threading.RLock()
self._load_existing()
self.cleanup()
def _path(self, approval_id: str) -> Path:
return self.base_dir / f"{approval_id}.json"
def _save(self, approval: dict[str, Any]) -> None:
self._path(approval["approval_id"]).write_text(
json.dumps(approval, ensure_ascii=False, indent=2),
encoding="utf-8",
)
def _load_existing(self) -> None:
now = time.time()
for path in sorted(self.base_dir.glob("*.json")):
try:
approval = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
continue
updated_at = float(approval.get("updated_at") or approval.get("created_at") or now)
if now - updated_at > self.retention_seconds:
try:
path.unlink()
except OSError:
pass
continue
if approval.get("status") == "pending":
approval["status"] = "rejected"
approval["reason"] = "Server restarted while waiting for approval"
approval["updated_at"] = time.time()
path.write_text(
json.dumps(approval, ensure_ascii=False, indent=2),
encoding="utf-8",
)
self._approvals[approval["approval_id"]] = approval
def create(
self,
*,
job_id: str,
tool_name: str,
arguments: dict[str, Any],
) -> dict[str, Any]:
approval_id = uuid.uuid4().hex
approval = {
"approval_id": approval_id,
"job_id": job_id,
"tool_name": tool_name,
"arguments": arguments,
"status": "pending",
"created_at": time.time(),
"updated_at": time.time(),
"actor": None,
"reason": None,
}
with self._lock:
self._approvals[approval_id] = approval
self._conditions[approval_id] = threading.Condition(self._lock)
self._save(approval)
return approval.copy()
def get(self, approval_id: str) -> dict[str, Any] | None:
with self._lock:
approval = self._approvals.get(approval_id)
return approval.copy() if approval else None
def list_pending(self) -> list[dict[str, Any]]:
with self._lock:
pending = [
approval.copy()
for approval in self._approvals.values()
if approval.get("status") == "pending"
]
pending.sort(key=lambda item: item.get("created_at", 0))
return pending
def respond(
self,
approval_id: str,
*,
approved: bool,
actor: str,
) -> dict[str, Any]:
with self._lock:
approval = self._approvals.get(approval_id)
if not approval:
raise KeyError("Unknown approval_id")
if approval["status"] != "pending":
return approval.copy()
approval["status"] = "approved" if approved else "rejected"
approval["actor"] = actor
approval["updated_at"] = time.time()
approval["reason"] = None if approved else "Rejected by operator"
self._save(approval)
condition = self._conditions.get(approval_id)
if condition:
condition.notify_all()
return approval.copy()
def wait(self, approval_id: str, timeout_seconds: float = 3600.0) -> dict[str, Any]:
with self._lock:
approval = self._approvals.get(approval_id)
if not approval:
raise KeyError("Unknown approval_id")
if approval["status"] != "pending":
return approval.copy()
condition = self._conditions.setdefault(
approval_id,
threading.Condition(self._lock),
)
deadline = time.time() + timeout_seconds
while approval["status"] == "pending":
remaining = deadline - time.time()
if remaining <= 0:
approval["status"] = "rejected"
approval["reason"] = "Approval timeout"
approval["updated_at"] = time.time()
self._save(approval)
break
condition.wait(timeout=remaining)
return approval.copy()
def reject_pending_for_job(
self,
job_id: str,
*,
actor: str,
reason: str,
) -> list[dict[str, Any]]:
rejected: list[dict[str, Any]] = []
with self._lock:
for approval_id, approval in self._approvals.items():
if approval.get("job_id") != job_id or approval.get("status") != "pending":
continue
approval["status"] = "rejected"
approval["actor"] = actor
approval["reason"] = reason
approval["updated_at"] = time.time()
self._save(approval)
condition = self._conditions.get(approval_id)
if condition:
condition.notify_all()
rejected.append(approval.copy())
return rejected
def cleanup(self) -> None:
now = time.time()
with self._lock:
expired = [
approval_id
for approval_id, approval in self._approvals.items()
if now
- float(approval.get("updated_at") or approval.get("created_at") or now)
> self.retention_seconds
]
for approval_id in expired:
self._approvals.pop(approval_id, None)
self._conditions.pop(approval_id, None)
try:
self._path(approval_id).unlink()
except OSError:
pass