119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import FileResponse
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class CriticFeedbackRequest(BaseModel):
|
|
feedback: str
|
|
task_id: str | None = None
|
|
session_id: str | None = None
|
|
correctness_override: float | None = None
|
|
usefulness_override: float | None = None
|
|
safety_override: float | None = None
|
|
|
|
from app.core.permission_resolution import PermissionResolutionRequest, SecretResolutionRequest, PasswordResolutionRequest
|
|
from app.core.contracts import UserTask
|
|
from app.runtime.runtime_controller import RuntimeController
|
|
from app.streaming.manager import StreamingManager
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Load models on startup."""
|
|
print("Lifespan: Starting model loading...")
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def load_models():
|
|
try:
|
|
print("Lifespan: Loading models...")
|
|
runtime.load_models_at_startup()
|
|
print("Lifespan: Models loaded")
|
|
except Exception as e:
|
|
print(f"Lifespan: Failed to load models: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
await loop.run_in_executor(None, load_models)
|
|
|
|
yield # Server runs here
|
|
|
|
print("Lifespan: Shutting down...")
|
|
|
|
|
|
app = FastAPI(title="ducklm", lifespan=lifespan)
|
|
runtime = RuntimeController(base_dir=Path(__file__).resolve().parents[2])
|
|
streaming = StreamingManager(runtime.event_bus)
|
|
|
|
|
|
@app.get("/")
|
|
def index() -> FileResponse:
|
|
return FileResponse(Path(__file__).resolve().parent / "static" / "index.html")
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/chat")
|
|
def chat(task: UserTask) -> dict[str, object]:
|
|
return runtime.handle_task(task)
|
|
|
|
|
|
@app.post("/permissions/resolve")
|
|
def resolve_permission(request: PermissionResolutionRequest) -> dict[str, object]:
|
|
return runtime.resolve_permission(task_id=request.task_id, decision=request.decision)
|
|
|
|
|
|
@app.post("/secrets/resolve")
|
|
def resolve_secret(request: SecretResolutionRequest) -> dict[str, object]:
|
|
return runtime.resolve_secret(task_id=request.task_id, secret=request.secret)
|
|
|
|
|
|
@app.post("/password/resolve")
|
|
def resolve_password(request: PasswordResolutionRequest) -> dict[str, object]:
|
|
return runtime.resolve_password(task_id=request.task_id, password=request.password)
|
|
|
|
|
|
@app.post("/critic/feedback")
|
|
def critic_feedback(request: CriticFeedbackRequest) -> dict[str, object]:
|
|
feedback = runtime.handle_critic_feedback(
|
|
feedback=request.feedback,
|
|
task_id=request.task_id,
|
|
session_id=request.session_id,
|
|
correctness_override=request.correctness_override,
|
|
usefulness_override=request.usefulness_override,
|
|
safety_override=request.safety_override,
|
|
)
|
|
return feedback
|
|
|
|
|
|
@app.websocket("/stream/{task_id}")
|
|
async def stream_task(websocket: WebSocket, task_id: str) -> None:
|
|
await websocket.accept()
|
|
replayed_events = streaming.replay_events(task_id)
|
|
for event in replayed_events:
|
|
await websocket.send_json(event.model_dump(mode="json"))
|
|
if replayed_events and replayed_events[-1].type in {"task_completed", "task_failed"}:
|
|
await websocket.close()
|
|
return
|
|
|
|
queue = streaming.subscribe(task_id)
|
|
try:
|
|
while True:
|
|
event = await asyncio.wait_for(queue.get(), timeout=15)
|
|
await websocket.send_json(event.model_dump(mode="json"))
|
|
if event.type in {"task_completed", "task_failed", "task_awaiting_permission", "task_awaiting_input"}:
|
|
break
|
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
|
pass
|
|
finally:
|
|
streaming.unsubscribe(task_id, queue)
|
|
await websocket.close()
|