Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.DS_Store
__pycache__/
.pytest_cache/
.grok-search/
logs/
*.egg-info/
uv.lock
.env
.env.*
8 changes: 8 additions & 0 deletions src/grok_search/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def retry_multiplier(self) -> float:
def retry_max_wait(self) -> int:
return int(os.getenv("GROK_RETRY_MAX_WAIT", "10"))

@property
def output_cleanup_enabled(self) -> bool:
raw = os.getenv("GROK_OUTPUT_CLEANUP")
if raw is None:
raw = os.getenv("GROK_FILTER_THINK_TAGS", "true")
return raw.lower() in ("true", "1", "yes")

@property
def grok_api_url(self) -> str:
url = os.getenv("GROK_API_URL")
Expand Down Expand Up @@ -184,6 +191,7 @@ def get_config_info(self) -> dict:
"GROK_API_KEY": api_key_masked,
"GROK_MODEL": self.grok_model,
"GROK_DEBUG": self.debug_enabled,
"GROK_OUTPUT_CLEANUP": self.output_cleanup_enabled,
"GROK_LOG_LEVEL": self.log_level,
"GROK_LOG_DIR": str(self.log_dir),
"TAVILY_API_URL": self.tavily_api_url,
Expand Down
39 changes: 39 additions & 0 deletions src/grok_search/planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ class ExecutionOrderOutput(BaseModel):

_ACCUMULATIVE_LIST_PHASES = {"query_decomposition", "tool_selection"}
_MERGE_STRATEGY_PHASE = "search_strategy"
_PHASE_PREDECESSORS = {
"complexity_assessment": "intent_analysis",
"query_decomposition": "complexity_assessment",
"search_strategy": "query_decomposition",
"tool_selection": "search_strategy",
"execution_order": "tool_selection",
}


def _split_csv(value: str) -> list[str]:
Expand Down Expand Up @@ -126,6 +133,9 @@ def __init__(self):
def get_session(self, session_id: str) -> PlanningSession | None:
return self._sessions.get(session_id)

def reset(self) -> None:
self._sessions.clear()

