211 lines
7.1 KiB
Python
211 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
import aiosqlite
|
|
from pydantic import BaseModel
|
|
|
|
from duck_core.tasks.store import utc_now
|
|
|
|
|
|
class Conversation(BaseModel):
|
|
id: int | None = None
|
|
conversation_id: str
|
|
title: str
|
|
workspace: str
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
|
|
class ConversationMessage(BaseModel):
|
|
id: int | None = None
|
|
conversation_id: str
|
|
role: str
|
|
content: str
|
|
reasoning_content: str | None = None
|
|
task_id: str | None = None
|
|
status: str | None = None
|
|
created_at: str
|
|
|
|
|
|
class ConversationStore:
|
|
def __init__(self, db_path: str):
|
|
self.db_path = Path(db_path)
|
|
|
|
async def init(self) -> None:
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute(
|
|
"""
|
|
create table if not exists conversations (
|
|
id integer primary key autoincrement,
|
|
conversation_id text not null unique,
|
|
title text not null,
|
|
workspace text not null,
|
|
created_at text not null,
|
|
updated_at text not null
|
|
)
|
|
"""
|
|
)
|
|
await db.execute(
|
|
"""
|
|
create table if not exists conversation_messages (
|
|
id integer primary key autoincrement,
|
|
conversation_id text not null,
|
|
role text not null,
|
|
content text not null,
|
|
reasoning_content text,
|
|
task_id text,
|
|
status text,
|
|
created_at text not null
|
|
)
|
|
"""
|
|
)
|
|
await db.commit()
|
|
|
|
async def create(self, title: str, workspace: str) -> Conversation:
|
|
await self.init()
|
|
now = utc_now()
|
|
conversation_id = f"chat_{uuid4().hex[:12]}"
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
cursor = await db.execute(
|
|
"""
|
|
insert into conversations(conversation_id, title, workspace, created_at, updated_at)
|
|
values (?, ?, ?, ?, ?)
|
|
""",
|
|
(conversation_id, title, workspace, now, now),
|
|
)
|
|
await db.commit()
|
|
row_id = cursor.lastrowid
|
|
return Conversation(
|
|
id=row_id,
|
|
conversation_id=conversation_id,
|
|
title=title,
|
|
workspace=workspace,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
|
|
async def ensure(
|
|
self, conversation_id: str | None, title: str, workspace: str
|
|
) -> Conversation:
|
|
if conversation_id:
|
|
existing = await self.get(conversation_id)
|
|
if existing is not None:
|
|
return existing
|
|
return await self.create(title, workspace)
|
|
|
|
async def get(self, conversation_id: str) -> Conversation | None:
|
|
await self.init()
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute(
|
|
"select * from conversations where conversation_id = ?", (conversation_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
return self._row_to_conversation(row) if row else None
|
|
|
|
async def list(self, limit: int = 50) -> list[Conversation]:
|
|
await self.init()
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute(
|
|
"select * from conversations order by updated_at desc limit ?", (limit,)
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [self._row_to_conversation(row) for row in rows]
|
|
|
|
async def add_message(
|
|
self,
|
|
conversation_id: str,
|
|
role: str,
|
|
content: str,
|
|
reasoning_content: str | None = None,
|
|
task_id: str | None = None,
|
|
status: str | None = None,
|
|
) -> ConversationMessage:
|
|
await self.init()
|
|
now = utc_now()
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
cursor = await db.execute(
|
|
"""
|
|
insert into conversation_messages(
|
|
conversation_id, role, content, reasoning_content, task_id, status, created_at
|
|
) values (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(conversation_id, role, content, reasoning_content, task_id, status, now),
|
|
)
|
|
await db.execute(
|
|
"update conversations set updated_at = ? where conversation_id = ?",
|
|
(now, conversation_id),
|
|
)
|
|
await db.commit()
|
|
row_id = cursor.lastrowid
|
|
return ConversationMessage(
|
|
id=row_id,
|
|
conversation_id=conversation_id,
|
|
role=role,
|
|
content=content,
|
|
reasoning_content=reasoning_content,
|
|
task_id=task_id,
|
|
status=status,
|
|
created_at=now,
|
|
)
|
|
|
|
async def list_messages(
|
|
self, conversation_id: str, limit: int | None = None
|
|
) -> list[ConversationMessage]:
|
|
await self.init()
|
|
sql = "select * from conversation_messages where conversation_id = ? order by id"
|
|
params: tuple[Any, ...] = (conversation_id,)
|
|
if limit is not None:
|
|
sql = (
|
|
"select * from (select * from conversation_messages where conversation_id = ? "
|
|
"order by id desc limit ?) order by id"
|
|
)
|
|
params = (conversation_id, limit)
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute(sql, params)
|
|
rows = await cursor.fetchall()
|
|
return [self._row_to_message(row) for row in rows]
|
|
|
|
async def get_conversation_id_for_task(self, task_id: str) -> str | None:
|
|
await self.init()
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
cursor = await db.execute(
|
|
"""
|
|
select conversation_id from conversation_messages
|
|
where task_id = ?
|
|
order by id
|
|
limit 1
|
|
""",
|
|
(task_id,),
|
|
)
|
|
row = await cursor.fetchone()
|
|
return row[0] if row else None
|
|
|
|
def _row_to_conversation(self, row: aiosqlite.Row) -> Conversation:
|
|
return Conversation(
|
|
id=row["id"],
|
|
conversation_id=row["conversation_id"],
|
|
title=row["title"],
|
|
workspace=row["workspace"],
|
|
created_at=row["created_at"],
|
|
updated_at=row["updated_at"],
|
|
)
|
|
|
|
def _row_to_message(self, row: aiosqlite.Row) -> ConversationMessage:
|
|
return ConversationMessage(
|
|
id=row["id"],
|
|
conversation_id=row["conversation_id"],
|
|
role=row["role"],
|
|
content=row["content"],
|
|
reasoning_content=row["reasoning_content"],
|
|
task_id=row["task_id"],
|
|
status=row["status"],
|
|
created_at=row["created_at"],
|
|
)
|