ducklm/app/api/server.py

171 lines
6.0 KiB
Python

from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from pydantic import BaseModel
class CriticFeedbackRequest(BaseModel):
feedback: str
task_id: str | None = None
session_id: str | None = None
feedback_type: str | None = None
severity: str | None = None
correction: str | None = None
remember: bool = True
retry: bool = False
assistant_answer: str | None = None
correctness_override: float | None = None
usefulness_override: float | None = None
safety_override: float | None = None
from app.core.permission_resolution import PermissionResolutionRequest, SecretResolutionRequest, PasswordResolutionRequest, ReviewResolutionRequest
from app.core.contracts import UserTask
from app.runtime.runtime_controller import RuntimeController
from app.streaming.manager import StreamingManager
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on startup."""
print("Lifespan: Starting model loading...")
try:
print("Lifespan: Loading models...")
runtime.load_models_at_startup()
print("Lifespan: Models loaded")
# Rebuild vector index if empty but memory store has data.
if runtime._memory_interface:
store_count = runtime._memory_interface.count()
if store_count > 0:
idx_count = runtime._memory_interface._vector_index.element_count
if idx_count == 0:
print(f"Lifespan: Rebuilding vector index ({store_count} entries)...")
runtime._memory_interface.reindex()
print("Lifespan: Vector index rebuilt")
except Exception as e:
print(f"Lifespan: Failed to load models: {e}")
import traceback
traceback.print_exc()
yield # Server runs here
print("Lifespan: Shutting down...")
app = FastAPI(title="ducklm", lifespan=lifespan)
runtime = RuntimeController(base_dir=Path(__file__).resolve().parents[2])
streaming = StreamingManager(runtime.event_bus)
@app.get("/")
def index() -> FileResponse:
return FileResponse(Path(__file__).resolve().parent / "static" / "index.html")
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.get("/events")
def list_events(limit: int = 500) -> dict[str, object]:
safe_limit = max(1, min(limit, 2000))
return {
"events": [
event.model_dump(mode="json")
for event in runtime.event_bus.list_recent(limit=safe_limit)
]
}
@app.post("/chat")
def chat(task: UserTask) -> dict[str, object]:
submit = getattr(runtime, "submit_task", None)
if callable(submit):
return submit(task)
return runtime.handle_task(task)
@app.post("/permissions/resolve")
def resolve_permission(request: PermissionResolutionRequest) -> dict[str, object]:
submit = getattr(runtime, "submit_permission_resolution", None)
if callable(submit):
return submit(task_id=request.task_id, decision=request.decision)
return runtime.resolve_permission(task_id=request.task_id, decision=request.decision)
@app.post("/secrets/resolve")
def resolve_secret(request: SecretResolutionRequest) -> dict[str, object]:
submit = getattr(runtime, "submit_secret_resolution", None)
if callable(submit):
return submit(task_id=request.task_id, secret=request.secret)
return runtime.resolve_secret(task_id=request.task_id, secret=request.secret)
@app.post("/password/resolve")
def resolve_password(request: PasswordResolutionRequest) -> dict[str, object]:
submit = getattr(runtime, "submit_password_resolution", None)
if callable(submit):
return submit(task_id=request.task_id, password=request.password)
return runtime.resolve_password(task_id=request.task_id, password=request.password)
@app.post("/review/resolve")
def resolve_review(request: ReviewResolutionRequest) -> dict[str, object]:
submit = getattr(runtime, "submit_review_resolution", None)
if callable(submit):
return submit(task_id=request.task_id, decision=request.decision, correction=request.correction)
return runtime.resolve_review(task_id=request.task_id, decision=request.decision, correction=request.correction)
@app.post("/critic/feedback")
def critic_feedback(request: CriticFeedbackRequest) -> dict[str, object]:
feedback = runtime.handle_critic_feedback(
feedback=request.feedback,
task_id=request.task_id,
session_id=request.session_id,
feedback_type=request.feedback_type,
severity=request.severity,
correction=request.correction,
remember=request.remember,
retry=request.retry,
assistant_answer=request.assistant_answer,
correctness_override=request.correctness_override,
usefulness_override=request.usefulness_override,
safety_override=request.safety_override,
)
return feedback
@app.websocket("/stream/{task_id}")
async def stream_task(websocket: WebSocket, task_id: str) -> None:
await websocket.accept()
replayed_events = streaming.replay_events(task_id)
for event in replayed_events:
await websocket.send_json(event.model_dump(mode="json"))
if replayed_events and replayed_events[-1].type in {"task_completed", "task_failed"}:
await websocket.close()
return
queue = streaming.subscribe(task_id)
try:
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=30)
except asyncio.TimeoutError:
await websocket.send_json({"type": "heartbeat", "task_id": task_id})
continue
await websocket.send_json(event.model_dump(mode="json"))
if event.type in {"task_completed", "task_failed", "task_awaiting_permission", "task_awaiting_input", "task_awaiting_review"}:
break
except WebSocketDisconnect:
pass
finally:
streaming.unsubscribe(task_id, queue)
await websocket.close()