398 lines
15 KiB
Python
398 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from http import HTTPStatus
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from config import ServerConfig
|
|
from approvals import ApprovalStore
|
|
from jobs import JobStore
|
|
from llm import QwenAgent
|
|
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
|
|
from sessions import SessionStore
|
|
from tools import ToolRegistry
|
|
|
|
|
|
class AppState:
|
|
def __init__(self, config: ServerConfig) -> None:
|
|
self.config = config
|
|
self.oauth = QwenOAuthManager()
|
|
self.sessions = SessionStore(config.session_dir)
|
|
self.tools = ToolRegistry(config)
|
|
self.agent = QwenAgent(config, self.oauth, self.tools)
|
|
self.jobs = JobStore(
|
|
config.state_dir / "jobs",
|
|
retention_seconds=config.jobs_retention_seconds,
|
|
)
|
|
self.approvals = ApprovalStore(
|
|
config.state_dir / "approvals",
|
|
retention_seconds=config.approvals_retention_seconds,
|
|
)
|
|
self.pending_flows_path = config.state_dir / "oauth_flows.json"
|
|
self.pending_device_flows: dict[str, DeviceAuthState] = self._load_pending_flows()
|
|
self.lock = threading.Lock()
|
|
|
|
def _load_pending_flows(self) -> dict[str, DeviceAuthState]:
|
|
self.config.state_dir.mkdir(parents=True, exist_ok=True)
|
|
if not self.pending_flows_path.exists():
|
|
return {}
|
|
try:
|
|
payload = json.loads(self.pending_flows_path.read_text(encoding="utf-8"))
|
|
except (OSError, json.JSONDecodeError):
|
|
return {}
|
|
flows: dict[str, DeviceAuthState] = {}
|
|
now = time.time()
|
|
for flow_id, value in payload.items():
|
|
try:
|
|
state = DeviceAuthState(
|
|
device_code=value["device_code"],
|
|
code_verifier=value["code_verifier"],
|
|
user_code=value["user_code"],
|
|
verification_uri=value["verification_uri"],
|
|
verification_uri_complete=value["verification_uri_complete"],
|
|
expires_at=float(value["expires_at"]),
|
|
interval_seconds=float(value.get("interval_seconds", 2.0)),
|
|
)
|
|
except KeyError:
|
|
continue
|
|
if state.expires_at > now:
|
|
flows[flow_id] = state
|
|
if len(flows) != len(payload):
|
|
self.pending_device_flows = flows
|
|
self._save_pending_flows()
|
|
return flows
|
|
|
|
def _save_pending_flows(self) -> None:
|
|
self.config.state_dir.mkdir(parents=True, exist_ok=True)
|
|
payload = {
|
|
flow_id: {
|
|
"device_code": state.device_code,
|
|
"code_verifier": state.code_verifier,
|
|
"user_code": state.user_code,
|
|
"verification_uri": state.verification_uri,
|
|
"verification_uri_complete": state.verification_uri_complete,
|
|
"expires_at": state.expires_at,
|
|
"interval_seconds": state.interval_seconds,
|
|
}
|
|
for flow_id, state in self.pending_device_flows.items()
|
|
if state.expires_at > time.time()
|
|
}
|
|
self.pending_flows_path.write_text(
|
|
json.dumps(payload, ensure_ascii=False, indent=2),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
def cleanup_state(self) -> None:
|
|
self.jobs.cleanup()
|
|
self.approvals.cleanup()
|
|
with self.lock:
|
|
self.pending_device_flows = {
|
|
flow_id: state
|
|
for flow_id, state in self.pending_device_flows.items()
|
|
if state.expires_at > time.time()
|
|
}
|
|
self._save_pending_flows()
|
|
|
|
def auth_status(self) -> dict[str, Any]:
|
|
creds = self.oauth.load_credentials()
|
|
if not creds:
|
|
return {
|
|
"authenticated": False,
|
|
"tool_policy": self.config.tool_policy,
|
|
"pending_flows": len(self.pending_device_flows),
|
|
"pending_approvals": len(self.approvals.list_pending()),
|
|
}
|
|
return {
|
|
"authenticated": True,
|
|
"resource_url": creds.get("resource_url"),
|
|
"expires_at": creds.get("expiry_date"),
|
|
"tool_policy": self.config.tool_policy,
|
|
"pending_flows": len(self.pending_device_flows),
|
|
"pending_approvals": len(self.approvals.list_pending()),
|
|
}
|
|
|
|
|
|
class RequestHandler(BaseHTTPRequestHandler):
|
|
server_version = "new-qwen-serv/0.1"
|
|
|
|
def _json_body(self) -> dict[str, Any]:
|
|
length = int(self.headers.get("Content-Length", "0"))
|
|
if length <= 0:
|
|
return {}
|
|
raw = self.rfile.read(length).decode("utf-8")
|
|
return json.loads(raw) if raw else {}
|
|
|
|
def _send(self, status: int, payload: dict[str, Any]) -> None:
|
|
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("Content-Type", "application/json; charset=utf-8")
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
@property
|
|
def app(self) -> AppState:
|
|
return self.server.app_state # type: ignore[attr-defined]
|
|
|
|
def do_GET(self) -> None:
|
|
if self.path == "/health":
|
|
self._send(HTTPStatus.OK, {"ok": True, "time": int(time.time())})
|
|
return
|
|
if self.path == "/api/v1/auth/status":
|
|
self._send(HTTPStatus.OK, self.app.auth_status())
|
|
return
|
|
if self.path == "/api/v1/sessions":
|
|
self._send(HTTPStatus.OK, {"sessions": self.app.sessions.list_sessions()})
|
|
return
|
|
if self.path == "/api/v1/approvals":
|
|
self._send(HTTPStatus.OK, {"approvals": self.app.approvals.list_pending()})
|
|
return
|
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
|
|
|
def _run_chat_job(self, job_id: str, session_id: str, user_id: str, message: str) -> None:
|
|
try:
|
|
self.app.jobs.set_status(job_id, "running")
|
|
self.app.jobs.append_event(
|
|
job_id,
|
|
{"type": "job_status", "message": "Запрос принят сервером"},
|
|
)
|
|
session = self.app.sessions.load(session_id)
|
|
history = session.get("messages", [])
|
|
result = self.app.agent.run(
|
|
history,
|
|
message,
|
|
on_event=lambda event: self.app.jobs.append_event(job_id, event),
|
|
approval_callback=lambda tool_name, arguments: self._request_approval(
|
|
job_id,
|
|
tool_name,
|
|
arguments,
|
|
),
|
|
)
|
|
persisted_messages = result["messages"][1:]
|
|
self.app.sessions.save(
|
|
session_id,
|
|
{
|
|
"session_id": session_id,
|
|
"user_id": user_id,
|
|
"updated_at": int(time.time()),
|
|
"messages": persisted_messages,
|
|
"last_answer": result["answer"],
|
|
},
|
|
)
|
|
self.app.jobs.append_event(
|
|
job_id,
|
|
{"type": "job_status", "message": "Ответ готов"},
|
|
)
|
|
self.app.jobs.finish(
|
|
job_id,
|
|
answer=result["answer"],
|
|
usage=result.get("usage"),
|
|
)
|
|
except Exception as exc:
|
|
self.app.jobs.append_event(
|
|
job_id,
|
|
{"type": "error", "message": str(exc)},
|
|
)
|
|
self.app.jobs.fail(job_id, str(exc))
|
|
|
|
def _request_approval(
|
|
self,
|
|
job_id: str,
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
approval = self.app.approvals.create(
|
|
job_id=job_id,
|
|
tool_name=tool_name,
|
|
arguments=arguments,
|
|
)
|
|
self.app.jobs.append_event(
|
|
job_id,
|
|
{
|
|
"type": "approval_request",
|
|
"approval_id": approval["approval_id"],
|
|
"tool_name": tool_name,
|
|
"arguments": arguments,
|
|
},
|
|
)
|
|
self.app.jobs.set_status(job_id, "waiting_approval")
|
|
decision = self.app.approvals.wait(
|
|
approval["approval_id"],
|
|
timeout_seconds=float(self.app.config.approval_timeout_seconds),
|
|
)
|
|
self.app.jobs.set_status(job_id, "running")
|
|
return decision
|
|
|
|
def do_POST(self) -> None:
|
|
try:
|
|
if self.path == "/api/v1/auth/device/start":
|
|
flow_id = uuid.uuid4().hex
|
|
state = self.app.oauth.start_device_flow(open_browser=False)
|
|
with self.app.lock:
|
|
self.app.pending_device_flows[flow_id] = state
|
|
self.app._save_pending_flows()
|
|
self._send(
|
|
HTTPStatus.OK,
|
|
{
|
|
"flow_id": flow_id,
|
|
"user_code": state.user_code,
|
|
"verification_uri": state.verification_uri,
|
|
"verification_uri_complete": state.verification_uri_complete,
|
|
"expires_at": state.expires_at,
|
|
"interval_seconds": state.interval_seconds,
|
|
},
|
|
)
|
|
return
|
|
|
|
if self.path == "/api/v1/auth/device/poll":
|
|
body = self._json_body()
|
|
flow_id = body["flow_id"]
|
|
with self.app.lock:
|
|
state = self.app.pending_device_flows.get(flow_id)
|
|
if not state:
|
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Unknown flow_id"})
|
|
return
|
|
creds = self.app.oauth.poll_device_flow(state)
|
|
if creds is None:
|
|
self._send(HTTPStatus.OK, {"done": False, "interval_seconds": state.interval_seconds})
|
|
return
|
|
with self.app.lock:
|
|
self.app.pending_device_flows.pop(flow_id, None)
|
|
self.app._save_pending_flows()
|
|
self._send(HTTPStatus.OK, {"done": True, "credentials": {"resource_url": creds.get("resource_url")}})
|
|
return
|
|
|
|
if self.path == "/api/v1/chat":
|
|
body = self._json_body()
|
|
session_id = body.get("session_id") or uuid.uuid4().hex
|
|
user_id = str(body.get("user_id") or "anonymous")
|
|
message = body["message"]
|
|
session = self.app.sessions.load(session_id)
|
|
history = session.get("messages", [])
|
|
result = self.app.agent.run(history, message)
|
|
persisted_messages = result["messages"][1:]
|
|
self.app.sessions.save(
|
|
session_id,
|
|
{
|
|
"session_id": session_id,
|
|
"user_id": user_id,
|
|
"updated_at": int(time.time()),
|
|
"messages": persisted_messages,
|
|
"last_answer": result["answer"],
|
|
},
|
|
)
|
|
self._send(
|
|
HTTPStatus.OK,
|
|
{
|
|
"session_id": session_id,
|
|
"answer": result["answer"],
|
|
"events": result["events"],
|
|
"usage": result.get("usage"),
|
|
},
|
|
)
|
|
return
|
|
|
|
if self.path == "/api/v1/chat/start":
|
|
body = self._json_body()
|
|
session_id = body.get("session_id") or uuid.uuid4().hex
|
|
user_id = str(body.get("user_id") or "anonymous")
|
|
message = body["message"]
|
|
job = self.app.jobs.create(session_id, user_id, message)
|
|
thread = threading.Thread(
|
|
target=self._run_chat_job,
|
|
args=(job["job_id"], session_id, user_id, message),
|
|
daemon=True,
|
|
)
|
|
thread.start()
|
|
self._send(
|
|
HTTPStatus.OK,
|
|
{
|
|
"job_id": job["job_id"],
|
|
"session_id": session_id,
|
|
"status": "queued",
|
|
},
|
|
)
|
|
return
|
|
|
|
if self.path == "/api/v1/chat/poll":
|
|
body = self._json_body()
|
|
job_id = body["job_id"]
|
|
since_seq = int(body.get("since_seq", 0))
|
|
job = self.app.jobs.get(job_id)
|
|
if not job:
|
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Unknown job_id"})
|
|
return
|
|
events = [
|
|
event for event in job.get("events", []) if event.get("seq", 0) > since_seq
|
|
]
|
|
self._send(
|
|
HTTPStatus.OK,
|
|
{
|
|
"job_id": job_id,
|
|
"session_id": job["session_id"],
|
|
"status": job["status"],
|
|
"events": events,
|
|
"answer": job.get("answer"),
|
|
"usage": job.get("usage"),
|
|
"error": job.get("error"),
|
|
},
|
|
)
|
|
return
|
|
|
|
if self.path == "/api/v1/approval/respond":
|
|
body = self._json_body()
|
|
approval_id = body["approval_id"]
|
|
approved = bool(body["approved"])
|
|
actor = str(body.get("actor") or "unknown")
|
|
approval = self.app.approvals.respond(
|
|
approval_id,
|
|
approved=approved,
|
|
actor=actor,
|
|
)
|
|
self._send(HTTPStatus.OK, approval)
|
|
return
|
|
|
|
if self.path == "/api/v1/session/get":
|
|
body = self._json_body()
|
|
session_id = body["session_id"]
|
|
self._send(HTTPStatus.OK, self.app.sessions.load(session_id))
|
|
return
|
|
|
|
if self.path == "/api/v1/session/clear":
|
|
body = self._json_body()
|
|
session_id = body["session_id"]
|
|
self.app.sessions.clear(session_id)
|
|
self._send(HTTPStatus.OK, {"ok": True, "session_id": session_id})
|
|
return
|
|
|
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
|
except OAuthError as exc:
|
|
self._send(HTTPStatus.BAD_GATEWAY, {"error": str(exc)})
|
|
except KeyError as exc:
|
|
self._send(HTTPStatus.BAD_REQUEST, {"error": f"Missing field: {exc}"})
|
|
except Exception as exc:
|
|
self._send(HTTPStatus.INTERNAL_SERVER_ERROR, {"error": str(exc)})
|
|
|
|
def log_message(self, fmt: str, *args: Any) -> None:
|
|
return
|
|
|
|
|
|
def main() -> None:
|
|
config = ServerConfig.load()
|
|
config.session_dir.mkdir(parents=True, exist_ok=True)
|
|
config.state_dir.mkdir(parents=True, exist_ok=True)
|
|
httpd = ThreadingHTTPServer((config.host, config.port), RequestHandler)
|
|
httpd.app_state = AppState(config) # type: ignore[attr-defined]
|
|
httpd.app_state.cleanup_state() # type: ignore[attr-defined]
|
|
print(f"new-qwen serv listening on http://{config.host}:{config.port}")
|
|
httpd.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|