Add chat job polling and edit tool
This commit is contained in:
parent
f9b9d7d242
commit
940bef2f4a
|
|
@ -28,12 +28,13 @@ Qwen OAuth + OpenAI-compatible endpoint
|
|||
- хранение токенов в `~/.qwen/oauth_creds.json`
|
||||
- HTTP API сервера
|
||||
- агентный цикл с tool calling
|
||||
- инструменты: `list_files`, `glob_search`, `grep_text`, `stat_path`, `read_file`, `write_file`, `make_directory`, `exec_command`
|
||||
- инструменты: `list_files`, `glob_search`, `grep_text`, `stat_path`, `read_file`, `replace_in_file`, `write_file`, `make_directory`, `exec_command`
|
||||
- Telegram polling без внешних библиотек
|
||||
- JSON-хранилище сессий
|
||||
- API списка и просмотра сессий
|
||||
- автоматический polling OAuth flow в боте
|
||||
- очередь сообщений, пришедших до завершения OAuth
|
||||
- job-based chat polling между `bot` и `serv`
|
||||
|
||||
## Ограничения текущей реализации
|
||||
|
||||
|
|
@ -93,3 +94,5 @@ curl -X POST http://127.0.0.1:8080/api/v1/auth/device/start
|
|||
- `POST /api/v1/session/get`
|
||||
- `POST /api/v1/session/clear`
|
||||
- `POST /api/v1/chat`
|
||||
- `POST /api/v1/chat/start`
|
||||
- `POST /api/v1/chat/poll`
|
||||
|
|
|
|||
49
bot/app.py
49
bot/app.py
|
|
@ -51,6 +51,24 @@ def send_text_chunks(api: TelegramAPI, chat_id: int, text: str) -> None:
|
|||
api.send_message(chat_id, normalized[start : start + chunk_size])
|
||||
|
||||
|
||||
def summarize_event(event: dict[str, Any]) -> str | None:
|
||||
event_type = event.get("type")
|
||||
if event_type == "job_status":
|
||||
return event.get("message")
|
||||
if event_type == "model_request":
|
||||
return "Думаю над ответом"
|
||||
if event_type == "tool_call":
|
||||
return f"Вызываю инструмент: {event.get('name')}"
|
||||
if event_type == "tool_result":
|
||||
result = event.get("result", {})
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
return f"Инструмент {event.get('name')} завершился с ошибкой"
|
||||
return f"Инструмент {event.get('name')} завершён"
|
||||
if event_type == "error":
|
||||
return f"Ошибка: {event.get('message')}"
|
||||
return None
|
||||
|
||||
|
||||
def get_auth_flow(state: dict[str, Any], chat_id: int) -> dict[str, Any] | None:
|
||||
return state.setdefault("auth_flows", {}).get(str(chat_id))
|
||||
|
||||
|
|
@ -141,16 +159,39 @@ def deliver_chat_message(
|
|||
session_id = state.setdefault("sessions", {}).get(session_key)
|
||||
prefix = "Обрабатываю отложенный запрос..." if delayed else "Обрабатываю запрос..."
|
||||
api.send_message(chat_id, prefix)
|
||||
result = post_json(
|
||||
f"{config.server_url}/api/v1/chat",
|
||||
start_result = post_json(
|
||||
f"{config.server_url}/api/v1/chat/start",
|
||||
{
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"message": text,
|
||||
},
|
||||
)
|
||||
state["sessions"][session_key] = result["session_id"]
|
||||
answer = result.get("answer") or "Пустой ответ от модели."
|
||||
state["sessions"][session_key] = start_result["session_id"]
|
||||
job_id = start_result["job_id"]
|
||||
seen_seq = 0
|
||||
sent_statuses: set[str] = set()
|
||||
answer = None
|
||||
while True:
|
||||
poll_result = post_json(
|
||||
f"{config.server_url}/api/v1/chat/poll",
|
||||
{"job_id": job_id, "since_seq": seen_seq},
|
||||
)
|
||||
for event in poll_result.get("events", []):
|
||||
seen_seq = max(seen_seq, int(event.get("seq", 0)))
|
||||
summary = summarize_event(event)
|
||||
if summary and summary not in sent_statuses:
|
||||
api.send_message(chat_id, summary[:4000])
|
||||
sent_statuses.add(summary)
|
||||
if poll_result.get("status") == "completed":
|
||||
answer = poll_result.get("answer")
|
||||
state["sessions"][session_key] = poll_result["session_id"]
|
||||
break
|
||||
if poll_result.get("status") == "failed":
|
||||
raise RuntimeError(poll_result.get("error") or "Chat job failed")
|
||||
time.sleep(1.2)
|
||||
|
||||
answer = answer or "Пустой ответ от модели."
|
||||
send_text_chunks(api, chat_id, answer)
|
||||
|
||||
|
||||
|
|
|
|||
90
serv/app.py
90
serv/app.py
|
|
@ -9,6 +9,7 @@ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|||
from typing import Any
|
||||
|
||||
from config import ServerConfig
|
||||
from jobs import JobStore
|
||||
from llm import QwenAgent
|
||||
from oauth import DeviceAuthState, OAuthError, QwenOAuthManager
|
||||
from sessions import SessionStore
|
||||
|
|
@ -22,6 +23,7 @@ class AppState:
|
|||
self.sessions = SessionStore(config.session_dir)
|
||||
self.tools = ToolRegistry(config)
|
||||
self.agent = QwenAgent(config, self.oauth, self.tools)
|
||||
self.jobs = JobStore()
|
||||
self.pending_device_flows: dict[str, DeviceAuthState] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
|
|
@ -70,6 +72,47 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
return
|
||||
self._send(HTTPStatus.NOT_FOUND, {"error": "Not found"})
|
||||
|
||||
def _run_chat_job(self, job_id: str, session_id: str, user_id: str, message: str) -> None:
|
||||
try:
|
||||
self.app.jobs.set_status(job_id, "running")
|
||||
self.app.jobs.append_event(
|
||||
job_id,
|
||||
{"type": "job_status", "message": "Запрос принят сервером"},
|
||||
)
|
||||
session = self.app.sessions.load(session_id)
|
||||
history = session.get("messages", [])
|
||||
result = self.app.agent.run(
|
||||
history,
|
||||
message,
|
||||
on_event=lambda event: self.app.jobs.append_event(job_id, event),
|
||||
)
|
||||
persisted_messages = result["messages"][1:]
|
||||
self.app.sessions.save(
|
||||
session_id,
|
||||
{
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
"messages": persisted_messages,
|
||||
"last_answer": result["answer"],
|
||||
},
|
||||
)
|
||||
self.app.jobs.append_event(
|
||||
job_id,
|
||||
{"type": "job_status", "message": "Ответ готов"},
|
||||
)
|
||||
self.app.jobs.finish(
|
||||
job_id,
|
||||
answer=result["answer"],
|
||||
usage=result.get("usage"),
|
||||
)
|
||||
except Exception as exc:
|
||||
self.app.jobs.append_event(
|
||||
job_id,
|
||||
{"type": "error", "message": str(exc)},
|
||||
)
|
||||
self.app.jobs.fail(job_id, str(exc))
|
||||
|
||||
def do_POST(self) -> None:
|
||||
try:
|
||||
if self.path == "/api/v1/auth/device/start":
|
||||
|
|
@ -137,6 +180,53 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||
)
|
||||
return
|
||||
|
||||
if self.path == "/api/v1/chat/start":
|
||||
body = self._json_body()
|
||||
session_id = body.get("session_id") or uuid.uuid4().hex
|
||||
user_id = str(body.get("user_id") or "anonymous")
|
||||
message = body["message"]
|
||||
job = self.app.jobs.create(session_id, user_id, message)
|
||||
thread = threading.Thread(
|
||||
target=self._run_chat_job,
|
||||
args=(job["job_id"], session_id, user_id, message),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
self._send(
|
||||
HTTPStatus.OK,
|
||||
{
|
||||
"job_id": job["job_id"],
|
||||
"session_id": session_id,
|
||||
"status": "queued",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if self.path == "/api/v1/chat/poll":
|
||||
body = self._json_body()
|
||||
job_id = body["job_id"]
|
||||
since_seq = int(body.get("since_seq", 0))
|
||||
job = self.app.jobs.get(job_id)
|
||||
if not job:
|
||||
self._send(HTTPStatus.NOT_FOUND, {"error": "Unknown job_id"})
|
||||
return
|
||||
events = [
|
||||
event for event in job.get("events", []) if event.get("seq", 0) > since_seq
|
||||
]
|
||||
self._send(
|
||||
HTTPStatus.OK,
|
||||
{
|
||||
"job_id": job_id,
|
||||
"session_id": job["session_id"],
|
||||
"status": job["status"],
|
||||
"events": events,
|
||||
"answer": job.get("answer"),
|
||||
"usage": job.get("usage"),
|
||||
"error": job.get("error"),
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if self.path == "/api/v1/session/get":
|
||||
body = self._json_body()
|
||||
session_id = body["session_id"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,76 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
||||
class JobStore:
|
||||
def __init__(self) -> None:
|
||||
self._jobs: dict[str, dict[str, Any]] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def create(self, session_id: str, user_id: str, message: str) -> dict[str, Any]:
|
||||
job_id = uuid.uuid4().hex
|
||||
job = {
|
||||
"job_id": job_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"message": message,
|
||||
"status": "queued",
|
||||
"created_at": time.time(),
|
||||
"updated_at": time.time(),
|
||||
"events": [],
|
||||
"answer": None,
|
||||
"usage": None,
|
||||
"error": None,
|
||||
}
|
||||
with self._lock:
|
||||
self._jobs[job_id] = job
|
||||
return job
|
||||
|
||||
def get(self, job_id: str) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
return None
|
||||
return {
|
||||
key: (value.copy() if isinstance(value, list) else value)
|
||||
for key, value in job.items()
|
||||
}
|
||||
|
||||
def append_event(self, job_id: str, event: dict[str, Any]) -> None:
|
||||
with self._lock:
|
||||
job = self._jobs[job_id]
|
||||
seq = len(job["events"]) + 1
|
||||
job["events"].append({"seq": seq, **event})
|
||||
job["updated_at"] = time.time()
|
||||
|
||||
def set_status(self, job_id: str, status: str) -> None:
|
||||
with self._lock:
|
||||
job = self._jobs[job_id]
|
||||
job["status"] = status
|
||||
job["updated_at"] = time.time()
|
||||
|
||||
def finish(
|
||||
self,
|
||||
job_id: str,
|
||||
*,
|
||||
answer: str,
|
||||
usage: dict[str, Any] | None,
|
||||
) -> None:
|
||||
with self._lock:
|
||||
job = self._jobs[job_id]
|
||||
job["status"] = "completed"
|
||||
job["answer"] = answer
|
||||
job["usage"] = usage
|
||||
job["updated_at"] = time.time()
|
||||
|
||||
def fail(self, job_id: str, error_message: str) -> None:
|
||||
with self._lock:
|
||||
job = self._jobs[job_id]
|
||||
job["status"] = "failed"
|
||||
job["error"] = error_message
|
||||
job["updated_at"] = time.time()
|
||||
|
||||
27
serv/llm.py
27
serv/llm.py
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
from urllib import error, request
|
||||
|
||||
from config import ServerConfig
|
||||
|
|
@ -49,7 +49,13 @@ class QwenAgent:
|
|||
body = exc.read().decode("utf-8", errors="replace")
|
||||
raise OAuthError(f"LLM request failed with HTTP {exc.code}: {body}") from exc
|
||||
|
||||
def run(self, history: list[dict[str, Any]], user_message: str) -> dict[str, Any]:
|
||||
def run(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
user_message: str,
|
||||
on_event: Callable[[dict[str, Any]], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
emit = on_event or (lambda _event: None)
|
||||
system_prompt = self.config.system_prompt or DEFAULT_SYSTEM_PROMPT
|
||||
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
messages.extend(history)
|
||||
|
|
@ -57,12 +63,15 @@ class QwenAgent:
|
|||
events: list[dict[str, Any]] = []
|
||||
|
||||
for _ in range(self.config.max_tool_rounds):
|
||||
emit({"type": "model_request", "message": "Запрашиваю ответ модели"})
|
||||
response = self._request_completion(messages)
|
||||
choice = response["choices"][0]["message"]
|
||||
tool_calls = choice.get("tool_calls") or []
|
||||
content = choice.get("content")
|
||||
if content:
|
||||
events.append({"type": "assistant", "content": content})
|
||||
assistant_event = {"type": "assistant", "content": content}
|
||||
events.append(assistant_event)
|
||||
emit(assistant_event)
|
||||
|
||||
if not tool_calls:
|
||||
return {
|
||||
|
|
@ -87,7 +96,9 @@ class QwenAgent:
|
|||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
|
||||
events.append({"type": "tool_call", "name": tool_name, "arguments": arguments})
|
||||
tool_call_event = {"type": "tool_call", "name": tool_name, "arguments": arguments}
|
||||
events.append(tool_call_event)
|
||||
emit(tool_call_event)
|
||||
try:
|
||||
result = self.tools.execute(tool_name, arguments)
|
||||
except ToolError as exc:
|
||||
|
|
@ -95,7 +106,9 @@ class QwenAgent:
|
|||
except Exception as exc:
|
||||
result = {"error": f"Unexpected tool failure: {exc}"}
|
||||
|
||||
events.append({"type": "tool_result", "name": tool_name, "result": result})
|
||||
tool_result_event = {"type": "tool_result", "name": tool_name, "result": result}
|
||||
events.append(tool_result_event)
|
||||
emit(tool_result_event)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
|
|
@ -107,7 +120,9 @@ class QwenAgent:
|
|||
final_message = (
|
||||
"Остановлено по лимиту tool rounds. Попробуй сузить задачу или продолжить отдельным сообщением."
|
||||
)
|
||||
events.append({"type": "assistant", "content": final_message})
|
||||
final_event = {"type": "assistant", "content": final_message}
|
||||
events.append(final_event)
|
||||
emit(final_event)
|
||||
return {
|
||||
"answer": final_message,
|
||||
"events": events,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class ToolRegistry:
|
|||
"grep_text": self.grep_text,
|
||||
"stat_path": self.stat_path,
|
||||
"read_file": self.read_file,
|
||||
"replace_in_file": self.replace_in_file,
|
||||
"write_file": self.write_file,
|
||||
"make_directory": self.make_directory,
|
||||
"exec_command": self.exec_command,
|
||||
|
|
@ -106,6 +107,23 @@ class ToolRegistry:
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "replace_in_file",
|
||||
"description": "Replace exact text in a workspace file without rewriting unrelated content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
"old_text": {"type": "string"},
|
||||
"new_text": {"type": "string"},
|
||||
"expected_count": {"type": "integer"},
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -276,6 +294,30 @@ class ToolRegistry:
|
|||
"bytes_written": len(arguments["content"].encode("utf-8")),
|
||||
}
|
||||
|
||||
def replace_in_file(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
target = self._resolve(arguments["path"])
|
||||
if not target.exists():
|
||||
raise ToolError("File does not exist")
|
||||
if not target.is_file():
|
||||
raise ToolError("Path is not a file")
|
||||
old_text = arguments["old_text"]
|
||||
new_text = arguments["new_text"]
|
||||
expected_count = arguments.get("expected_count")
|
||||
content = target.read_text(encoding="utf-8")
|
||||
count = content.count(old_text)
|
||||
if count == 0:
|
||||
raise ToolError("old_text not found in file")
|
||||
if expected_count is not None and count != int(expected_count):
|
||||
raise ToolError(
|
||||
f"expected_count mismatch: found {count}, expected {int(expected_count)}"
|
||||
)
|
||||
updated = content.replace(old_text, new_text)
|
||||
target.write_text(updated, encoding="utf-8")
|
||||
return {
|
||||
"path": target.relative_to(self.workspace_root).as_posix(),
|
||||
"replacements": count,
|
||||
}
|
||||
|
||||
def make_directory(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
target = self._resolve(arguments["path"])
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue