181 lines
6.4 KiB
Python
181 lines
6.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
class ApprovalStore:
|
|
def __init__(self, base_dir: Path, retention_seconds: int) -> None:
|
|
self.base_dir = base_dir
|
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
self.retention_seconds = retention_seconds
|
|
self._approvals: dict[str, dict[str, Any]] = {}
|
|
self._conditions: dict[str, threading.Condition] = {}
|
|
self._lock = threading.RLock()
|
|
self._load_existing()
|
|
self.cleanup()
|
|
|
|
def _path(self, approval_id: str) -> Path:
|
|
return self.base_dir / f"{approval_id}.json"
|
|
|
|
def _save(self, approval: dict[str, Any]) -> None:
|
|
self._path(approval["approval_id"]).write_text(
|
|
json.dumps(approval, ensure_ascii=False, indent=2),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
def _load_existing(self) -> None:
|
|
now = time.time()
|
|
for path in sorted(self.base_dir.glob("*.json")):
|
|
try:
|
|
approval = json.loads(path.read_text(encoding="utf-8"))
|
|
except (OSError, json.JSONDecodeError):
|
|
continue
|
|
updated_at = float(approval.get("updated_at") or approval.get("created_at") or now)
|
|
if now - updated_at > self.retention_seconds:
|
|
try:
|
|
path.unlink()
|
|
except OSError:
|
|
pass
|
|
continue
|
|
if approval.get("status") == "pending":
|
|
approval["status"] = "rejected"
|
|
approval["reason"] = "Server restarted while waiting for approval"
|
|
approval["updated_at"] = time.time()
|
|
path.write_text(
|
|
json.dumps(approval, ensure_ascii=False, indent=2),
|
|
encoding="utf-8",
|
|
)
|
|
self._approvals[approval["approval_id"]] = approval
|
|
|
|
def create(
|
|
self,
|
|
*,
|
|
job_id: str,
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
approval_id = uuid.uuid4().hex
|
|
approval = {
|
|
"approval_id": approval_id,
|
|
"job_id": job_id,
|
|
"tool_name": tool_name,
|
|
"arguments": arguments,
|
|
"status": "pending",
|
|
"created_at": time.time(),
|
|
"updated_at": time.time(),
|
|
"actor": None,
|
|
"reason": None,
|
|
}
|
|
with self._lock:
|
|
self._approvals[approval_id] = approval
|
|
self._conditions[approval_id] = threading.Condition(self._lock)
|
|
self._save(approval)
|
|
return approval.copy()
|
|
|
|
def get(self, approval_id: str) -> dict[str, Any] | None:
|
|
with self._lock:
|
|
approval = self._approvals.get(approval_id)
|
|
return approval.copy() if approval else None
|
|
|
|
def list_pending(self) -> list[dict[str, Any]]:
|
|
with self._lock:
|
|
pending = [
|
|
approval.copy()
|
|
for approval in self._approvals.values()
|
|
if approval.get("status") == "pending"
|
|
]
|
|
pending.sort(key=lambda item: item.get("created_at", 0))
|
|
return pending
|
|
|
|
def respond(
|
|
self,
|
|
approval_id: str,
|
|
*,
|
|
approved: bool,
|
|
actor: str,
|
|
) -> dict[str, Any]:
|
|
with self._lock:
|
|
approval = self._approvals.get(approval_id)
|
|
if not approval:
|
|
raise KeyError("Unknown approval_id")
|
|
if approval["status"] != "pending":
|
|
return approval.copy()
|
|
approval["status"] = "approved" if approved else "rejected"
|
|
approval["actor"] = actor
|
|
approval["updated_at"] = time.time()
|
|
approval["reason"] = None if approved else "Rejected by operator"
|
|
self._save(approval)
|
|
condition = self._conditions.get(approval_id)
|
|
if condition:
|
|
condition.notify_all()
|
|
return approval.copy()
|
|
|
|
def wait(self, approval_id: str, timeout_seconds: float = 3600.0) -> dict[str, Any]:
|
|
with self._lock:
|
|
approval = self._approvals.get(approval_id)
|
|
if not approval:
|
|
raise KeyError("Unknown approval_id")
|
|
if approval["status"] != "pending":
|
|
return approval.copy()
|
|
condition = self._conditions.setdefault(
|
|
approval_id,
|
|
threading.Condition(self._lock),
|
|
)
|
|
deadline = time.time() + timeout_seconds
|
|
while approval["status"] == "pending":
|
|
remaining = deadline - time.time()
|
|
if remaining <= 0:
|
|
approval["status"] = "rejected"
|
|
approval["reason"] = "Approval timeout"
|
|
approval["updated_at"] = time.time()
|
|
self._save(approval)
|
|
break
|
|
condition.wait(timeout=remaining)
|
|
return approval.copy()
|
|
|
|
def reject_pending_for_job(
|
|
self,
|
|
job_id: str,
|
|
*,
|
|
actor: str,
|
|
reason: str,
|
|
) -> list[dict[str, Any]]:
|
|
rejected: list[dict[str, Any]] = []
|
|
with self._lock:
|
|
for approval_id, approval in self._approvals.items():
|
|
if approval.get("job_id") != job_id or approval.get("status") != "pending":
|
|
continue
|
|
approval["status"] = "rejected"
|
|
approval["actor"] = actor
|
|
approval["reason"] = reason
|
|
approval["updated_at"] = time.time()
|
|
self._save(approval)
|
|
condition = self._conditions.get(approval_id)
|
|
if condition:
|
|
condition.notify_all()
|
|
rejected.append(approval.copy())
|
|
return rejected
|
|
|
|
def cleanup(self) -> None:
|
|
now = time.time()
|
|
with self._lock:
|
|
expired = [
|
|
approval_id
|
|
for approval_id, approval in self._approvals.items()
|
|
if now
|
|
- float(approval.get("updated_at") or approval.get("created_at") or now)
|
|
> self.retention_seconds
|
|
]
|
|
for approval_id in expired:
|
|
self._approvals.pop(approval_id, None)
|
|
self._conditions.pop(approval_id, None)
|
|
try:
|
|
self._path(approval_id).unlink()
|
|
except OSError:
|
|
pass
|