from __future__ import annotations import asyncio from contextlib import asynccontextmanager from pathlib import Path from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from pydantic import BaseModel class CriticFeedbackRequest(BaseModel): feedback: str task_id: str | None = None session_id: str | None = None feedback_type: str | None = None severity: str | None = None correction: str | None = None remember: bool = True retry: bool = False assistant_answer: str | None = None correctness_override: float | None = None usefulness_override: float | None = None safety_override: float | None = None from app.core.permission_resolution import PermissionResolutionRequest, SecretResolutionRequest, PasswordResolutionRequest, ReviewResolutionRequest from app.core.contracts import UserTask from app.runtime.runtime_controller import RuntimeController from app.streaming.manager import StreamingManager @asynccontextmanager async def lifespan(app: FastAPI): """Load models on startup.""" print("Lifespan: Starting model loading...") try: print("Lifespan: Loading models...") runtime.load_models_at_startup() print("Lifespan: Models loaded") # Rebuild vector index if empty but memory store has data. if runtime._memory_interface: store_count = runtime._memory_interface.count() if store_count > 0: idx_count = runtime._memory_interface._vector_index.element_count if idx_count == 0: print(f"Lifespan: Rebuilding vector index ({store_count} entries)...") runtime._memory_interface.reindex() print("Lifespan: Vector index rebuilt") except Exception as e: print(f"Lifespan: Failed to load models: {e}") import traceback traceback.print_exc() yield # Server runs here print("Lifespan: Shutting down...") app = FastAPI(title="ducklm", lifespan=lifespan) runtime = RuntimeController(base_dir=Path(__file__).resolve().parents[2]) streaming = StreamingManager(runtime.event_bus) @app.get("/") def index() -> FileResponse: return FileResponse(Path(__file__).resolve().parent / "static" / "index.html") @app.get("/health") def health() -> dict[str, str]: return {"status": "ok"} @app.get("/events") def list_events(limit: int = 500) -> dict[str, object]: safe_limit = max(1, min(limit, 2000)) return { "events": [ event.model_dump(mode="json") for event in runtime.event_bus.list_recent(limit=safe_limit) ] } @app.post("/chat") def chat(task: UserTask) -> dict[str, object]: submit = getattr(runtime, "submit_task", None) if callable(submit): return submit(task) return runtime.handle_task(task) @app.post("/permissions/resolve") def resolve_permission(request: PermissionResolutionRequest) -> dict[str, object]: submit = getattr(runtime, "submit_permission_resolution", None) if callable(submit): return submit(task_id=request.task_id, decision=request.decision) return runtime.resolve_permission(task_id=request.task_id, decision=request.decision) @app.post("/secrets/resolve") def resolve_secret(request: SecretResolutionRequest) -> dict[str, object]: submit = getattr(runtime, "submit_secret_resolution", None) if callable(submit): return submit(task_id=request.task_id, secret=request.secret) return runtime.resolve_secret(task_id=request.task_id, secret=request.secret) @app.post("/password/resolve") def resolve_password(request: PasswordResolutionRequest) -> dict[str, object]: submit = getattr(runtime, "submit_password_resolution", None) if callable(submit): return submit(task_id=request.task_id, password=request.password) return runtime.resolve_password(task_id=request.task_id, password=request.password) @app.post("/review/resolve") def resolve_review(request: ReviewResolutionRequest) -> dict[str, object]: submit = getattr(runtime, "submit_review_resolution", None) if callable(submit): return submit(task_id=request.task_id, decision=request.decision, correction=request.correction) return runtime.resolve_review(task_id=request.task_id, decision=request.decision, correction=request.correction) @app.post("/critic/feedback") def critic_feedback(request: CriticFeedbackRequest) -> dict[str, object]: feedback = runtime.handle_critic_feedback( feedback=request.feedback, task_id=request.task_id, session_id=request.session_id, feedback_type=request.feedback_type, severity=request.severity, correction=request.correction, remember=request.remember, retry=request.retry, assistant_answer=request.assistant_answer, correctness_override=request.correctness_override, usefulness_override=request.usefulness_override, safety_override=request.safety_override, ) return feedback @app.websocket("/stream/{task_id}") async def stream_task(websocket: WebSocket, task_id: str) -> None: await websocket.accept() replayed_events = streaming.replay_events(task_id) for event in replayed_events: await websocket.send_json(event.model_dump(mode="json")) if replayed_events and replayed_events[-1].type in {"task_completed", "task_failed"}: await websocket.close() return queue = streaming.subscribe(task_id) try: while True: try: event = await asyncio.wait_for(queue.get(), timeout=30) except asyncio.TimeoutError: await websocket.send_json({"type": "heartbeat", "task_id": task_id}) continue await websocket.send_json(event.model_dump(mode="json")) if event.type in {"task_completed", "task_failed", "task_awaiting_permission", "task_awaiting_input", "task_awaiting_review"}: break except WebSocketDisconnect: pass finally: streaming.unsubscribe(task_id, queue) await websocket.close()