new-qwen/serv/web_search.py

269 lines
8.9 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any
from urllib import parse, request
from config import ServerConfig
from oauth import QwenOAuthManager
class WebSearchError(RuntimeError):
pass
@dataclass(slots=True)
class WebSearchResultItem:
title: str
url: str
content: str = ""
score: float | None = None
published_date: str | None = None
@dataclass(slots=True)
class WebSearchResult:
query: str
provider: str
answer: str
results: list[WebSearchResultItem]
class BaseWebSearchProvider:
name = "base"
def is_available(self) -> bool:
raise NotImplementedError
def search(self, query: str) -> WebSearchResult:
if not self.is_available():
raise WebSearchError(f"[{self.name}] Provider is not available")
try:
return self.perform_search(query)
except WebSearchError:
raise
except Exception as exc:
raise WebSearchError(f"[{self.name}] Search failed: {exc}") from exc
def perform_search(self, query: str) -> WebSearchResult:
raise NotImplementedError
class DashScopeWebSearchProvider(BaseWebSearchProvider):
name = "dashscope"
def __init__(self, oauth: QwenOAuthManager) -> None:
self.oauth = oauth
def is_available(self) -> bool:
try:
creds = self.oauth.load_credentials()
except Exception:
return False
return bool(creds and creds.get("resource_url"))
def perform_search(self, query: str) -> WebSearchResult:
creds = self.oauth.get_valid_credentials()
access_token = creds.get("access_token")
resource_url = creds.get("resource_url")
if not access_token or not resource_url:
raise WebSearchError("[dashscope] Qwen OAuth credentials are not available")
base_url = resource_url if str(resource_url).startswith("http") else f"https://{resource_url}"
api_endpoint = base_url.rstrip("/") + "/api/v1/indices/plugin/web_search"
payload = json.dumps({"uq": query, "page": 1, "rows": 10}).encode("utf-8")
req = request.Request(
api_endpoint,
data=payload,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
},
method="POST",
)
with request.urlopen(req, timeout=60) as response:
data = json.loads(response.read().decode("utf-8"))
if data.get("status") != 0:
raise WebSearchError(
f"[dashscope] API error: {data.get('message') or 'unknown error'}"
)
docs = (data.get("data") or {}).get("docs") or []
results = [
WebSearchResultItem(
title=item.get("title") or "Untitled",
url=item.get("url") or "",
content=item.get("snippet") or "",
score=item.get("_score"),
published_date=item.get("timestamp_format"),
)
for item in docs
]
return WebSearchResult(
query=query,
provider=self.name,
answer="",
results=results,
)
class TavilyWebSearchProvider(BaseWebSearchProvider):
name = "tavily"
def __init__(self, api_key: str) -> None:
self.api_key = api_key
def is_available(self) -> bool:
return bool(self.api_key)
def perform_search(self, query: str) -> WebSearchResult:
payload = json.dumps(
{
"api_key": self.api_key,
"query": query,
"search_depth": "advanced",
"max_results": 5,
"include_answer": True,
}
).encode("utf-8")
req = request.Request(
"https://api.tavily.com/search",
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
with request.urlopen(req, timeout=60) as response:
data = json.loads(response.read().decode("utf-8"))
results = [
WebSearchResultItem(
title=item.get("title") or "Untitled",
url=item.get("url") or "",
content=item.get("content") or "",
score=item.get("score"),
published_date=item.get("published_date"),
)
for item in data.get("results") or []
]
return WebSearchResult(
query=query,
provider=self.name,
answer=(data.get("answer") or "").strip(),
results=results,
)
class GoogleWebSearchProvider(BaseWebSearchProvider):
name = "google"
def __init__(self, api_key: str, search_engine_id: str) -> None:
self.api_key = api_key
self.search_engine_id = search_engine_id
def is_available(self) -> bool:
return bool(self.api_key and self.search_engine_id)
def perform_search(self, query: str) -> WebSearchResult:
params = parse.urlencode(
{
"key": self.api_key,
"cx": self.search_engine_id,
"q": query,
"num": "10",
"safe": "medium",
}
)
url = f"https://www.googleapis.com/customsearch/v1?{params}"
with request.urlopen(url, timeout=60) as response:
data = json.loads(response.read().decode("utf-8"))
results = [
WebSearchResultItem(
title=item.get("title") or "Untitled",
url=item.get("link") or "",
content=item.get("snippet") or "",
)
for item in data.get("items") or []
]
return WebSearchResult(
query=query,
provider=self.name,
answer="",
results=results,
)
class WebSearchService:
def __init__(self, config: ServerConfig, oauth: QwenOAuthManager) -> None:
self.providers: dict[str, BaseWebSearchProvider] = {
"dashscope": DashScopeWebSearchProvider(oauth),
"tavily": TavilyWebSearchProvider(config.tavily_api_key),
"google": GoogleWebSearchProvider(
config.google_search_api_key,
config.google_search_engine_id,
),
}
def list_available_providers(self) -> list[str]:
return [
provider_name
for provider_name, provider in self.providers.items()
if provider.is_available()
]
def search(self, query: str, provider_name: str | None = None) -> dict[str, Any]:
available = {
name: provider
for name, provider in self.providers.items()
if provider.is_available()
}
if not available:
raise WebSearchError("No web search providers are available")
if provider_name:
provider = available.get(provider_name)
if not provider:
raise WebSearchError(
f'Provider "{provider_name}" is not available. Available: {", ".join(sorted(available))}'
)
else:
provider = available.get("dashscope") or next(iter(available.values()))
result = provider.search(query)
return self._format_result(result)
def _format_result(self, result: WebSearchResult) -> dict[str, Any]:
sources = [{"title": item.title, "url": item.url} for item in result.results]
if result.answer.strip():
content = result.answer.strip()
if sources:
source_lines = "\n".join(
f"[{index + 1}] {source['title']} ({source['url']})"
for index, source in enumerate(sources)
)
content += f"\n\nSources:\n{source_lines}"
else:
blocks: list[str] = []
for index, item in enumerate(result.results[:5], start=1):
parts = [f"{index}. {item.title}"]
if item.content:
parts.append(item.content.strip())
parts.append(f"Source: {item.url}")
if item.published_date:
parts.append(f"Published: {item.published_date}")
blocks.append("\n".join(parts))
content = "\n\n".join(blocks)
if content:
content += "\n\nNote: Use web_fetch-like follow-up tooling for deeper page content."
return {
"query": result.query,
"provider": result.provider,
"content": content or f'No search results found for "{result.query}".',
"sources": sources,
"results": [
{
"title": item.title,
"url": item.url,
"content": item.content,
"score": item.score,
"published_date": item.published_date,
}
for item in result.results
],
}