new-qwen/serv/oauth.py

185 lines
6.7 KiB
Python

from __future__ import annotations
import base64
import hashlib
import json
import secrets
import time
import uuid
import webbrowser
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib import error, parse, request
QWEN_OAUTH_BASE_URL = "https://chat.qwen.ai"
QWEN_DEVICE_CODE_ENDPOINT = f"{QWEN_OAUTH_BASE_URL}/api/v1/oauth2/device/code"
QWEN_TOKEN_ENDPOINT = f"{QWEN_OAUTH_BASE_URL}/api/v1/oauth2/token"
QWEN_CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
QWEN_SCOPE = "openid profile email model.completion"
QWEN_DEVICE_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
class OAuthError(RuntimeError):
pass
@dataclass(slots=True)
class DeviceAuthState:
device_code: str
code_verifier: str
user_code: str
verification_uri: str
verification_uri_complete: str
expires_at: float
interval_seconds: float
class QwenOAuthManager:
def __init__(self, creds_path: Path | None = None) -> None:
self.creds_path = creds_path or Path.home() / ".qwen" / "oauth_creds.json"
self.creds_path.parent.mkdir(parents=True, exist_ok=True)
def _post_form(self, url: str, payload: dict[str, str]) -> dict[str, Any]:
data = parse.urlencode(payload).encode("utf-8")
req = request.Request(
url,
data=data,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
"x-request-id": str(uuid.uuid4()),
},
method="POST",
)
try:
with request.urlopen(req, timeout=60) as response:
return json.loads(response.read().decode("utf-8"))
except error.HTTPError as exc:
body = exc.read().decode("utf-8", errors="replace")
try:
payload = json.loads(body)
except json.JSONDecodeError:
raise OAuthError(f"HTTP {exc.code}: {body}") from exc
message = payload.get("error_description") or payload.get("error") or body
raise OAuthError(message) from exc
def load_credentials(self) -> dict[str, Any] | None:
if not self.creds_path.exists():
return None
return json.loads(self.creds_path.read_text(encoding="utf-8"))
def save_credentials(self, payload: dict[str, Any]) -> None:
self.creds_path.write_text(
json.dumps(payload, ensure_ascii=True, indent=2),
encoding="utf-8",
)
def clear_credentials(self) -> None:
if self.creds_path.exists():
self.creds_path.unlink()
def start_device_flow(self, open_browser: bool = False) -> DeviceAuthState:
code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("ascii").rstrip("=")
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
.decode("ascii")
.rstrip("=")
)
response = self._post_form(
QWEN_DEVICE_CODE_ENDPOINT,
{
"client_id": QWEN_CLIENT_ID,
"scope": QWEN_SCOPE,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
},
)
state = DeviceAuthState(
device_code=response["device_code"],
code_verifier=code_verifier,
user_code=response["user_code"],
verification_uri=response["verification_uri"],
verification_uri_complete=response["verification_uri_complete"],
expires_at=time.time() + float(response.get("expires_in", 600)),
interval_seconds=2.0,
)
if open_browser:
try:
webbrowser.open(state.verification_uri_complete)
except Exception:
pass
return state
def poll_device_flow(self, state: DeviceAuthState) -> dict[str, Any] | None:
if time.time() >= state.expires_at:
raise OAuthError("Device authorization expired")
try:
response = self._post_form(
QWEN_TOKEN_ENDPOINT,
{
"grant_type": QWEN_DEVICE_GRANT,
"client_id": QWEN_CLIENT_ID,
"device_code": state.device_code,
"code_verifier": state.code_verifier,
},
)
except OAuthError as exc:
text = str(exc)
if "authorization_pending" in text:
return None
if "slow_down" in text:
state.interval_seconds = min(state.interval_seconds * 1.5, 10.0)
return None
raise
creds = {
"access_token": response["access_token"],
"refresh_token": response.get("refresh_token"),
"token_type": response.get("token_type", "Bearer"),
"resource_url": response.get("resource_url"),
"expiry_date": int(time.time() * 1000) + int(response.get("expires_in", 3600)) * 1000,
}
self.save_credentials(creds)
return creds
def refresh_credentials(self, creds: dict[str, Any]) -> dict[str, Any]:
refresh_token = creds.get("refresh_token")
if not refresh_token:
raise OAuthError("No refresh token available")
response = self._post_form(
QWEN_TOKEN_ENDPOINT,
{
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": QWEN_CLIENT_ID,
},
)
updated = {
"access_token": response["access_token"],
"refresh_token": response.get("refresh_token") or refresh_token,
"token_type": response.get("token_type", "Bearer"),
"resource_url": response.get("resource_url") or creds.get("resource_url"),
"expiry_date": int(time.time() * 1000) + int(response.get("expires_in", 3600)) * 1000,
}
self.save_credentials(updated)
return updated
def get_valid_credentials(self) -> dict[str, Any]:
creds = self.load_credentials()
if not creds:
raise OAuthError("Qwen OAuth is not configured. Start device flow first.")
expiry = int(creds.get("expiry_date", 0))
if expiry and expiry - int(time.time() * 1000) > 30_000:
return creds
return self.refresh_credentials(creds)
def get_openai_base_url(self, creds: dict[str, Any]) -> str:
resource_url = creds.get("resource_url") or "https://dashscope.aliyuncs.com/compatible-mode"
if not resource_url.startswith("http"):
resource_url = f"https://{resource_url}"
if resource_url.endswith("/v1"):
return resource_url
return resource_url.rstrip("/") + "/v1"