185 lines
6.7 KiB
Python
185 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import secrets
|
|
import time
|
|
import uuid
|
|
import webbrowser
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib import error, parse, request
|
|
|
|
|
|
QWEN_OAUTH_BASE_URL = "https://chat.qwen.ai"
|
|
QWEN_DEVICE_CODE_ENDPOINT = f"{QWEN_OAUTH_BASE_URL}/api/v1/oauth2/device/code"
|
|
QWEN_TOKEN_ENDPOINT = f"{QWEN_OAUTH_BASE_URL}/api/v1/oauth2/token"
|
|
QWEN_CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
|
|
QWEN_SCOPE = "openid profile email model.completion"
|
|
QWEN_DEVICE_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
|
|
|
|
|
|
class OAuthError(RuntimeError):
|
|
pass
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class DeviceAuthState:
|
|
device_code: str
|
|
code_verifier: str
|
|
user_code: str
|
|
verification_uri: str
|
|
verification_uri_complete: str
|
|
expires_at: float
|
|
interval_seconds: float
|
|
|
|
|
|
class QwenOAuthManager:
|
|
def __init__(self, creds_path: Path | None = None) -> None:
|
|
self.creds_path = creds_path or Path.home() / ".qwen" / "oauth_creds.json"
|
|
self.creds_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
def _post_form(self, url: str, payload: dict[str, str]) -> dict[str, Any]:
|
|
data = parse.urlencode(payload).encode("utf-8")
|
|
req = request.Request(
|
|
url,
|
|
data=data,
|
|
headers={
|
|
"Content-Type": "application/x-www-form-urlencoded",
|
|
"Accept": "application/json",
|
|
"x-request-id": str(uuid.uuid4()),
|
|
},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with request.urlopen(req, timeout=60) as response:
|
|
return json.loads(response.read().decode("utf-8"))
|
|
except error.HTTPError as exc:
|
|
body = exc.read().decode("utf-8", errors="replace")
|
|
try:
|
|
payload = json.loads(body)
|
|
except json.JSONDecodeError:
|
|
raise OAuthError(f"HTTP {exc.code}: {body}") from exc
|
|
message = payload.get("error_description") or payload.get("error") or body
|
|
raise OAuthError(message) from exc
|
|
|
|
def load_credentials(self) -> dict[str, Any] | None:
|
|
if not self.creds_path.exists():
|
|
return None
|
|
return json.loads(self.creds_path.read_text(encoding="utf-8"))
|
|
|
|
def save_credentials(self, payload: dict[str, Any]) -> None:
|
|
self.creds_path.write_text(
|
|
json.dumps(payload, ensure_ascii=True, indent=2),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
def clear_credentials(self) -> None:
|
|
if self.creds_path.exists():
|
|
self.creds_path.unlink()
|
|
|
|
def start_device_flow(self, open_browser: bool = False) -> DeviceAuthState:
|
|
code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("ascii").rstrip("=")
|
|
code_challenge = (
|
|
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
|
|
.decode("ascii")
|
|
.rstrip("=")
|
|
)
|
|
response = self._post_form(
|
|
QWEN_DEVICE_CODE_ENDPOINT,
|
|
{
|
|
"client_id": QWEN_CLIENT_ID,
|
|
"scope": QWEN_SCOPE,
|
|
"code_challenge": code_challenge,
|
|
"code_challenge_method": "S256",
|
|
},
|
|
)
|
|
state = DeviceAuthState(
|
|
device_code=response["device_code"],
|
|
code_verifier=code_verifier,
|
|
user_code=response["user_code"],
|
|
verification_uri=response["verification_uri"],
|
|
verification_uri_complete=response["verification_uri_complete"],
|
|
expires_at=time.time() + float(response.get("expires_in", 600)),
|
|
interval_seconds=2.0,
|
|
)
|
|
if open_browser:
|
|
try:
|
|
webbrowser.open(state.verification_uri_complete)
|
|
except Exception:
|
|
pass
|
|
return state
|
|
|
|
def poll_device_flow(self, state: DeviceAuthState) -> dict[str, Any] | None:
|
|
if time.time() >= state.expires_at:
|
|
raise OAuthError("Device authorization expired")
|
|
try:
|
|
response = self._post_form(
|
|
QWEN_TOKEN_ENDPOINT,
|
|
{
|
|
"grant_type": QWEN_DEVICE_GRANT,
|
|
"client_id": QWEN_CLIENT_ID,
|
|
"device_code": state.device_code,
|
|
"code_verifier": state.code_verifier,
|
|
},
|
|
)
|
|
except OAuthError as exc:
|
|
text = str(exc)
|
|
if "authorization_pending" in text:
|
|
return None
|
|
if "slow_down" in text:
|
|
state.interval_seconds = min(state.interval_seconds * 1.5, 10.0)
|
|
return None
|
|
raise
|
|
|
|
creds = {
|
|
"access_token": response["access_token"],
|
|
"refresh_token": response.get("refresh_token"),
|
|
"token_type": response.get("token_type", "Bearer"),
|
|
"resource_url": response.get("resource_url"),
|
|
"expiry_date": int(time.time() * 1000) + int(response.get("expires_in", 3600)) * 1000,
|
|
}
|
|
self.save_credentials(creds)
|
|
return creds
|
|
|
|
def refresh_credentials(self, creds: dict[str, Any]) -> dict[str, Any]:
|
|
refresh_token = creds.get("refresh_token")
|
|
if not refresh_token:
|
|
raise OAuthError("No refresh token available")
|
|
response = self._post_form(
|
|
QWEN_TOKEN_ENDPOINT,
|
|
{
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": QWEN_CLIENT_ID,
|
|
},
|
|
)
|
|
updated = {
|
|
"access_token": response["access_token"],
|
|
"refresh_token": response.get("refresh_token") or refresh_token,
|
|
"token_type": response.get("token_type", "Bearer"),
|
|
"resource_url": response.get("resource_url") or creds.get("resource_url"),
|
|
"expiry_date": int(time.time() * 1000) + int(response.get("expires_in", 3600)) * 1000,
|
|
}
|
|
self.save_credentials(updated)
|
|
return updated
|
|
|
|
def get_valid_credentials(self) -> dict[str, Any]:
|
|
creds = self.load_credentials()
|
|
if not creds:
|
|
raise OAuthError("Qwen OAuth is not configured. Start device flow first.")
|
|
expiry = int(creds.get("expiry_date", 0))
|
|
if expiry and expiry - int(time.time() * 1000) > 30_000:
|
|
return creds
|
|
return self.refresh_credentials(creds)
|
|
|
|
def get_openai_base_url(self, creds: dict[str, Any]) -> str:
|
|
resource_url = creds.get("resource_url") or "https://dashscope.aliyuncs.com/compatible-mode"
|
|
if not resource_url.startswith("http"):
|
|
resource_url = f"https://{resource_url}"
|
|
if resource_url.endswith("/v1"):
|
|
return resource_url
|
|
return resource_url.rstrip("/") + "/v1"
|