ducklm/app/api/server.py

142 lines
4.5 KiB
Python

from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from pydantic import BaseModel
class CriticFeedbackRequest(BaseModel):
feedback: str
task_id: str | None = None
session_id: str | None = None
feedback_type: str | None = None
severity: str | None = None
correction: str | None = None
remember: bool = True
retry: bool = False
assistant_answer: str | None = None
correctness_override: float | None = None
usefulness_override: float | None = None
safety_override: float | None = None
from app.core.permission_resolution import PermissionResolutionRequest, SecretResolutionRequest, PasswordResolutionRequest
from app.core.contracts import UserTask
from app.runtime.runtime_controller import RuntimeController
from app.streaming.manager import StreamingManager
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on startup."""
print("Lifespan: Starting model loading...")
loop = asyncio.get_event_loop()
def load_models():
try:
print("Lifespan: Loading models...")
runtime.load_models_at_startup()
print("Lifespan: Models loaded")
except Exception as e:
print(f"Lifespan: Failed to load models: {e}")
import traceback
traceback.print_exc()
await loop.run_in_executor(None, load_models)
yield # Server runs here
print("Lifespan: Shutting down...")
app = FastAPI(title="ducklm", lifespan=lifespan)
runtime = RuntimeController(base_dir=Path(__file__).resolve().parents[2])
streaming = StreamingManager(runtime.event_bus)
@app.get("/")
def index() -> FileResponse:
return FileResponse(Path(__file__).resolve().parent / "static" / "index.html")
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.get("/events")
def list_events(limit: int = 500) -> dict[str, object]:
safe_limit = max(1, min(limit, 2000))
return {
"events": [
event.model_dump(mode="json")
for event in runtime.event_bus.list_recent(limit=safe_limit)
]
}
@app.post("/chat")
def chat(task: UserTask) -> dict[str, object]:
return runtime.handle_task(task)
@app.post("/permissions/resolve")
def resolve_permission(request: PermissionResolutionRequest) -> dict[str, object]:
return runtime.resolve_permission(task_id=request.task_id, decision=request.decision)
@app.post("/secrets/resolve")
def resolve_secret(request: SecretResolutionRequest) -> dict[str, object]:
return runtime.resolve_secret(task_id=request.task_id, secret=request.secret)
@app.post("/password/resolve")
def resolve_password(request: PasswordResolutionRequest) -> dict[str, object]:
return runtime.resolve_password(task_id=request.task_id, password=request.password)
@app.post("/critic/feedback")
def critic_feedback(request: CriticFeedbackRequest) -> dict[str, object]:
feedback = runtime.handle_critic_feedback(
feedback=request.feedback,
task_id=request.task_id,
session_id=request.session_id,
feedback_type=request.feedback_type,
severity=request.severity,
correction=request.correction,
remember=request.remember,
retry=request.retry,
assistant_answer=request.assistant_answer,
correctness_override=request.correctness_override,
usefulness_override=request.usefulness_override,
safety_override=request.safety_override,
)
return feedback
@app.websocket("/stream/{task_id}")
async def stream_task(websocket: WebSocket, task_id: str) -> None:
await websocket.accept()
replayed_events = streaming.replay_events(task_id)
for event in replayed_events:
await websocket.send_json(event.model_dump(mode="json"))
if replayed_events and replayed_events[-1].type in {"task_completed", "task_failed"}:
await websocket.close()
return
queue = streaming.subscribe(task_id)
try:
while True:
event = await asyncio.wait_for(queue.get(), timeout=15)
await websocket.send_json(event.model_dump(mode="json"))
if event.type in {"task_completed", "task_failed", "task_awaiting_permission", "task_awaiting_input"}:
break
except (asyncio.TimeoutError, WebSocketDisconnect):
pass
finally:
streaming.unsubscribe(task_id, queue)
await websocket.close()