from __future__ import annotations import json import time import uuid from typing import Any from urllib import error, parse, request from config import ServerConfig class GigaChatError(RuntimeError): pass class GigaChatAuthManager: def __init__(self, config: ServerConfig) -> None: self.config = config self.token_path = config.state_dir / "gigachat_token.json" self.token_path.parent.mkdir(parents=True, exist_ok=True) def is_configured(self) -> bool: return bool(self.config.gigachat_auth_key) def _authorization_header(self) -> str: raw = self.config.gigachat_auth_key.strip() if not raw: raise GigaChatError("GigaChat auth key is not configured") if raw.lower().startswith("basic "): return raw return f"Basic {raw}" def load_token(self) -> dict[str, Any] | None: if not self.token_path.exists(): return None try: return json.loads(self.token_path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return None def save_token(self, payload: dict[str, Any]) -> None: self.token_path.write_text( json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8", ) def fetch_token(self) -> dict[str, Any]: data = parse.urlencode({"scope": self.config.gigachat_scope}).encode("utf-8") req = request.Request( self.config.gigachat_oauth_url, data=data, headers={ "Content-Type": "application/x-www-form-urlencoded", "Accept": "application/json", "RqUID": str(uuid.uuid4()), "Authorization": self._authorization_header(), }, method="POST", ) try: with request.urlopen(req, timeout=60) as response: payload = json.loads(response.read().decode("utf-8")) except error.HTTPError as exc: body = exc.read().decode("utf-8", errors="replace") raise GigaChatError(f"GigaChat token request failed with HTTP {exc.code}: {body}") from exc token = { "access_token": payload["access_token"], "expires_at": int(payload["expires_at"]), } self.save_token(token) return token def get_valid_token(self) -> str: if not self.is_configured(): raise GigaChatError("GigaChat auth key is not configured") token = self.load_token() now = int(time.time()) if token and int(token.get("expires_at", 0)) - now > 30: return str(token["access_token"]) refreshed = self.fetch_token() return str(refreshed["access_token"])