ducklm/duck_core/conversations/store.py

211 lines
7.1 KiB
Python

from __future__ import annotations
from pathlib import Path
from typing import Any
from uuid import uuid4
import aiosqlite
from pydantic import BaseModel
from duck_core.tasks.store import utc_now
class Conversation(BaseModel):
id: int | None = None
conversation_id: str
title: str
workspace: str
created_at: str
updated_at: str
class ConversationMessage(BaseModel):
id: int | None = None
conversation_id: str
role: str
content: str
reasoning_content: str | None = None
task_id: str | None = None
status: str | None = None
created_at: str
class ConversationStore:
def __init__(self, db_path: str):
self.db_path = Path(db_path)
async def init(self) -> None:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"""
create table if not exists conversations (
id integer primary key autoincrement,
conversation_id text not null unique,
title text not null,
workspace text not null,
created_at text not null,
updated_at text not null
)
"""
)
await db.execute(
"""
create table if not exists conversation_messages (
id integer primary key autoincrement,
conversation_id text not null,
role text not null,
content text not null,
reasoning_content text,
task_id text,
status text,
created_at text not null
)
"""
)
await db.commit()
async def create(self, title: str, workspace: str) -> Conversation:
await self.init()
now = utc_now()
conversation_id = f"chat_{uuid4().hex[:12]}"
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"""
insert into conversations(conversation_id, title, workspace, created_at, updated_at)
values (?, ?, ?, ?, ?)
""",
(conversation_id, title, workspace, now, now),
)
await db.commit()
row_id = cursor.lastrowid
return Conversation(
id=row_id,
conversation_id=conversation_id,
title=title,
workspace=workspace,
created_at=now,
updated_at=now,
)
async def ensure(
self, conversation_id: str | None, title: str, workspace: str
) -> Conversation:
if conversation_id:
existing = await self.get(conversation_id)
if existing is not None:
return existing
return await self.create(title, workspace)
async def get(self, conversation_id: str) -> Conversation | None:
await self.init()
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"select * from conversations where conversation_id = ?", (conversation_id,)
)
row = await cursor.fetchone()
return self._row_to_conversation(row) if row else None
async def list(self, limit: int = 50) -> list[Conversation]:
await self.init()
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"select * from conversations order by updated_at desc limit ?", (limit,)
)
rows = await cursor.fetchall()
return [self._row_to_conversation(row) for row in rows]
async def add_message(
self,
conversation_id: str,
role: str,
content: str,
reasoning_content: str | None = None,
task_id: str | None = None,
status: str | None = None,
) -> ConversationMessage:
await self.init()
now = utc_now()
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"""
insert into conversation_messages(
conversation_id, role, content, reasoning_content, task_id, status, created_at
) values (?, ?, ?, ?, ?, ?, ?)
""",
(conversation_id, role, content, reasoning_content, task_id, status, now),
)
await db.execute(
"update conversations set updated_at = ? where conversation_id = ?",
(now, conversation_id),
)
await db.commit()
row_id = cursor.lastrowid
return ConversationMessage(
id=row_id,
conversation_id=conversation_id,
role=role,
content=content,
reasoning_content=reasoning_content,
task_id=task_id,
status=status,
created_at=now,
)
async def list_messages(
self, conversation_id: str, limit: int | None = None
) -> list[ConversationMessage]:
await self.init()
sql = "select * from conversation_messages where conversation_id = ? order by id"
params: tuple[Any, ...] = (conversation_id,)
if limit is not None:
sql = (
"select * from (select * from conversation_messages where conversation_id = ? "
"order by id desc limit ?) order by id"
)
params = (conversation_id, limit)
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(sql, params)
rows = await cursor.fetchall()
return [self._row_to_message(row) for row in rows]
async def get_conversation_id_for_task(self, task_id: str) -> str | None:
await self.init()
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"""
select conversation_id from conversation_messages
where task_id = ?
order by id
limit 1
""",
(task_id,),
)
row = await cursor.fetchone()
return row[0] if row else None
def _row_to_conversation(self, row: aiosqlite.Row) -> Conversation:
return Conversation(
id=row["id"],
conversation_id=row["conversation_id"],
title=row["title"],
workspace=row["workspace"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
def _row_to_message(self, row: aiosqlite.Row) -> ConversationMessage:
return ConversationMessage(
id=row["id"],
conversation_id=row["conversation_id"],
role=row["role"],
content=row["content"],
reasoning_content=row["reasoning_content"],
task_id=row["task_id"],
status=row["status"],
created_at=row["created_at"],
)