78 lines
2.7 KiB
Python
78 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
class SQLiteTaskStateStore:
|
|
"""Durable task state store for runtime lifecycle state."""
|
|
|
|
def __init__(self, db_path: str | Path) -> None:
|
|
self._db_path = Path(db_path)
|
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._initialize()
|
|
|
|
def create_task(self, task_id: str, initial_state: dict[str, Any]) -> dict[str, Any]:
|
|
state = dict(initial_state)
|
|
session_id = state.get("session_id")
|
|
with sqlite3.connect(self._db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO task_states (task_id, state_json, session_id)
|
|
VALUES (?, ?, ?)
|
|
""",
|
|
(task_id, json.dumps(state), session_id),
|
|
)
|
|
conn.commit()
|
|
return state
|
|
|
|
def get_task(self, task_id: str) -> dict[str, Any] | None:
|
|
with sqlite3.connect(self._db_path) as conn:
|
|
row = conn.execute(
|
|
"SELECT state_json FROM task_states WHERE task_id = ?",
|
|
(task_id,),
|
|
).fetchone()
|
|
return json.loads(row[0]) if row else None
|
|
|
|
def update_task(self, task_id: str, patch: dict[str, Any]) -> dict[str, Any]:
|
|
state = self.get_task(task_id) or {}
|
|
state.update(patch)
|
|
with sqlite3.connect(self._db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO task_states (task_id, state_json)
|
|
VALUES (?, ?)
|
|
""",
|
|
(task_id, json.dumps(state)),
|
|
)
|
|
conn.commit()
|
|
return state
|
|
|
|
def _initialize(self) -> None:
|
|
with sqlite3.connect(self._db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS task_states (
|
|
task_id TEXT PRIMARY KEY,
|
|
state_json TEXT NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
conn.commit()
|
|
try:
|
|
conn.execute("ALTER TABLE task_states ADD COLUMN session_id TEXT")
|
|
conn.commit()
|
|
except sqlite3.OperationalError:
|
|
pass
|
|
|
|
def get_session_tasks(self, session_id: str, limit: int = 10) -> list[dict[str, Any]]:
|
|
with sqlite3.connect(self._db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
rows = conn.execute(
|
|
"SELECT state_json FROM task_states WHERE session_id = ? ORDER BY rowid DESC LIMIT ?",
|
|
(session_id, limit),
|
|
).fetchall()
|
|
return [json.loads(row[0]) for row in rows]
|