Persist server state and add tool policy

This commit is contained in:
mirivlad 2026-04-07 17:11:29 +08:00
parent 940bef2f4a
commit ac7f1bd493
6 changed files with 130 additions and 5 deletions

View File

@ -35,6 +35,8 @@ Qwen OAuth + OpenAI-compatible endpoint
- автоматический polling OAuth flow в боте - автоматический polling OAuth flow в боте
- очередь сообщений, пришедших до завершения OAuth - очередь сообщений, пришедших до завершения OAuth
- job-based chat polling между `bot` и `serv` - job-based chat polling между `bot` и `serv`
- persistence для chat jobs и pending OAuth flows на стороне `serv`
- policy mode для инструментов: `full-access`, `workspace-write`, `read-only`
## Ограничения текущей реализации ## Ограничения текущей реализации
@ -50,6 +52,14 @@ Qwen OAuth + OpenAI-compatible endpoint
cp serv/.env.example serv/.env cp serv/.env.example serv/.env
``` ```
Ключевые параметры сервера:
- `NEW_QWEN_STATE_DIR` - где хранить jobs и pending OAuth flows
- `NEW_QWEN_TOOL_POLICY` - режим инструментов:
`full-access` - все инструменты
`workspace-write` - без `exec_command`
`read-only` - только чтение и поиск
Бот: Бот:
```bash ```bash

View File

@ -3,7 +3,9 @@ NEW_QWEN_PORT=8080
NEW_QWEN_MODEL=qwen3.6-plus NEW_QWEN_MODEL=qwen3.6-plus
NEW_QWEN_WORKSPACE_ROOT=/home/mirivlad/git NEW_QWEN_WORKSPACE_ROOT=/home/mirivlad/git
NEW_QWEN_SESSION_DIR=/home/mirivlad/git/new-qwen/.new-qwen/sessions NEW_QWEN_SESSION_DIR=/home/mirivlad/git/new-qwen/.new-qwen/sessions
NEW_QWEN_STATE_DIR=/home/mirivlad/git/new-qwen/.new-qwen/state
NEW_QWEN_SYSTEM_PROMPT= NEW_QWEN_SYSTEM_PROMPT=
NEW_QWEN_MAX_TOOL_ROUNDS=8 NEW_QWEN_MAX_TOOL_ROUNDS=8
NEW_QWEN_MAX_FILE_READ_BYTES=200000 NEW_QWEN_MAX_FILE_READ_BYTES=200000
NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES=12000 NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES=12000
NEW_QWEN_TOOL_POLICY=full-access

View File