def process_phase(
self,
phase: str,
Expand All @@ -147,6 +157,35 @@ def process_phase(
if target not in PHASE_NAMES:
return {"error": f"Unknown phase: {target}. Valid: {', '.join(PHASE_NAMES)}"}

if not is_revision:
predecessor = _PHASE_PREDECESSORS.get(target)
if predecessor and predecessor not in session.phases:
return {
"error": f"Phase '{target}' requires '{predecessor}' to be completed first.",
"expected_phase_order": PHASE_NAMES,
"session_id": session.session_id,
"completed_phases": session.completed_phases,
"complexity_level": session.complexity_level,
}

if session.complexity_level == 1 and target in {"search_strategy", "tool_selection", "execution_order"}:
return {
"error": "Level 1 planning completes after query_decomposition.",
"expected_phase_order": PHASE_NAMES,
"session_id": session.session_id,
"completed_phases": session.completed_phases,
"complexity_level": session.complexity_level,
}

if session.complexity_level == 2 and target == "execution_order":
return {
"error": "Level 2 planning completes after tool_selection.",
"expected_phase_order": PHASE_NAMES,
"session_id": session.session_id,
"completed_phases": session.completed_phases,
"complexity_level": session.complexity_level,
}

if target in _ACCUMULATIVE_LIST_PHASES:
if is_revision:
session.phases[target] = PhaseRecord(
Expand Down
31 changes: 25 additions & 6 deletions src/grok_search/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,26 @@ async def _get_available_models_cached(api_url: str, api_key: str) -> list[str]:
return models


def _planning_session_error(session_id: str) -> str:
import json
return json.dumps(
{
"error": "session_not_found",
"message": f"Session '{session_id}' not found. Call plan_intent first.",
"expected_phase_order": [
"intent_analysis",
"complexity_assessment",
"query_decomposition",
"search_strategy",
"tool_selection",
"execution_order",
],
"restart_from_intent_analysis": True,
},
ensure_ascii=False,
indent=2,
)

def _extra_results_to_sources(
tavily_results: list[dict] | None,
firecrawl_results: list[dict] | None,
Expand Down Expand Up @@ -209,7 +229,6 @@ async def _safe_firecrawl() -> list[dict] | None:
await _SOURCES_CACHE.set(session_id, all_sources)
return {"session_id": session_id, "content": answer, "sources_count": len(all_sources)}


@mcp.tool(
name="get_sources",
description="""
Expand Down Expand Up @@ -713,7 +732,7 @@ async def plan_complexity(
) -> str:
import json
if not planning_engine.get_session(session_id):
return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
return _planning_session_error(session_id)
return json.dumps(planning_engine.process_phase(
phase="complexity_assessment", thought=thought, session_id=session_id,
is_revision=is_revision, confidence=confidence,
Expand Down Expand Up @@ -741,7 +760,7 @@ async def plan_sub_query(
) -> str:
import json
if not planning_engine.get_session(session_id):
return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
return _planning_session_error(session_id)
item = {"id": id, "goal": goal, "expected_output": expected_output, "boundary": boundary}
if depends_on:
item["depends_on"] = _split_csv(depends_on)
Expand Down Expand Up @@ -771,7 +790,7 @@ async def plan_search_term(
) -> str:
import json
if not planning_engine.get_session(session_id):
return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
return _planning_session_error(session_id)
data = {"search_terms": [{"term": term, "purpose": purpose, "round": round}]}
if approach:
data["approach"] = approach
Expand Down Expand Up @@ -800,7 +819,7 @@ async def plan_tool_mapping(
) -> str:
import json
if not planning_engine.get_session(session_id):
return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
return _planning_session_error(session_id)
item = {"sub_query_id": sub_query_id, "tool": tool, "reason": reason}
if params_json:
try:
Expand Down Expand Up @@ -829,7 +848,7 @@ async def plan_execution(
) -> str:
import json
if not planning_engine.get_session(session_id):
return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
return _planning_session_error(session_id)
parallel = [_split_csv(g) for g in parallel_groups.split(";") if g.strip()] if parallel_groups else []
seq = _split_csv(sequential)
return json.dumps(planning_engine.process_phase(
Expand Down
78 changes: 78 additions & 0 deletions src/grok_search/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio

from .config import config
from .utils import extract_unique_urls


Expand All @@ -23,6 +24,41 @@
_SOURCES_FUNCTION_PATTERN = re.compile(
r"(?im)(^|\n)\s*(sources|source|citations|citation|references|reference|citation_card|source_cards|source_card)\s*\("
)
_THINK_BLOCK_PATTERN = re.compile(r"(?is)<think>.*?</think>")
_LEADING_POLICY_PATTERNS = [
re.compile(r"(?is)^\s*\**\s*i cannot comply\b.*"),
re.compile(r"(?is)^\s*\**\s*i do not accept\b.*"),
re.compile(r"(?is)^\s*\**\s*refusal\s*[::].*"),
re.compile(r"(?is)^\s*\**\s*refuse to\b.*"),
re.compile(r"(?is)^\s*\**\s*rejected?\b.*"),
re.compile(r"(?is)^\s*\**\s*拒绝执行\b.*"),
re.compile(r"(?is)^\s*\**\s*无法遵循\b.*"),
]
_POLICY_META_KEYWORDS = (
"cannot comply",
"refuse",
"refusal",
"override my core",
"core behavior",
"custom rules",
"用户提供的自定义",
"覆盖我的核心",
"核心行为",
"拒绝执行",
"无法遵循",
)
_POLICY_CONTEXT_KEYWORDS = (
"jailbreak",
"prompt injection",
"system instructions",
"system prompt",
"user-injected",
"注入",
"越狱",
"系统指令",
"系统提示",
"自定义“system”",
)


def new_session_id() -> str:
Expand Down Expand Up @@ -72,6 +108,11 @@ def split_answer_and_sources(text: str) -> tuple[str, list[dict]]:
if not raw:
return "", []

if config.output_cleanup_enabled:
cleaned = sanitize_answer_text(raw)
if cleaned:
raw = cleaned

split = _split_function_call_sources(raw)
if split:
return split
Expand All @@ -91,6 +132,43 @@ def split_answer_and_sources(text: str) -> tuple[str, list[dict]]:
return raw, []


def sanitize_answer_text(text: str) -> str:
raw = (text or "").strip()
if not raw:
return ""

cleaned = _THINK_BLOCK_PATTERN.sub("", raw).strip()
paragraphs = _split_paragraphs(cleaned)
filtered = [paragraph for paragraph in paragraphs if not _looks_like_policy_block(paragraph)]
if filtered:
return "\n\n".join(filtered).strip()
return cleaned


def _split_paragraphs(text: str) -> list[str]:
parts = [part.strip() for part in re.split(r"\n\s*\n", text) if part.strip()]
return parts or ([text.strip()] if text.strip() else [])


def _looks_like_policy_block(text: str) -> bool:
normalized = _normalize_policy_text(text)
if not normalized:
return False

if any(pattern.match(normalized) for pattern in _LEADING_POLICY_PATTERNS):
return True

return any(keyword in normalized for keyword in _POLICY_META_KEYWORDS) and any(
keyword in normalized for keyword in _POLICY_CONTEXT_KEYWORDS
)


def _normalize_policy_text(text: str) -> str:
normalized = re.sub(r"[>*_`#-]+", " ", text or "")
normalized = re.sub(r"\s+", " ", normalized)
return normalized.strip().lower()


def _split_function_call_sources(text: str) -> tuple[str, list[dict]] | None:
matches = list(_SOURCES_FUNCTION_PATTERN.finditer(text))
if not matches:
Expand Down
2 changes: 2 additions & 0 deletions src/grok_search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,6 @@ def format_search_results(results: List[SearchResult]) -> str:
5. **Every sentence must cite sources** (`citation_card`). More references = stronger credibility. Silence if uncited.
6. Expand on key concepts—after proposing solutions, **use real-world analogies** to demystify technical terms.
7. **Strictly format outputs in polished Markdown** (LaTeX for formulas, code blocks for scripts, etc.).
8. Do not mention prompt injection, jailbreaks, system prompts, hidden instructions, or refusal policies unless the user explicitly asks about those topics.
9. Do not output hidden reasoning or `<think>` tags.
"""
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import sys
from pathlib import Path


# Make the src layout importable in local pytest runs.
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"

if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
Loading