Persist server state and add tool policy
This commit is contained in:
parent
940bef2f4a
commit
ac7f1bd493
10
README.md
10
README.md
|
|
@ -35,6 +35,8 @@ Qwen OAuth + OpenAI-compatible endpoint
|
|||
- автоматический polling OAuth flow в боте
|
||||
- очередь сообщений, пришедших до завершения OAuth
|
||||
- job-based chat polling между `bot` и `serv`
|
||||
- persistence для chat jobs и pending OAuth flows на стороне `serv`
|
||||
- policy mode для инструментов: `full-access`, `workspace-write`, `read-only`
|
||||
|
||||
## Ограничения текущей реализации
|
||||
|
||||
|
|
@ -50,6 +52,14 @@ Qwen OAuth + OpenAI-compatible endpoint
|
|||
cp serv/.env.example serv/.env
|
||||
```
|
||||
|
||||
Ключевые параметры сервера:
|
||||
|
||||
- `NEW_QWEN_STATE_DIR` - где хранить jobs и pending OAuth flows
|
||||
- `NEW_QWEN_TOOL_POLICY` - режим инструментов:
|
||||
`full-access` - все инструменты
|
||||
`workspace-write` - без `exec_command`
|
||||
`read-only` - только чтение и поиск
|
||||
|
||||
Бот:
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ NEW_QWEN_PORT=8080
|
|||
NEW_QWEN_MODEL=qwen3.6-plus
|
||||
NEW_QWEN_WORKSPACE_ROOT=/home/mirivlad/git
|
||||
NEW_QWEN_SESSION_DIR=/home/mirivlad/git/new-qwen/.new-qwen/sessions
|
||||
NEW_QWEN_STATE_DIR=/home/mirivlad/git/new-qwen/.new-qwen/state
|
||||
NEW_QWEN_SYSTEM_PROMPT=
|
||||
NEW_QWEN_MAX_TOOL_ROUNDS=8
|
||||
NEW_QWEN_MAX_FILE_READ_BYTES=200000
|
||||
NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES=12000
|
||||
NEW_QWEN_TOOL_POLICY=full-access
|
||||
|
|
|
|||
64
serv/app.py
64
serv/app.py
|
|
@ -6,6 +6,7 @@ import time
|
|||
import uuid
|
||||
from http import HTTPStatus
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from config import ServerConfig
|
||||
|
|
@ -23,18 +24,72 @@ class AppState:
|
|||
self.sessions = SessionStore(config.session_dir)
|
||||
self.tools = ToolRegistry(config)
|
||||
self.agent = QwenAgent(config, self.oauth, self.tools)
|
||||
self.jobs = JobStore()
|
||||
self.pending_device_flows: dict[str, DeviceAuthState] = {}
|
||||
self.jobs = JobStore(config.state_dir / "jobs")
|
||||
self.pending_flows_path = config.state_dir / "oauth_flows.json"
|
||||
self.pending_device_flows: dict[str, DeviceAuthState] = self._load_pending_flows()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def _load_pending_flows(self) -> dict[str, DeviceAuthState]:
|
||||
self.config.state_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not self.pending_flows_path.exists():
|
||||
return {}
|
||||
try:
|
||||
payload = json.loads(self.pending_flows_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
flows: dict[str, DeviceAuthState] = {}
|
||||
now = time.time()
|
||||
for flow_id, value in payload.items():
|
||||
try:
|
||||
state = DeviceAuthState(
|
||||
device_code=value["device_code"],
|
||||
code_verifier=value["code_verifier"],
|
||||
user_code=value["user_code"],
|
||||
verification_uri=value["verification_uri"],
|
||||
verification_uri_complete=value["verification_uri_complete"],
|
||||
expires_at=float(value["expires_at"]),
|
||||
interval_seconds=float(value.get("interval_seconds", 2.0)),
|
||||
)
|
||||
except KeyError:
|
||||
continue
|
||||
if state.expires_at > now:
|
||||
flows[flow_id] = state
|
||||
return flows
|
||||
|
||||
def _save_pending_flows(self) -> None:
|
||||
self.config.state_dir.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
flow_id: {
|
||||
"device_code": state.device_code,
|
||||
"code_verifier": state.code_verifier,
|
||||
"user_code": state.user_code,
|
||||
"verification_uri": state.verification_uri,
|
||||
"verification_uri_complete": state.verification_uri_complete,
|
||||
"expires_at": state.expires_at,
|
||||
"interval_seconds": state.interval_seconds,
|
||||
}
|
||||
for flow_id, state in self.pending_device_flows.items()
|
||||
if state.expires_at > time.time()
|
||||
}
|
||||
self.pending_flows_path.write_text(
|
||||
json.dumps(payload, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def auth_status(self) -> dict[str, Any]:
|
||||
creds = self.oauth.load_credentials()
|
||||
if not creds:
|
||||
return {"authenticated": False}
|
||||
return {
|
||||
"authenticated": False,
|
||||
"tool_policy": self.config.tool_policy,
|
||||
"pending_flows": len(self.pending_device_flows),
|
||||
}
|
||||
return {
|
||||
"authenticated": True,
|
||||
"resource_url": creds.get("resource_url"),
|
||||
"expires_at": creds.get("expiry_date"),
|
||||
"tool_policy": self.config.tool_policy,
|
||||
"pending_flows": len(self.pending_device_flows),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -120,6 +175,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
state = self.app.oauth.start_device_flow(open_browser=False)
|
||||
with self.app.lock:
|
||||
self.app.pending_device_flows[flow_id] = state
|
||||
self.app._save_pending_flows()
|
||||
self._send(
|
||||
HTTPStatus.OK,
|
||||
{
|
||||
|
|
@ -147,6 +203,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
return
|
||||
with self.app.lock:
|
||||
self.app.pending_device_flows.pop(flow_id, None)
|
||||
self.app._save_pending_flows()
|
||||
self._send(HTTPStatus.OK, {"done": True, "credentials": {"resource_url": creds.get("resource_url")}})
|
||||
return
|
||||
|
||||
|
|
@ -255,6 +312,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
def main() -> None:
|
||||
config = ServerConfig.load()
|
||||
config.session_dir.mkdir(parents=True, exist_ok=True)
|
||||
config.state_dir.mkdir(parents=True, exist_ok=True)
|
||||
httpd = ThreadingHTTPServer((config.host, config.port), RequestHandler)
|
||||
httpd.app_state = AppState(config) # type: ignore[attr-defined]
|
||||
print(f"new-qwen serv listening on http://{config.host}:{config.port}")
|
||||
|
|
|
|||
|
|
@ -23,10 +23,12 @@ class ServerConfig:
|
|||
model: str
|
||||
workspace_root: Path
|
||||
session_dir: Path
|
||||
state_dir: Path
|
||||
system_prompt: str
|
||||
max_tool_rounds: int
|
||||
max_file_read_bytes: int
|
||||
max_command_output_bytes: int
|
||||
tool_policy: str
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> "ServerConfig":
|
||||
|
|
@ -41,12 +43,19 @@ class ServerConfig:
|
|||
str(base_dir.parent / ".new-qwen" / "sessions"),
|
||||
)
|
||||
).expanduser()
|
||||
state_dir = Path(
|
||||
os.environ.get(
|
||||
"NEW_QWEN_STATE_DIR",
|
||||
str(base_dir.parent / ".new-qwen" / "state"),
|
||||
)
|
||||
).expanduser()
|
||||
return cls(
|
||||
host=os.environ.get("NEW_QWEN_HOST", "127.0.0.1"),
|
||||
port=int(os.environ.get("NEW_QWEN_PORT", "8080")),
|
||||
model=os.environ.get("NEW_QWEN_MODEL", "qwen3.6-plus"),
|
||||
workspace_root=workspace_root.resolve(),
|
||||
session_dir=session_dir.resolve(),
|
||||
state_dir=state_dir.resolve(),
|
||||
system_prompt=os.environ.get("NEW_QWEN_SYSTEM_PROMPT", "").strip(),
|
||||
max_tool_rounds=int(os.environ.get("NEW_QWEN_MAX_TOOL_ROUNDS", "8")),
|
||||
max_file_read_bytes=int(
|
||||
|
|
@ -55,4 +64,5 @@ class ServerConfig:
|
|||
max_command_output_bytes=int(
|
||||
os.environ.get("NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES", "12000")
|
||||
),
|
||||
tool_policy=os.environ.get("NEW_QWEN_TOOL_POLICY", "full-access").strip(),
|
||||
)
|
||||
|
|
|
|||
34
serv/jobs.py
34
serv/jobs.py
|
|
@ -1,15 +1,41 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class JobStore:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, base_dir: Path) -> None:
|
||||
self.base_dir = base_dir
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._jobs: dict[str, dict[str, Any]] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._load_existing()
|
||||
|
||||
def _path(self, job_id: str) -> Path:
|
||||
return self.base_dir / f"{job_id}.json"
|
||||
|
||||
def _save_job(self, job: dict[str, Any]) -> None:
|
||||
self._path(job["job_id"]).write_text(
|
||||
json.dumps(job, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def _load_existing(self) -> None:
|
||||
for path in sorted(self.base_dir.glob("*.json")):
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
continue
|
||||
if payload.get("status") in {"queued", "running"}:
|
||||
payload["status"] = "failed"
|
||||
payload["error"] = "Server restarted while job was running"
|
||||
payload["updated_at"] = time.time()
|
||||
self._jobs[payload["job_id"]] = payload
|
||||
|
||||
def create(self, session_id: str, user_id: str, message: str) -> dict[str, Any]:
|
||||
job_id = uuid.uuid4().hex
|
||||
|
|
@ -28,6 +54,7 @@ class JobStore:
|
|||
}
|
||||
with self._lock:
|
||||
self._jobs[job_id] = job
|
||||
self._save_job(job)
|
||||
return job
|
||||
|
||||
def get(self, job_id: str) -> dict[str, Any] | None:
|
||||
|
|
@ -46,12 +73,14 @@ class JobStore:
|
|||
seq = len(job["events"]) + 1
|
||||
job["events"].append({"seq": seq, **event})
|
||||
job["updated_at"] = time.time()
|
||||
self._save_job(job)
|
||||
|
||||
def set_status(self, job_id: str, status: str) -> None:
|
||||
with self._lock:
|
||||
job = self._jobs[job_id]
|
||||
job["status"] = status
|
||||
job["updated_at"] = time.time()
|
||||
self._save_job(job)
|
||||
|
||||
def finish(
|
||||
self,
|
||||
|
|
@ -66,6 +95,7 @@ class JobStore:
|
|||
job["answer"] = answer
|
||||
job["usage"] = usage
|
||||
job["updated_at"] = time.time()
|
||||
self._save_job(job)
|
||||
|
||||
def fail(self, job_id: str, error_message: str) -> None:
|
||||
with self._lock:
|
||||
|
|
@ -73,4 +103,4 @@ class JobStore:
|
|||
job["status"] = "failed"
|
||||
job["error"] = error_message
|
||||
job["updated_at"] = time.time()
|
||||
|
||||
self._save_job(job)
|
||||
|
|
|
|||
|
|
@ -179,10 +179,25 @@ class ToolRegistry:
|
|||
raise ToolError("Path escapes workspace root")
|
||||
return resolved
|
||||
|
||||
def _check_policy(self, tool_name: str) -> None:
|
||||
policy = self.config.tool_policy
|
||||
read_only_tools = {"list_files", "glob_search", "grep_text", "stat_path", "read_file"}
|
||||
write_tools = {"replace_in_file", "write_file", "make_directory"}
|
||||
shell_tools = {"exec_command"}
|
||||
if policy == "full-access":
|
||||
return
|
||||
if policy == "read-only" and tool_name not in read_only_tools:
|
||||
raise ToolError(f"Tool '{tool_name}' is blocked by read-only policy")
|
||||
if policy == "workspace-write" and tool_name in shell_tools:
|
||||
raise ToolError(f"Tool '{tool_name}' is blocked by workspace-write policy")
|
||||
if policy not in {"full-access", "workspace-write", "read-only"}:
|
||||
raise ToolError(f"Unknown tool policy: {policy}")
|
||||
|
||||
def execute(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
handler = self._handlers.get(name)
|
||||
if not handler:
|
||||
raise ToolError(f"Unknown tool: {name}")
|
||||
self._check_policy(name)
|
||||
return handler(arguments)
|
||||
|
||||
def list_files(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue