new-qwen/serv/app.py

324 lines
13 KiB
Python

from __future__ import annotations
import json
import threading
import time
import uuid
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any
from config import ServerConfig
from jobs import JobStore
from llm import QwenAgent
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
from sessions import SessionStore
from tools import ToolRegistry
class AppState:
def __init__(self, config: ServerConfig) -> None:
self.config = config
self.oauth = QwenOAuthManager()
self.sessions = SessionStore(config.session_dir)
self.tools = ToolRegistry(config)
self.agent = QwenAgent(config, self.oauth, self.tools)
self.jobs = JobStore(config.state_dir / "jobs")
self.pending_flows_path = config.state_dir / "oauth_flows.json"
self.pending_device_flows: dict[str, DeviceAuthState] = self._load_pending_flows()
self.lock = threading.Lock()
def _load_pending_flows(self) -> dict[str, DeviceAuthState]:
self.config.state_dir.mkdir(parents=True, exist_ok=True)
if not self.pending_flows_path.exists():
return {}
try:
payload = json.loads(self.pending_flows_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {}
flows: dict[str, DeviceAuthState] = {}
now = time.time()
for flow_id, value in payload.items():
try:
state = DeviceAuthState(
device_code=value["device_code"],
code_verifier=value["code_verifier"],
user_code=value["user_code"],
verification_uri=value["verification_uri"],
verification_uri_complete=value["verification_uri_complete"],
expires_at=float(value["expires_at"]),
interval_seconds=float(value.get("interval_seconds", 2.0)),
)
except KeyError:
continue
if state.expires_at > now:
flows[flow_id] = state
return flows
def _save_pending_flows(self) -> None:
self.config.state_dir.mkdir(parents=True, exist_ok=True)
payload = {
flow_id: {
"device_code": state.device_code,
"code_verifier": state.code_verifier,
"user_code": state.user_code,
"verification_uri": state.verification_uri,
"verification_uri_complete": state.verification_uri_complete,
"expires_at": state.expires_at,
"interval_seconds": state.interval_seconds,
}
for flow_id, state in self.pending_device_flows.items()
if state.expires_at > time.time()
}
self.pending_flows_path.write_text(
json.dumps(payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
def auth_status(self) -> dict[str, Any]:
creds = self.oauth.load_credentials()
if not creds:
return {
"authenticated": False,
"tool_policy": self.config.tool_policy,
"pending_flows": len(self.pending_device_flows),
}
return {
"authenticated": True,
"resource_url": creds.get("resource_url"),
"expires_at": creds.get("expiry_date"),
"tool_policy": self.config.tool_policy,
"pending_flows": len(self.pending_device_flows),
}
class RequestHandler(BaseHTTPRequestHandler):
server_version = "new-qwen-serv/0.1"
def _json_body(self) -> dict[str, Any]:
length = int(self.headers.get("Content-Length", "0"))
if length <= 0:
return {}
raw = self.rfile.read(length).decode("utf-8")
return json.loads(raw) if raw else {}
def _send(self, status: int, payload: dict[str, Any]) -> None:
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
@property
def app(self) -> AppState:
return self.server.app_state # type: ignore[attr-defined]
def do_GET(self) -> None:
if self.path == "/health":
self._send(HTTPStatus.OK, {"ok": True, "time": int(time.time())})
return
if self.path == "/api/v1/auth/status":
self._send(HTTPStatus.OK, self.app.auth_status())
return
if self.path == "/api/v1/sessions":
self._send(HTTPStatus.OK, {"sessions": self.app.sessions.list_sessions()})
return
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
def _run_chat_job(self, job_id: str, session_id: str, user_id: str, message: str) -> None:
try:
self.app.jobs.set_status(job_id, "running")
self.app.jobs.append_event(
job_id,
{"type": "job_status", "message": "Запрос принят сервером"},
)
session = self.app.sessions.load(session_id)
history = session.get("messages", [])
result = self.app.agent.run(
history,
message,
on_event=lambda event: self.app.jobs.append_event(job_id, event),
)
persisted_messages = result["messages"][1:]
self.app.sessions.save(
session_id,
{
"session_id": session_id,
"user_id": user_id,
"updated_at": int(time.time()),
"messages": persisted_messages,
"last_answer": result["answer"],
},
)
self.app.jobs.append_event(
job_id,
{"type": "job_status", "message": "Ответ готов"},
)
self.app.jobs.finish(
job_id,
answer=result["answer"],
usage=result.get("usage"),
)
except Exception as exc:
self.app.jobs.append_event(
job_id,
{"type": "error", "message": str(exc)},
)
self.app.jobs.fail(job_id, str(exc))
def do_POST(self) -> None:
try:
if self.path == "/api/v1/auth/device/start":
flow_id = uuid.uuid4().hex
state = self.app.oauth.start_device_flow(open_browser=False)
with self.app.lock:
self.app.pending_device_flows[flow_id] = state
self.app._save_pending_flows()
self._send(
HTTPStatus.OK,
{
"flow_id": flow_id,
"user_code": state.user_code,
"verification_uri": state.verification_uri,
"verification_uri_complete": state.verification_uri_complete,
"expires_at": state.expires_at,
"interval_seconds": state.interval_seconds,
},
)
return
if self.path == "/api/v1/auth/device/poll":
body = self._json_body()
flow_id = body["flow_id"]
with self.app.lock:
state = self.app.pending_device_flows.get(flow_id)
if not state:
self._send(HTTPStatus.NOT_FOUND, {"error": "Unknown flow_id"})
return
creds = self.app.oauth.poll_device_flow(state)
if creds is None:
self._send(HTTPStatus.OK, {"done": False, "interval_seconds": state.interval_seconds})
return
with self.app.lock:
self.app.pending_device_flows.pop(flow_id, None)
self.app._save_pending_flows()
self._send(HTTPStatus.OK, {"done": True, "credentials": {"resource_url": creds.get("resource_url")}})
return
if self.path == "/api/v1/chat":
body = self._json_body()
session_id = body.get("session_id") or uuid.uuid4().hex
user_id = str(body.get("user_id") or "anonymous")
message = body["message"]
session = self.app.sessions.load(session_id)
history = session.get("messages", [])
result = self.app.agent.run(history, message)
persisted_messages = result["messages"][1:]
self.app.sessions.save(
session_id,
{
"session_id": session_id,
"user_id": user_id,
"updated_at": int(time.time()),
"messages": persisted_messages,
"last_answer": result["answer"],
},
)
self._send(
HTTPStatus.OK,
{
"session_id": session_id,
"answer": result["answer"],
"events": result["events"],
"usage": result.get("usage"),
},
)
return
if self.path == "/api/v1/chat/start":
body = self._json_body()
session_id = body.get("session_id") or uuid.uuid4().hex
user_id = str(body.get("user_id") or "anonymous")
message = body["message"]
job = self.app.jobs.create(session_id, user_id, message)
thread = threading.Thread(
target=self._run_chat_job,
args=(job["job_id"], session_id, user_id, message),
daemon=True,
)
thread.start()
self._send(
HTTPStatus.OK,
{
"job_id": job["job_id"],
"session_id": session_id,
"status": "queued",
},
)
return
if self.path == "/api/v1/chat/poll":
body = self._json_body()
job_id = body["job_id"]
since_seq = int(body.get("since_seq", 0))
job = self.app.jobs.get(job_id)
if not job:
self._send(HTTPStatus.NOT_FOUND, {"error": "Unknown job_id"})
return
events = [
event for event in job.get("events", []) if event.get("seq", 0) > since_seq
]
self._send(
HTTPStatus.OK,
{
"job_id": job_id,
"session_id": job["session_id"],
"status": job["status"],
"events": events,
"answer": job.get("answer"),
"usage": job.get("usage"),
"error": job.get("error"),
},
)
return
if self.path == "/api/v1/session/get":
body = self._json_body()
session_id = body["session_id"]
self._send(HTTPStatus.OK, self.app.sessions.load(session_id))
return
if self.path == "/api/v1/session/clear":
body = self._json_body()
session_id = body["session_id"]
self.app.sessions.clear(session_id)
self._send(HTTPStatus.OK, {"ok": True, "session_id": session_id})
return
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
except OAuthError as exc:
self._send(HTTPStatus.BAD_GATEWAY, {"error": str(exc)})
except KeyError as exc:
self._send(HTTPStatus.BAD_REQUEST, {"error": f"Missing field: {exc}"})
except Exception as exc:
self._send(HTTPStatus.INTERNAL_SERVER_ERROR, {"error": str(exc)})
def log_message(self, fmt: str, *args: Any) -> None:
return
def main() -> None:
config = ServerConfig.load()
config.session_dir.mkdir(parents=True, exist_ok=True)
config.state_dir.mkdir(parents=True, exist_ok=True)
httpd = ThreadingHTTPServer((config.host, config.port), RequestHandler)
httpd.app_state = AppState(config) # type: ignore[attr-defined]
print(f"new-qwen serv listening on http://{config.host}:{config.port}")
httpd.serve_forever()
if __name__ == "__main__":
main()