@ -6,6 +6,7 @@ import time
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any from typing import Any
from config import ServerConfig from config import ServerConfig
@ -23,18 +24,72 @@ class AppState:
self.sessions = SessionStore(config.session_dir) self.sessions = SessionStore(config.session_dir)
self.tools = ToolRegistry(config) self.tools = ToolRegistry(config)
self.agent = QwenAgent(config, self.oauth, self.tools) self.agent = QwenAgent(config, self.oauth, self.tools)
self.jobs = JobStore() self.jobs = JobStore(config.state_dir / "jobs")
self.pending_device_flows: dict[str, DeviceAuthState] = {} self.pending_flows_path = config.state_dir / "oauth_flows.json"
self.pending_device_flows: dict[str, DeviceAuthState] = self._load_pending_flows()
self.lock = threading.Lock() self.lock = threading.Lock()
def _load_pending_flows(self) -> dict[str, DeviceAuthState]:
self.config.state_dir.mkdir(parents=True, exist_ok=True)
if not self.pending_flows_path.exists():
return {}
try:
payload = json.loads(self.pending_flows_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {}
flows: dict[str, DeviceAuthState] = {}
now = time.time()
for flow_id, value in payload.items():
try:
state = DeviceAuthState(
device_code=value["device_code"],
code_verifier=value["code_verifier"],
user_code=value["user_code"],
verification_uri=value["verification_uri"],
verification_uri_complete=value["verification_uri_complete"],
expires_at=float(value["expires_at"]),
interval_seconds=float(value.get("interval_seconds", 2.0)),
)
except KeyError:
continue
if state.expires_at > now:
flows[flow_id] = state
return flows
def _save_pending_flows(self) -> None:
self.config.state_dir.mkdir(parents=True, exist_ok=True)
payload = {
flow_id: {
"device_code": state.device_code,
"code_verifier": state.code_verifier,
"user_code": state.user_code,
"verification_uri": state.verification_uri,
"verification_uri_complete": state.verification_uri_complete,
"expires_at": state.expires_at,
"interval_seconds": state.interval_seconds,
}
for flow_id, state in self.pending_device_flows.items()
if state.expires_at > time.time()
}
self.pending_flows_path.write_text(
json.dumps(payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
def auth_status(self) -> dict[str, Any]: def auth_status(self) -> dict[str, Any]:
creds = self.oauth.load_credentials() creds = self.oauth.load_credentials()
if not creds: if not creds:
return {"authenticated": False} return {
"authenticated": False,
"tool_policy": self.config.tool_policy,
"pending_flows": len(self.pending_device_flows),
}
return { return {
"authenticated": True, "authenticated": True,
"resource_url": creds.get("resource_url"), "resource_url": creds.get("resource_url"),
"expires_at": creds.get("expiry_date"), "expires_at": creds.get("expiry_date"),
"tool_policy": self.config.tool_policy,
"pending_flows": len(self.pending_device_flows),
} }
@ -120,6 +175,7 @@ class RequestHandler(BaseHTTPRequestHandler):
state = self.app.oauth.start_device_flow(open_browser=False) state = self.app.oauth.start_device_flow(open_browser=False)
with self.app.lock: with self.app.lock:
self.app.pending_device_flows[flow_id] = state self.app.pending_device_flows[flow_id] = state
self.app._save_pending_flows()
self._send( self._send(
HTTPStatus.OK, HTTPStatus.OK,
{ {
@ -147,6 +203,7 @@ class RequestHandler(BaseHTTPRequestHandler):
return return
with self.app.lock: with self.app.lock:
self.app.pending_device_flows.pop(flow_id, None) self.app.pending_device_flows.pop(flow_id, None)
self.app._save_pending_flows()
self._send(HTTPStatus.OK, {"done": True, "credentials": {"resource_url": creds.get("resource_url")}}) self._send(HTTPStatus.OK, {"done": True, "credentials": {"resource_url": creds.get("resource_url")}})
return return
@ -255,6 +312,7 @@ class RequestHandler(BaseHTTPRequestHandler):
def main() -> None: def main() -> None:
config = ServerConfig.load() config = ServerConfig.load()
config.session_dir.mkdir(parents=True, exist_ok=True) config.session_dir.mkdir(parents=True, exist_ok=True)
config.state_dir.mkdir(parents=True, exist_ok=True)
httpd = ThreadingHTTPServer((config.host, config.port), RequestHandler) httpd = ThreadingHTTPServer((config.host, config.port), RequestHandler)
httpd.app_state = AppState(config) # type: ignore[attr-defined] httpd.app_state = AppState(config) # type: ignore[attr-defined]
print(f"new-qwen serv listening on http://{config.host}:{config.port}") print(f"new-qwen serv listening on http://{config.host}:{config.port}")

View File

@ -23,10 +23,12 @@ class ServerConfig:
model: str model: str
workspace_root: Path workspace_root: Path
session_dir: Path session_dir: Path
state_dir: Path
system_prompt: str system_prompt: str
max_tool_rounds: int max_tool_rounds: int
max_file_read_bytes: int max_file_read_bytes: int
max_command_output_bytes: int max_command_output_bytes: int
tool_policy: str
@classmethod @classmethod
def load(cls) -> "ServerConfig": def load(cls) -> "ServerConfig":
@ -41,12 +43,19 @@ class ServerConfig:
str(base_dir.parent / ".new-qwen" / "sessions"), str(base_dir.parent / ".new-qwen" / "sessions"),
) )
).expanduser() ).expanduser()
state_dir = Path(
os.environ.get(
"NEW_QWEN_STATE_DIR",
str(base_dir.parent / ".new-qwen" / "state"),
)
).expanduser()
return cls( return cls(
host=os.environ.get("NEW_QWEN_HOST", "127.0.0.1"), host=os.environ.get("NEW_QWEN_HOST", "127.0.0.1"),
port=int(os.environ.get("NEW_QWEN_PORT", "8080")), port=int(os.environ.get("NEW_QWEN_PORT", "8080")),
model=os.environ.get("NEW_QWEN_MODEL", "qwen3.6-plus"), model=os.environ.get("NEW_QWEN_MODEL", "qwen3.6-plus"),
workspace_root=workspace_root.resolve(), workspace_root=workspace_root.resolve(),
session_dir=session_dir.resolve(), session_dir=session_dir.resolve(),
state_dir=state_dir.resolve(),
system_prompt=os.environ.get("NEW_QWEN_SYSTEM_PROMPT", "").strip(), system_prompt=os.environ.get("NEW_QWEN_SYSTEM_PROMPT", "").strip(),
max_tool_rounds=int(os.environ.get("NEW_QWEN_MAX_TOOL_ROUNDS", "8")), max_tool_rounds=int(os.environ.get("NEW_QWEN_MAX_TOOL_ROUNDS", "8")),
max_file_read_bytes=int( max_file_read_bytes=int(
@ -55,4 +64,5 @@ class ServerConfig:
max_command_output_bytes=int( max_command_output_bytes=int(
os.environ.get("NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES", "12000") os.environ.get("NEW_QWEN_MAX_COMMAND_OUTPUT_BYTES", "12000")
), ),
tool_policy=os.environ.get("NEW_QWEN_TOOL_POLICY", "full-access").strip(),
) )

View File

@ -1,15 +1,41 @@
from __future__ import annotations from __future__ import annotations
import json
import threading import threading
import time import time
import uuid import uuid
from pathlib import Path
from typing import Any from typing import Any
class JobStore: class JobStore:
def __init__(self) -> None: def __init__(self, base_dir: Path) -> None:
self.base_dir = base_dir
self.base_dir.mkdir(parents=True, exist_ok=True)
self._jobs: dict[str, dict[str, Any]] = {} self._jobs: dict[str, dict[str, Any]] = {}
self._lock = threading.RLock() self._lock = threading.RLock()
self._load_existing()
def _path(self, job_id: str) -> Path:
return self.base_dir / f"{job_id}.json"
def _save_job(self, job: dict[str, Any]) -> None:
self._path(job["job_id"]).write_text(
json.dumps(job, ensure_ascii=False, indent=2),
encoding="utf-8",
)
def _load_existing(self) -> None:
for path in sorted(self.base_dir.glob("*.json")):
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
continue
if payload.get("status") in {"queued", "running"}:
payload["status"] = "failed"
payload["error"] = "Server restarted while job was running"
payload["updated_at"] = time.time()
self._jobs[payload["job_id"]] = payload
def create(self, session_id: str, user_id: str, message: str) -> dict[str, Any]: def create(self, session_id: str, user_id: str, message: str) -> dict[str, Any]:
job_id = uuid.uuid4().hex job_id = uuid.uuid4().hex
@ -28,6 +54,7 @@ class JobStore:
} }
with self._lock: with self._lock:
self._jobs[job_id] = job self._jobs[job_id] = job
self._save_job(job)
return job return job
def get(self, job_id: str) -> dict[str, Any] | None: def get(self, job_id: str) -> dict[str, Any] | None:
@ -46,12 +73,14 @@ class JobStore:
seq = len(job["events"]) + 1 seq = len(job["events"]) + 1
job["events"].append({"seq": seq, **event}) job["events"].append({"seq": seq, **event})
job["updated_at"] = time.time() job["updated_at"] = time.time()
self._save_job(job)
def set_status(self, job_id: str, status: str) -> None: def set_status(self, job_id: str, status: str) -> None:
with self._lock: with self._lock:
job = self._jobs[job_id] job = self._jobs[job_id]
job["status"] = status job["status"] = status
job["updated_at"] = time.time() job["updated_at"] = time.time()
self._save_job(job)
def finish( def finish(
self, self,
@ -66,6 +95,7 @@ class JobStore:
job["answer"] = answer job["answer"] = answer
job["usage"] = usage job["usage"] = usage
job["updated_at"] = time.time() job["updated_at"] = time.time()
self._save_job(job)
def fail(self, job_id: str, error_message: str) -> None: def fail(self, job_id: str, error_message: str) -> None:
with self._lock: with self._lock:
@ -73,4 +103,4 @@ class JobStore:
job["status"] = "failed" job["status"] = "failed"
job["error"] = error_message job["error"] = error_message
job["updated_at"] = time.time() job["updated_at"] = time.time()
self._save_job(job)

View File

@ -179,10 +179,25 @@ class ToolRegistry:
raise ToolError("Path escapes workspace root") raise ToolError("Path escapes workspace root")
return resolved return resolved
def _check_policy(self, tool_name: str) -> None:
policy = self.config.tool_policy
read_only_tools = {"list_files", "glob_search", "grep_text", "stat_path", "read_file"}
write_tools = {"replace_in_file", "write_file", "make_directory"}
shell_tools = {"exec_command"}
if policy == "full-access":
return
if policy == "read-only" and tool_name not in read_only_tools:
raise ToolError(f"Tool '{tool_name}' is blocked by read-only policy")
if policy == "workspace-write" and tool_name in shell_tools:
raise ToolError(f"Tool '{tool_name}' is blocked by workspace-write policy")
if policy not in {"full-access", "workspace-write", "read-only"}:
raise ToolError(f"Unknown tool policy: {policy}")
def execute(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]: def execute(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
handler = self._handlers.get(name) handler = self._handlers.get(name)
if not handler: if not handler:
raise ToolError(f"Unknown tool: {name}") raise ToolError(f"Unknown tool: {name}")
self._check_policy(name)
return handler(arguments) return handler(arguments)
def list_files(self, arguments: dict[str, Any]) -> dict[str, Any]: def list_files(self, arguments: dict[str, Any]) -> dict[str, Any]: