Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions src/grok_search/providers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,16 @@ def __init__(self, api_url: str, api_key: str, model: str = "grok-4-fast"):
def get_provider_name(self) -> str:
return "Grok"

async def search(self, query: str, platform: str = "", min_results: int = 3, max_results: int = 10, ctx=None) -> List[SearchResult]:
headers = {
def _build_api_headers(self) -> dict:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"User-Agent": "grok-search-mcp/0.1.0",
}

async def search(self, query: str, platform: str = "", min_results: int = 3, max_results: int = 10, ctx=None) -> List[SearchResult]:
headers = self._build_api_headers()
platform_prompt = ""

if platform:
Expand All @@ -146,18 +151,15 @@ async def search(self, query: str, platform: str = "", min_results: int = 3, max
},
{"role": "user", "content": time_context + query + platform_prompt},
],
"stream": True,
"stream": False,
}

await log_info(ctx, f"platform_prompt: { query + platform_prompt}", config.debug_enabled)

return await self._execute_stream_with_retry(headers, payload, ctx)
return await self._execute_completion_with_retry(headers, payload, ctx)

async def fetch(self, url: str, ctx=None) -> str:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
headers = self._build_api_headers()
payload = {
"model": self.model,
"messages": [
Expand All @@ -167,9 +169,9 @@ async def fetch(self, url: str, ctx=None) -> str:
},
{"role": "user", "content": url + "\n获取该网页内容并返回其结构化Markdown格式" },
],
"stream": True,
"stream": False,
}
return await self._execute_stream_with_retry(headers, payload, ctx)
return await self._execute_completion_with_retry(headers, payload, ctx)

async def _parse_streaming_response(self, response, ctx=None) -> str:
content = ""
Expand Down Expand Up @@ -212,6 +214,37 @@ async def _parse_streaming_response(self, response, ctx=None) -> str:

return content

async def _parse_completion_response(self, response: httpx.Response, ctx=None) -> str:
content = ""
body_text = response.text or ""

try:
data = response.json()
except Exception:
data = None

if isinstance(data, dict):
choices = data.get("choices", [])
if choices:
message = choices[0].get("message", {})
if isinstance(message, dict):
content = message.get("content", "") or ""

if not content and body_text.lstrip().startswith("data:"):
class _LineResponse:
def __init__(self, text: str):
self._lines = text.splitlines()

async def aiter_lines(self):
for line in self._lines:
yield line

content = await self._parse_streaming_response(_LineResponse(body_text), ctx)

await log_info(ctx, f"content: {content}", config.debug_enabled)

return content

async def _execute_stream_with_retry(self, headers: dict, payload: dict, ctx=None) -> str:
"""执行带重试机制的流式 HTTP 请求"""
timeout = httpx.Timeout(connect=6.0, read=120.0, write=10.0, pool=None)
Expand All @@ -233,21 +266,38 @@ async def _execute_stream_with_retry(self, headers: dict, payload: dict, ctx=Non
response.raise_for_status()
return await self._parse_streaming_response(response, ctx)

async def _execute_completion_with_retry(self, headers: dict, payload: dict, ctx=None) -> str:
"""执行带重试机制的非流式 HTTP 请求,兼容上游返回 JSON 或 SSE 文本"""
timeout = httpx.Timeout(connect=6.0, read=120.0, write=10.0, pool=None)

async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(config.retry_max_attempts + 1),
wait=_WaitWithRetryAfter(config.retry_multiplier, config.retry_max_wait),
retry=retry_if_exception(_is_retryable_exception),
reraise=True,
):
with attempt:
response = await client.post(
f"{self.api_url}/chat/completions",
headers=headers,
json=payload,
)
response.raise_for_status()
return await self._parse_completion_response(response, ctx)

async def describe_url(self, url: str, ctx=None) -> dict:
"""让 Grok 阅读单个 URL 并返回 title + extracts"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
headers = self._build_api_headers()
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": url_describe_prompt},
{"role": "user", "content": url},
],
"stream": True,
"stream": False,
}
result = await self._execute_stream_with_retry(headers, payload, ctx)
result = await self._execute_completion_with_retry(headers, payload, ctx)
title, extracts = url, ""
for line in result.strip().splitlines():
if line.startswith("Title:"):
Expand All @@ -258,19 +308,16 @@ async def describe_url(self, url: str, ctx=None) -> dict:

async def rank_sources(self, query: str, sources_text: str, total: int, ctx=None) -> list[int]:
"""让 Grok 按查询相关度对信源排序,返回排序后的序号列表"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
headers = self._build_api_headers()
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": rank_sources_prompt},
{"role": "user", "content": f"Query: {query}\n\n{sources_text}"},
],
"stream": True,
"stream": False,
}
result = await self._execute_stream_with_retry(headers, payload, ctx)
result = await self._execute_completion_with_retry(headers, payload, ctx)
order: list[int] = []
seen: set[int] = set()
for token in result.strip().split():
Expand Down
65 changes: 65 additions & 0 deletions tests/test_grok_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest

from grok_search.providers.grok import GrokSearchProvider


class DummyResponse:
def __init__(self, text="", json_data=None, json_error=None):
self.text = text
self._json_data = json_data
self._json_error = json_error

def json(self):
if self._json_error is not None:
raise self._json_error
return self._json_data


@pytest.mark.asyncio
async def test_search_uses_non_stream_completion_and_user_agent(monkeypatch):
provider = GrokSearchProvider("https://api.example.com", "test-key", "test-model")
captured = {}

async def fake_execute(headers, payload, ctx):
captured["headers"] = headers
captured["payload"] = payload
return "ok"

monkeypatch.setattr(provider, "_execute_completion_with_retry", fake_execute)

result = await provider.search("What is Scrape.do?")

assert result == "ok"
assert captured["headers"]["User-Agent"] == "grok-search-mcp/0.1.0"
assert captured["headers"]["Accept"] == "application/json, text/event-stream"
assert captured["payload"]["stream"] is False


@pytest.mark.asyncio
async def test_parse_completion_response_reads_json_message():
provider = GrokSearchProvider("https://api.example.com", "test-key", "test-model")
response = DummyResponse(
text='{"choices":[{"message":{"content":"hello world"}}]}',
json_data={"choices": [{"message": {"content": "hello world"}}]},
)

result = await provider._parse_completion_response(response)

assert result == "hello world"


@pytest.mark.asyncio
async def test_parse_completion_response_falls_back_to_sse_text():
provider = GrokSearchProvider("https://api.example.com", "test-key", "test-model")
response = DummyResponse(
text=(
'data: {"choices":[{"delta":{"content":"hello"}}]}\n\n'
'data: {"choices":[{"delta":{"content":" world"}}]}\n\n'
'data: [DONE]\n'
),
json_error=ValueError("not json"),
)

result = await provider._parse_completion_response(response)

assert result == "hello world"