new-qwen/serv/oauth.py

213 lines
8.1 KiB
Python

from __future__ import annotations
import base64
import hashlib
import http.client
import json
import secrets
import time
import uuid
import webbrowser
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib import parse
QWEN_OAUTH_BASE_URL = "https://chat.qwen.ai"
QWEN_DEVICE_CODE_ENDPOINT = f"{QWEN_OAUTH_BASE_URL}/api/v1/oauth2/device/code"
QWEN_TOKEN_ENDPOINT = f"{QWEN_OAUTH_BASE_URL}/api/v1/oauth2/token"
QWEN_CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
QWEN_SCOPE = "openid profile email model.completion"
QWEN_DEVICE_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
QWEN_DEFAULT_RESOURCE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# Hard-coded Qwen OAuth models (single source of truth)
QWEN_OAUTH_MODELS = [
{"id": "coder-model", "name": "coder-model", "description": "Qwen 3.6 Plus — efficient hybrid model with leading coding performance"},
]
QWEN_OAUTH_ALLOWED_MODELS = [model["id"] for model in QWEN_OAUTH_MODELS]
class OAuthError(RuntimeError):
pass
@dataclass(slots=True)
class DeviceAuthState:
device_code: str
code_verifier: str
user_code: str
verification_uri: str
verification_uri_complete: str
expires_at: float
interval_seconds: float
class QwenOAuthManager:
def __init__(self, creds_path: Path | None = None) -> None:
self.creds_path = creds_path or Path.home() / ".qwen" / "oauth_creds.json"
self.creds_path.parent.mkdir(parents=True, exist_ok=True)
def _post_form(self, url: str, payload: dict[str, str]) -> dict[str, Any]:
# Use http.client instead of urllib.request to avoid WAF blocking
# Parse the URL to extract host and path
parsed = parse.urlparse(url)
host = parsed.netloc
path = parsed.path or "/"
# Build the form data (standard urlencode with + for spaces)
body = parse.urlencode(payload)
conn = http.client.HTTPSConnection(host, timeout=60)
headers = {
"Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
"Accept": "application/json",
"x-request-id": str(uuid.uuid4()),
}
try:
conn.request("POST", path, body=body, headers=headers)
response = conn.getresponse()
raw = response.read()
if response.status >= 400:
try:
error_payload = json.loads(raw.decode("utf-8"))
message = error_payload.get("error_description") or error_payload.get("error") or raw.decode("utf-8")
except json.JSONDecodeError:
message = raw.decode("utf-8", errors="replace")
raise OAuthError(f"HTTP {response.status}: {message}")
return json.loads(raw.decode("utf-8"))
finally:
conn.close()
def load_credentials(self) -> dict[str, Any] | None:
if not self.creds_path.exists():
return None
return json.loads(self.creds_path.read_text(encoding="utf-8"))
def save_credentials(self, payload: dict[str, Any]) -> None:
self.creds_path.write_text(
json.dumps(payload, ensure_ascii=True, indent=2),
encoding="utf-8",
)
def clear_credentials(self) -> None:
if self.creds_path.exists():
self.creds_path.unlink()
def start_device_flow(self, open_browser: bool = False) -> DeviceAuthState:
code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("ascii").rstrip("=")
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
.decode("ascii")
.rstrip("=")
)
response = self._post_form(
QWEN_DEVICE_CODE_ENDPOINT,
{
"client_id": QWEN_CLIENT_ID,
"scope": QWEN_SCOPE,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
},
)
state = DeviceAuthState(
device_code=response["device_code"],
code_verifier=code_verifier,
user_code=response["user_code"],
verification_uri=response["verification_uri"],
verification_uri_complete=response["verification_uri_complete"],
expires_at=time.time() + float(response.get("expires_in", 600)),
interval_seconds=2.0,
)
if open_browser:
try:
webbrowser.open(state.verification_uri_complete)
except Exception:
pass
return state
def poll_device_flow(self, state: DeviceAuthState) -> dict[str, Any] | None:
if time.time() >= state.expires_at:
raise OAuthError("Device authorization expired")
try:
response = self._post_form(
QWEN_TOKEN_ENDPOINT,
{
"grant_type": QWEN_DEVICE_GRANT,
"client_id": QWEN_CLIENT_ID,
"device_code": state.device_code,
"code_verifier": state.code_verifier,
},
)
except OAuthError as exc:
text = str(exc)
# Check for pending authorization (continue polling)
if "authorization_pending" in text or "not yet approved" in text.lower():
return None
if "slow_down" in text:
state.interval_seconds = min(state.interval_seconds * 1.5, 10.0)
return None
raise
creds = {
"access_token": response["access_token"],
"refresh_token": response.get("refresh_token"),
"token_type": response.get("token_type", "Bearer"),
"resource_url": response.get("resource_url"),
"expiry_date": int(time.time() * 1000) + int(response.get("expires_in", 3600)) * 1000,
}
self.save_credentials(creds)
return creds
def refresh_credentials(self, creds: dict[str, Any]) -> dict[str, Any]:
refresh_token = creds.get("refresh_token")
if not refresh_token:
raise OAuthError("No refresh token available")
response = self._post_form(
QWEN_TOKEN_ENDPOINT,
{
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": QWEN_CLIENT_ID,
},
)
updated = {
"access_token": response["access_token"],
"refresh_token": response.get("refresh_token") or refresh_token,
"token_type": response.get("token_type", "Bearer"),
"resource_url": response.get("resource_url") or creds.get("resource_url"),
"expiry_date": int(time.time() * 1000) + int(response.get("expires_in", 3600)) * 1000,
}
self.save_credentials(updated)
return updated
def get_valid_credentials(self) -> dict[str, Any]:
creds = self.load_credentials()
if not creds:
raise OAuthError("Qwen OAuth is not configured. Start device flow first.")
expiry = int(creds.get("expiry_date", 0))
if expiry and expiry - int(time.time() * 1000) > 30_000:
return creds
return self.refresh_credentials(creds)
def get_openai_base_url(self, creds: dict[str, Any]) -> str:
resource_url = creds.get("resource_url") or QWEN_DEFAULT_RESOURCE_URL
if not resource_url.startswith("http"):
resource_url = f"https://{resource_url}"
if resource_url.endswith("/v1"):
return resource_url
return resource_url.rstrip("/") + "/v1"
def get_available_models(self) -> list[dict[str, str]]:
"""Return the list of available Qwen OAuth models (hard-coded, single source of truth)."""
return QWEN_OAUTH_MODELS.copy()
def get_model_name_for_id(self, model_id: str) -> str:
"""Get the actual model name for a given model ID. Returns default if not found."""
for model in QWEN_OAUTH_MODELS:
if model["id"] == model_id:
return model["name"]
# Default to coder-model if unknown
return "coder-model"