diff --git a/src/grok_search/providers/grok.py b/src/grok_search/providers/grok.py index bd2820a..e7ec99d 100644 --- a/src/grok_search/providers/grok.py +++ b/src/grok_search/providers/grok.py @@ -125,11 +125,16 @@ def __init__(self, api_url: str, api_key: str, model: str = "grok-4-fast"): def get_provider_name(self) -> str: return "Grok" - async def search(self, query: str, platform: str = "", min_results: int = 3, max_results: int = 10, ctx=None) -> List[SearchResult]: - headers = { + def _build_api_headers(self) -> dict: + return { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "User-Agent": "grok-search-mcp/0.1.0", } + + async def search(self, query: str, platform: str = "", min_results: int = 3, max_results: int = 10, ctx=None) -> List[SearchResult]: + headers = self._build_api_headers() platform_prompt = "" if platform: @@ -146,18 +151,15 @@ async def search(self, query: str, platform: str = "", min_results: int = 3, max }, {"role": "user", "content": time_context + query + platform_prompt}, ], - "stream": True, + "stream": False, } await log_info(ctx, f"platform_prompt: { query + platform_prompt}", config.debug_enabled) - return await self._execute_stream_with_retry(headers, payload, ctx) + return await self._execute_completion_with_retry(headers, payload, ctx) async def fetch(self, url: str, ctx=None) -> str: - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } + headers = self._build_api_headers() payload = { "model": self.model, "messages": [ @@ -167,9 +169,9 @@ async def fetch(self, url: str, ctx=None) -> str: }, {"role": "user", "content": url + "\n获取该网页内容并返回其结构化Markdown格式" }, ], - "stream": True, + "stream": False, } - return await self._execute_stream_with_retry(headers, payload, ctx) + return await self._execute_completion_with_retry(headers, payload, ctx) async def _parse_streaming_response(self, response, ctx=None) -> str: content = "" @@ -212,6 +214,37 @@ async def _parse_streaming_response(self, response, ctx=None) -> str: return content + async def _parse_completion_response(self, response: httpx.Response, ctx=None) -> str: + content = "" + body_text = response.text or "" + + try: + data = response.json() + except Exception: + data = None + + if isinstance(data, dict): + choices = data.get("choices", []) + if choices: + message = choices[0].get("message", {}) + if isinstance(message, dict): + content = message.get("content", "") or "" + + if not content and body_text.lstrip().startswith("data:"): + class _LineResponse: + def __init__(self, text: str): + self._lines = text.splitlines() + + async def aiter_lines(self): + for line in self._lines: + yield line + + content = await self._parse_streaming_response(_LineResponse(body_text), ctx) + + await log_info(ctx, f"content: {content}", config.debug_enabled) + + return content + async def _execute_stream_with_retry(self, headers: dict, payload: dict, ctx=None) -> str: """执行带重试机制的流式 HTTP 请求""" timeout = httpx.Timeout(connect=6.0, read=120.0, write=10.0, pool=None) @@ -233,21 +266,38 @@ async def _execute_stream_with_retry(self, headers: dict, payload: dict, ctx=Non response.raise_for_status() return await self._parse_streaming_response(response, ctx) + async def _execute_completion_with_retry(self, headers: dict, payload: dict, ctx=None) -> str: + """执行带重试机制的非流式 HTTP 请求,兼容上游返回 JSON 或 SSE 文本""" + timeout = httpx.Timeout(connect=6.0, read=120.0, write=10.0, pool=None) + + async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client: + async for attempt in AsyncRetrying( + stop=stop_after_attempt(config.retry_max_attempts + 1), + wait=_WaitWithRetryAfter(config.retry_multiplier, config.retry_max_wait), + retry=retry_if_exception(_is_retryable_exception), + reraise=True, + ): + with attempt: + response = await client.post( + f"{self.api_url}/chat/completions", + headers=headers, + json=payload, + ) + response.raise_for_status() + return await self._parse_completion_response(response, ctx) + async def describe_url(self, url: str, ctx=None) -> dict: """让 Grok 阅读单个 URL 并返回 title + extracts""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } + headers = self._build_api_headers() payload = { "model": self.model, "messages": [ {"role": "system", "content": url_describe_prompt}, {"role": "user", "content": url}, ], - "stream": True, + "stream": False, } - result = await self._execute_stream_with_retry(headers, payload, ctx) + result = await self._execute_completion_with_retry(headers, payload, ctx) title, extracts = url, "" for line in result.strip().splitlines(): if line.startswith("Title:"): @@ -258,19 +308,16 @@ async def describe_url(self, url: str, ctx=None) -> dict: async def rank_sources(self, query: str, sources_text: str, total: int, ctx=None) -> list[int]: """让 Grok 按查询相关度对信源排序,返回排序后的序号列表""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } + headers = self._build_api_headers() payload = { "model": self.model, "messages": [ {"role": "system", "content": rank_sources_prompt}, {"role": "user", "content": f"Query: {query}\n\n{sources_text}"}, ], - "stream": True, + "stream": False, } - result = await self._execute_stream_with_retry(headers, payload, ctx) + result = await self._execute_completion_with_retry(headers, payload, ctx) order: list[int] = [] seen: set[int] = set() for token in result.strip().split(): diff --git a/tests/test_grok_provider.py b/tests/test_grok_provider.py new file mode 100644 index 0000000..ba9fa13 --- /dev/null +++ b/tests/test_grok_provider.py @@ -0,0 +1,65 @@ +import pytest + +from grok_search.providers.grok import GrokSearchProvider + + +class DummyResponse: + def __init__(self, text="", json_data=None, json_error=None): + self.text = text + self._json_data = json_data + self._json_error = json_error + + def json(self): + if self._json_error is not None: + raise self._json_error + return self._json_data + + +@pytest.mark.asyncio +async def test_search_uses_non_stream_completion_and_user_agent(monkeypatch): + provider = GrokSearchProvider("https://api.example.com", "test-key", "test-model") + captured = {} + + async def fake_execute(headers, payload, ctx): + captured["headers"] = headers + captured["payload"] = payload + return "ok" + + monkeypatch.setattr(provider, "_execute_completion_with_retry", fake_execute) + + result = await provider.search("What is Scrape.do?") + + assert result == "ok" + assert captured["headers"]["User-Agent"] == "grok-search-mcp/0.1.0" + assert captured["headers"]["Accept"] == "application/json, text/event-stream" + assert captured["payload"]["stream"] is False + + +@pytest.mark.asyncio +async def test_parse_completion_response_reads_json_message(): + provider = GrokSearchProvider("https://api.example.com", "test-key", "test-model") + response = DummyResponse( + text='{"choices":[{"message":{"content":"hello world"}}]}', + json_data={"choices": [{"message": {"content": "hello world"}}]}, + ) + + result = await provider._parse_completion_response(response) + + assert result == "hello world" + + +@pytest.mark.asyncio +async def test_parse_completion_response_falls_back_to_sse_text(): + provider = GrokSearchProvider("https://api.example.com", "test-key", "test-model") + response = DummyResponse( + text=( + 'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n' + 'data: {"choices":[{"delta":{"content":" world"}}]}\n\n' + 'data: [DONE]\n' + ), + json_error=ValueError("not json"), + ) + + result = await provider._parse_completion_response(response) + + assert result == "hello world"