176 lines
6.7 KiB
Python
176 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from http import HTTPStatus
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from typing import Any
|
|
|
|
from config import ServerConfig
|
|
from llm import QwenAgent
|
|
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
|
|
from sessions import SessionStore
|
|
from tools import ToolRegistry
|
|
|
|
|
|
class AppState:
|
|
def __init__(self, config: ServerConfig) -> None:
|
|
self.config = config
|
|
self.oauth = QwenOAuthManager()
|
|
self.sessions = SessionStore(config.session_dir)
|
|
self.tools = ToolRegistry(config)
|
|
self.agent = QwenAgent(config, self.oauth, self.tools)
|
|
self.pending_device_flows: dict[str, DeviceAuthState] = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def auth_status(self) -> dict[str, Any]:
|
|
creds = self.oauth.load_credentials()
|
|
if not creds:
|
|
return {"authenticated": False}
|
|
return {
|
|
"authenticated": True,
|
|
"resource_url": creds.get("resource_url"),
|
|
"expires_at": creds.get("expiry_date"),
|
|
}
|
|
|
|
|
|
class RequestHandler(BaseHTTPRequestHandler):
|
|
server_version = "new-qwen-serv/0.1"
|
|
|
|
def _json_body(self) -> dict[str, Any]:
|
|
length = int(self.headers.get("Content-Length", "0"))
|
|
if length <= 0:
|
|
return {}
|
|
raw = self.rfile.read(length).decode("utf-8")
|
|
return json.loads(raw) if raw else {}
|
|
|
|
def _send(self, status: int, payload: dict[str, Any]) -> None:
|
|
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("Content-Type", "application/json; charset=utf-8")
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
@property
|
|
def app(self) -> AppState:
|
|
return self.server.app_state # type: ignore[attr-defined]
|
|
|
|
def do_GET(self) -> None:
|
|
if self.path == "/health":
|
|
self._send(HTTPStatus.OK, {"ok": True, "time": int(time.time())})
|
|
return
|
|
if self.path == "/api/v1/auth/status":
|
|
self._send(HTTPStatus.OK, self.app.auth_status())
|
|
return
|
|
if self.path == "/api/v1/sessions":
|
|
self._send(HTTPStatus.OK, {"sessions": self.app.sessions.list_sessions()})
|
|
return
|
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
|
|
|
def do_POST(self) -> None:
|
|
try:
|
|
if self.path == "/api/v1/auth/device/start":
|
|
flow_id = uuid.uuid4().hex
|
|
state = self.app.oauth.start_device_flow(open_browser=False)
|
|
with self.app.lock:
|
|
self.app.pending_device_flows[flow_id] = state
|
|
self._send(
|
|
HTTPStatus.OK,
|
|
{
|
|
"flow_id": flow_id,
|
|
"user_code": state.user_code,
|
|
"verification_uri": state.verification_uri,
|
|
"verification_uri_complete": state.verification_uri_complete,
|
|
"expires_at": state.expires_at,
|
|
"interval_seconds": state.interval_seconds,
|
|
},
|
|
)
|
|
return
|
|
|
|
if self.path == "/api/v1/auth/device/poll":
|
|
body = self._json_body()
|
|
flow_id = body["flow_id"]
|
|
with self.app.lock:
|
|
state = self.app.pending_device_flows.get(flow_id)
|
|
if not state:
|
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Unknown flow_id"})
|
|
return
|
|
creds = self.app.oauth.poll_device_flow(state)
|
|
if creds is None:
|
|
self._send(HTTPStatus.OK, {"done": False, "interval_seconds": state.interval_seconds})
|
|
return
|
|
with self.app.lock:
|
|
self.app.pending_device_flows.pop(flow_id, None)
|
|
self._send(HTTPStatus.OK, {"done": True, "credentials": {"resource_url": creds.get("resource_url")}})
|
|
return
|
|
|
|
if self.path == "/api/v1/chat":
|
|
body = self._json_body()
|
|
session_id = body.get("session_id") or uuid.uuid4().hex
|
|
user_id = str(body.get("user_id") or "anonymous")
|
|
message = body["message"]
|
|
session = self.app.sessions.load(session_id)
|
|
history = session.get("messages", [])
|
|
result = self.app.agent.run(history, message)
|
|
persisted_messages = result["messages"][1:]
|
|
self.app.sessions.save(
|
|
session_id,
|
|
{
|
|
"session_id": session_id,
|
|
"user_id": user_id,
|
|
"updated_at": int(time.time()),
|
|
"messages": persisted_messages,
|
|
"last_answer": result["answer"],
|
|
},
|
|
)
|
|
self._send(
|
|
HTTPStatus.OK,
|
|
{
|
|
"session_id": session_id,
|
|
"answer": result["answer"],
|
|
"events": result["events"],
|
|
"usage": result.get("usage"),
|
|
},
|
|
)
|
|
return
|
|
|
|
if self.path == "/api/v1/session/get":
|
|
body = self._json_body()
|
|
session_id = body["session_id"]
|
|
self._send(HTTPStatus.OK, self.app.sessions.load(session_id))
|
|
return
|
|
|
|
if self.path == "/api/v1/session/clear":
|
|
body = self._json_body()
|
|
session_id = body["session_id"]
|
|
self.app.sessions.clear(session_id)
|
|
self._send(HTTPStatus.OK, {"ok": True, "session_id": session_id})
|
|
return
|
|
|
|
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
|
except OAuthError as exc:
|
|
self._send(HTTPStatus.BAD_GATEWAY, {"error": str(exc)})
|
|
except KeyError as exc:
|
|
self._send(HTTPStatus.BAD_REQUEST, {"error": f"Missing field: {exc}"})
|
|
except Exception as exc:
|
|
self._send(HTTPStatus.INTERNAL_SERVER_ERROR, {"error": str(exc)})
|
|
|
|
def log_message(self, fmt: str, *args: Any) -> None:
|
|
return
|
|
|
|
|
|
def main() -> None:
|
|
config = ServerConfig.load()
|
|
config.session_dir.mkdir(parents=True, exist_ok=True)
|
|
httpd = ThreadingHTTPServer((config.host, config.port), RequestHandler)
|
|
httpd.app_state = AppState(config) # type: ignore[attr-defined]
|
|
print(f"new-qwen serv listening on http://{config.host}:{config.port}")
|
|
httpd.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|