213 lines
8.1 KiB
Python
213 lines
8.1 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import http.client
|
|
import json
|
|
import secrets
|
|
import time
|
|
import uuid
|
|
import webbrowser
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib import parse
|
|
|
|
|
|
QWEN_OAUTH_BASE_URL = "https://chat.qwen.ai"
|
|
QWEN_DEVICE_CODE_ENDPOINT = f"{QWEN_OAUTH_BASE_URL}/api/v1/oauth2/device/code"
|
|
QWEN_TOKEN_ENDPOINT = f"{QWEN_OAUTH_BASE_URL}/api/v1/oauth2/token"
|
|
QWEN_CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
|
|
QWEN_SCOPE = "openid profile email model.completion"
|
|
QWEN_DEVICE_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
|
|
QWEN_DEFAULT_RESOURCE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
|
|
# Hard-coded Qwen OAuth models (single source of truth)
|
|
QWEN_OAUTH_MODELS = [
|
|
{"id": "coder-model", "name": "coder-model", "description": "Qwen 3.6 Plus — efficient hybrid model with leading coding performance"},
|
|
]
|
|
QWEN_OAUTH_ALLOWED_MODELS = [model["id"] for model in QWEN_OAUTH_MODELS]
|
|
|
|
|
|
class OAuthError(RuntimeError):
|
|
pass
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class DeviceAuthState:
|
|
device_code: str
|
|
code_verifier: str
|
|
user_code: str
|
|
verification_uri: str
|
|
verification_uri_complete: str
|
|
expires_at: float
|
|
interval_seconds: float
|
|
|
|
|
|
class QwenOAuthManager:
|
|
def __init__(self, creds_path: Path | None = None) -> None:
|
|
self.creds_path = creds_path or Path.home() / ".qwen" / "oauth_creds.json"
|
|
self.creds_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
def _post_form(self, url: str, payload: dict[str, str]) -> dict[str, Any]:
|
|
# Use http.client instead of urllib.request to avoid WAF blocking
|
|
# Parse the URL to extract host and path
|
|
parsed = parse.urlparse(url)
|
|
host = parsed.netloc
|
|
path = parsed.path or "/"
|
|
|
|
# Build the form data (standard urlencode with + for spaces)
|
|
body = parse.urlencode(payload)
|
|
|
|
conn = http.client.HTTPSConnection(host, timeout=60)
|
|
headers = {
|
|
"Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
|
|
"Accept": "application/json",
|
|
"x-request-id": str(uuid.uuid4()),
|
|
}
|
|
try:
|
|
conn.request("POST", path, body=body, headers=headers)
|
|
response = conn.getresponse()
|
|
raw = response.read()
|
|
if response.status >= 400:
|
|
try:
|
|
error_payload = json.loads(raw.decode("utf-8"))
|
|
message = error_payload.get("error_description") or error_payload.get("error") or raw.decode("utf-8")
|
|
except json.JSONDecodeError:
|
|
message = raw.decode("utf-8", errors="replace")
|
|
raise OAuthError(f"HTTP {response.status}: {message}")
|
|
return json.loads(raw.decode("utf-8"))
|
|
finally:
|
|
conn.close()
|
|
|
|
def load_credentials(self) -> dict[str, Any] | None:
|
|
if not self.creds_path.exists():
|
|
return None
|
|
return json.loads(self.creds_path.read_text(encoding="utf-8"))
|
|
|
|
def save_credentials(self, payload: dict[str, Any]) -> None:
|
|
self.creds_path.write_text(
|
|
json.dumps(payload, ensure_ascii=True, indent=2),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
def clear_credentials(self) -> None:
|
|
if self.creds_path.exists():
|
|
self.creds_path.unlink()
|
|
|
|
def start_device_flow(self, open_browser: bool = False) -> DeviceAuthState:
|
|
code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("ascii").rstrip("=")
|
|
code_challenge = (
|
|
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
|
|
.decode("ascii")
|
|
.rstrip("=")
|
|
)
|
|
response = self._post_form(
|
|
QWEN_DEVICE_CODE_ENDPOINT,
|
|
{
|
|
"client_id": QWEN_CLIENT_ID,
|
|
"scope": QWEN_SCOPE,
|
|
"code_challenge": code_challenge,
|
|
"code_challenge_method": "S256",
|
|
},
|
|
)
|
|
state = DeviceAuthState(
|
|
device_code=response["device_code"],
|
|
code_verifier=code_verifier,
|
|
user_code=response["user_code"],
|
|
verification_uri=response["verification_uri"],
|
|
verification_uri_complete=response["verification_uri_complete"],
|
|
expires_at=time.time() + float(response.get("expires_in", 600)),
|
|
interval_seconds=2.0,
|
|
)
|
|
if open_browser:
|
|
try:
|
|
webbrowser.open(state.verification_uri_complete)
|
|
except Exception:
|
|
pass
|
|
return state
|
|
|
|
def poll_device_flow(self, state: DeviceAuthState) -> dict[str, Any] | None:
|
|
if time.time() >= state.expires_at:
|
|
raise OAuthError("Device authorization expired")
|
|
try:
|
|
response = self._post_form(
|
|
QWEN_TOKEN_ENDPOINT,
|
|
{
|
|
"grant_type": QWEN_DEVICE_GRANT,
|
|
"client_id": QWEN_CLIENT_ID,
|
|
"device_code": state.device_code,
|
|
"code_verifier": state.code_verifier,
|
|
},
|
|
)
|
|
except OAuthError as exc:
|
|
text = str(exc)
|
|
# Check for pending authorization (continue polling)
|
|
if "authorization_pending" in text or "not yet approved" in text.lower():
|
|
return None
|
|
if "slow_down" in text:
|
|
state.interval_seconds = min(state.interval_seconds * 1.5, 10.0)
|
|
return None
|
|
raise
|
|
|
|
creds = {
|
|
"access_token": response["access_token"],
|
|
"refresh_token": response.get("refresh_token"),
|
|
"token_type": response.get("token_type", "Bearer"),
|
|
"resource_url": response.get("resource_url"),
|
|
"expiry_date": int(time.time() * 1000) + int(response.get("expires_in", 3600)) * 1000,
|
|
}
|
|
self.save_credentials(creds)
|
|
return creds
|
|
|
|
def refresh_credentials(self, creds: dict[str, Any]) -> dict[str, Any]:
|
|
refresh_token = creds.get("refresh_token")
|
|
if not refresh_token:
|
|
raise OAuthError("No refresh token available")
|
|
response = self._post_form(
|
|
QWEN_TOKEN_ENDPOINT,
|
|
{
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": QWEN_CLIENT_ID,
|
|
},
|
|
)
|
|
updated = {
|
|
"access_token": response["access_token"],
|
|
"refresh_token": response.get("refresh_token") or refresh_token,
|
|
"token_type": response.get("token_type", "Bearer"),
|
|
"resource_url": response.get("resource_url") or creds.get("resource_url"),
|
|
"expiry_date": int(time.time() * 1000) + int(response.get("expires_in", 3600)) * 1000,
|
|
}
|
|
self.save_credentials(updated)
|
|
return updated
|
|
|
|
def get_valid_credentials(self) -> dict[str, Any]:
|
|
creds = self.load_credentials()
|
|
if not creds:
|
|
raise OAuthError("Qwen OAuth is not configured. Start device flow first.")
|
|
expiry = int(creds.get("expiry_date", 0))
|
|
if expiry and expiry - int(time.time() * 1000) > 30_000:
|
|
return creds
|
|
return self.refresh_credentials(creds)
|
|
|
|
def get_openai_base_url(self, creds: dict[str, Any]) -> str:
|
|
resource_url = creds.get("resource_url") or QWEN_DEFAULT_RESOURCE_URL
|
|
if not resource_url.startswith("http"):
|
|
resource_url = f"https://{resource_url}"
|
|
if resource_url.endswith("/v1"):
|
|
return resource_url
|
|
return resource_url.rstrip("/") + "/v1"
|
|
|
|
def get_available_models(self) -> list[dict[str, str]]:
|
|
"""Return the list of available Qwen OAuth models (hard-coded, single source of truth)."""
|
|
return QWEN_OAUTH_MODELS.copy()
|
|
|
|
def get_model_name_for_id(self, model_id: str) -> str:
|
|
"""Get the actual model name for a given model ID. Returns default if not found."""
|
|
for model in QWEN_OAUTH_MODELS:
|
|
if model["id"] == model_id:
|
|
return model["name"]
|
|
# Default to coder-model if unknown
|
|
return "coder-model"
|