diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..d1c9774
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+.DS_Store
+__pycache__/
+.pytest_cache/
+.grok-search/
+logs/
+*.egg-info/
+uv.lock
+.env
+.env.*
diff --git a/src/grok_search/config.py b/src/grok_search/config.py
index bdfbfd6..ea88507 100644
--- a/src/grok_search/config.py
+++ b/src/grok_search/config.py
@@ -63,6 +63,13 @@ def retry_multiplier(self) -> float:
def retry_max_wait(self) -> int:
return int(os.getenv("GROK_RETRY_MAX_WAIT", "10"))
+ @property
+ def output_cleanup_enabled(self) -> bool:
+ raw = os.getenv("GROK_OUTPUT_CLEANUP")
+ if raw is None:
+ raw = os.getenv("GROK_FILTER_THINK_TAGS", "true")
+ return raw.lower() in ("true", "1", "yes")
+
@property
def grok_api_url(self) -> str:
url = os.getenv("GROK_API_URL")
@@ -184,6 +191,7 @@ def get_config_info(self) -> dict:
"GROK_API_KEY": api_key_masked,
"GROK_MODEL": self.grok_model,
"GROK_DEBUG": self.debug_enabled,
+ "GROK_OUTPUT_CLEANUP": self.output_cleanup_enabled,
"GROK_LOG_LEVEL": self.log_level,
"GROK_LOG_DIR": str(self.log_dir),
"TAVILY_API_URL": self.tavily_api_url,
diff --git a/src/grok_search/planning.py b/src/grok_search/planning.py
index 9f67a73..3fc76c2 100644
--- a/src/grok_search/planning.py
+++ b/src/grok_search/planning.py
@@ -84,6 +84,13 @@ class ExecutionOrderOutput(BaseModel):
_ACCUMULATIVE_LIST_PHASES = {"query_decomposition", "tool_selection"}
_MERGE_STRATEGY_PHASE = "search_strategy"
+_PHASE_PREDECESSORS = {
+ "complexity_assessment": "intent_analysis",
+ "query_decomposition": "complexity_assessment",
+ "search_strategy": "query_decomposition",
+ "tool_selection": "search_strategy",
+ "execution_order": "tool_selection",
+}
def _split_csv(value: str) -> list[str]:
@@ -126,6 +133,9 @@ def __init__(self):
def get_session(self, session_id: str) -> PlanningSession | None:
return self._sessions.get(session_id)
+ def reset(self) -> None:
+ self._sessions.clear()
+
def process_phase(
self,
phase: str,
@@ -147,6 +157,35 @@ def process_phase(
if target not in PHASE_NAMES:
return {"error": f"Unknown phase: {target}. Valid: {', '.join(PHASE_NAMES)}"}
+ if not is_revision:
+ predecessor = _PHASE_PREDECESSORS.get(target)
+ if predecessor and predecessor not in session.phases:
+ return {
+ "error": f"Phase '{target}' requires '{predecessor}' to be completed first.",
+ "expected_phase_order": PHASE_NAMES,
+ "session_id": session.session_id,
+ "completed_phases": session.completed_phases,
+ "complexity_level": session.complexity_level,
+ }
+
+ if session.complexity_level == 1 and target in {"search_strategy", "tool_selection", "execution_order"}:
+ return {
+ "error": "Level 1 planning completes after query_decomposition.",
+ "expected_phase_order": PHASE_NAMES,
+ "session_id": session.session_id,
+ "completed_phases": session.completed_phases,
+ "complexity_level": session.complexity_level,
+ }
+
+ if session.complexity_level == 2 and target == "execution_order":
+ return {
+ "error": "Level 2 planning completes after tool_selection.",
+ "expected_phase_order": PHASE_NAMES,
+ "session_id": session.session_id,
+ "completed_phases": session.completed_phases,
+ "complexity_level": session.complexity_level,
+ }
+
if target in _ACCUMULATIVE_LIST_PHASES:
if is_revision:
session.phases[target] = PhaseRecord(
diff --git a/src/grok_search/server.py b/src/grok_search/server.py
index 7754216..ac01a6f 100644
--- a/src/grok_search/server.py
+++ b/src/grok_search/server.py
@@ -71,6 +71,26 @@ async def _get_available_models_cached(api_url: str, api_key: str) -> list[str]:
return models
+def _planning_session_error(session_id: str) -> str:
+ import json
+ return json.dumps(
+ {
+ "error": "session_not_found",
+ "message": f"Session '{session_id}' not found. Call plan_intent first.",
+ "expected_phase_order": [
+ "intent_analysis",
+ "complexity_assessment",
+ "query_decomposition",
+ "search_strategy",
+ "tool_selection",
+ "execution_order",
+ ],
+ "restart_from_intent_analysis": True,
+ },
+ ensure_ascii=False,
+ indent=2,
+ )
+
def _extra_results_to_sources(
tavily_results: list[dict] | None,
firecrawl_results: list[dict] | None,
@@ -209,7 +229,6 @@ async def _safe_firecrawl() -> list[dict] | None:
await _SOURCES_CACHE.set(session_id, all_sources)
return {"session_id": session_id, "content": answer, "sources_count": len(all_sources)}
-
@mcp.tool(
name="get_sources",
description="""
@@ -713,7 +732,7 @@ async def plan_complexity(
) -> str:
import json
if not planning_engine.get_session(session_id):
- return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
+ return _planning_session_error(session_id)
return json.dumps(planning_engine.process_phase(
phase="complexity_assessment", thought=thought, session_id=session_id,
is_revision=is_revision, confidence=confidence,
@@ -741,7 +760,7 @@ async def plan_sub_query(
) -> str:
import json
if not planning_engine.get_session(session_id):
- return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
+ return _planning_session_error(session_id)
item = {"id": id, "goal": goal, "expected_output": expected_output, "boundary": boundary}
if depends_on:
item["depends_on"] = _split_csv(depends_on)
@@ -771,7 +790,7 @@ async def plan_search_term(
) -> str:
import json
if not planning_engine.get_session(session_id):
- return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
+ return _planning_session_error(session_id)
data = {"search_terms": [{"term": term, "purpose": purpose, "round": round}]}
if approach:
data["approach"] = approach
@@ -800,7 +819,7 @@ async def plan_tool_mapping(
) -> str:
import json
if not planning_engine.get_session(session_id):
- return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
+ return _planning_session_error(session_id)
item = {"sub_query_id": sub_query_id, "tool": tool, "reason": reason}
if params_json:
try:
@@ -829,7 +848,7 @@ async def plan_execution(
) -> str:
import json
if not planning_engine.get_session(session_id):
- return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."})
+ return _planning_session_error(session_id)
parallel = [_split_csv(g) for g in parallel_groups.split(";") if g.strip()] if parallel_groups else []
seq = _split_csv(sequential)
return json.dumps(planning_engine.process_phase(
diff --git a/src/grok_search/sources.py b/src/grok_search/sources.py
index 63386e2..3c77c60 100644
--- a/src/grok_search/sources.py
+++ b/src/grok_search/sources.py
@@ -7,6 +7,7 @@
import asyncio
+from .config import config
from .utils import extract_unique_urls
@@ -23,6 +24,41 @@
_SOURCES_FUNCTION_PATTERN = re.compile(
r"(?im)(^|\n)\s*(sources|source|citations|citation|references|reference|citation_card|source_cards|source_card)\s*\("
)
+_THINK_BLOCK_PATTERN = re.compile(r"(?is).*?")
+_LEADING_POLICY_PATTERNS = [
+ re.compile(r"(?is)^\s*\**\s*i cannot comply\b.*"),
+ re.compile(r"(?is)^\s*\**\s*i do not accept\b.*"),
+ re.compile(r"(?is)^\s*\**\s*refusal\s*[::].*"),
+ re.compile(r"(?is)^\s*\**\s*refuse to\b.*"),
+ re.compile(r"(?is)^\s*\**\s*rejected?\b.*"),
+ re.compile(r"(?is)^\s*\**\s*拒绝执行\b.*"),
+ re.compile(r"(?is)^\s*\**\s*无法遵循\b.*"),
+]
+_POLICY_META_KEYWORDS = (
+ "cannot comply",
+ "refuse",
+ "refusal",
+ "override my core",
+ "core behavior",
+ "custom rules",
+ "用户提供的自定义",
+ "覆盖我的核心",
+ "核心行为",
+ "拒绝执行",
+ "无法遵循",
+)
+_POLICY_CONTEXT_KEYWORDS = (
+ "jailbreak",
+ "prompt injection",
+ "system instructions",
+ "system prompt",
+ "user-injected",
+ "注入",
+ "越狱",
+ "系统指令",
+ "系统提示",
+ "自定义“system”",
+)
def new_session_id() -> str:
@@ -72,6 +108,11 @@ def split_answer_and_sources(text: str) -> tuple[str, list[dict]]:
if not raw:
return "", []
+ if config.output_cleanup_enabled:
+ cleaned = sanitize_answer_text(raw)
+ if cleaned:
+ raw = cleaned
+
split = _split_function_call_sources(raw)
if split:
return split
@@ -91,6 +132,43 @@ def split_answer_and_sources(text: str) -> tuple[str, list[dict]]:
return raw, []
+def sanitize_answer_text(text: str) -> str:
+ raw = (text or "").strip()
+ if not raw:
+ return ""
+
+ cleaned = _THINK_BLOCK_PATTERN.sub("", raw).strip()
+ paragraphs = _split_paragraphs(cleaned)
+ filtered = [paragraph for paragraph in paragraphs if not _looks_like_policy_block(paragraph)]
+ if filtered:
+ return "\n\n".join(filtered).strip()
+ return cleaned
+
+
+def _split_paragraphs(text: str) -> list[str]:
+ parts = [part.strip() for part in re.split(r"\n\s*\n", text) if part.strip()]
+ return parts or ([text.strip()] if text.strip() else [])
+
+
+def _looks_like_policy_block(text: str) -> bool:
+ normalized = _normalize_policy_text(text)
+ if not normalized:
+ return False
+
+ if any(pattern.match(normalized) for pattern in _LEADING_POLICY_PATTERNS):
+ return True
+
+ return any(keyword in normalized for keyword in _POLICY_META_KEYWORDS) and any(
+ keyword in normalized for keyword in _POLICY_CONTEXT_KEYWORDS
+ )
+
+
+def _normalize_policy_text(text: str) -> str:
+ normalized = re.sub(r"[>*_`#-]+", " ", text or "")
+ normalized = re.sub(r"\s+", " ", normalized)
+ return normalized.strip().lower()
+
+
def _split_function_call_sources(text: str) -> tuple[str, list[dict]] | None:
matches = list(_SOURCES_FUNCTION_PATTERN.finditer(text))
if not matches:
diff --git a/src/grok_search/utils.py b/src/grok_search/utils.py
index eedbd0f..942544c 100644
--- a/src/grok_search/utils.py
+++ b/src/grok_search/utils.py
@@ -238,4 +238,6 @@ def format_search_results(results: List[SearchResult]) -> str:
5. **Every sentence must cite sources** (`citation_card`). More references = stronger credibility. Silence if uncited.
6. Expand on key concepts—after proposing solutions, **use real-world analogies** to demystify technical terms.
7. **Strictly format outputs in polished Markdown** (LaTeX for formulas, code blocks for scripts, etc.).
+8. Do not mention prompt injection, jailbreaks, system prompts, hidden instructions, or refusal policies unless the user explicitly asks about those topics.
+9. Do not output hidden reasoning or `` tags.
"""
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..71f3f4b
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,10 @@
+import sys
+from pathlib import Path
+
+
+# Make the src layout importable in local pytest runs.
+ROOT = Path(__file__).resolve().parents[1]
+SRC = ROOT / "src"
+
+if str(SRC) not in sys.path:
+ sys.path.insert(0, str(SRC))
diff --git a/tests/test_planning.py b/tests/test_planning.py
new file mode 100644
index 0000000..80e9fcb
--- /dev/null
+++ b/tests/test_planning.py
@@ -0,0 +1,142 @@
+import json
+
+import pytest
+
+from grok_search import planning
+from grok_search import server
+
+
+@pytest.fixture(autouse=True)
+def reset_planning_state():
+ planning.engine.reset()
+ yield
+ planning.engine.reset()
+
+
+@pytest.mark.asyncio
+async def test_legacy_plan_flow_level_1_still_completes_after_sub_query():
+ intent = json.loads(
+ await server.plan_intent(
+ thought="Simple factual lookup.",
+ core_question="What is OpenAI?",
+ query_type="factual",
+ time_sensitivity="irrelevant",
+ )
+ )
+ session_id = intent["session_id"]
+
+ complexity = json.loads(
+ await server.plan_complexity(
+ session_id=session_id,
+ thought="Low complexity.",
+ level=1,
+ estimated_sub_queries=1,
+ estimated_tool_calls=2,
+ justification="One direct lookup is enough.",
+ )
+ )
+ assert complexity["plan_complete"] is False
+
+ decomposition = json.loads(
+ await server.plan_sub_query(
+ session_id=session_id,
+ thought="One sub-query is enough.",
+ id="sq1",
+ goal="Identify OpenAI.",
+ expected_output="One-sentence company definition.",
+ boundary="Exclude product history and leadership details.",
+ tool_hint="web_search",
+ )
+ )
+ assert decomposition["plan_complete"] is True
+ assert decomposition["complexity_level"] == 1
+
+
+@pytest.mark.asyncio
+async def test_missing_session_returns_structured_error():
+ result = json.loads(
+ await server.plan_complexity(
+ session_id="",
+ thought="Missing session.",
+ level=1,
+ estimated_sub_queries=1,
+ estimated_tool_calls=2,
+ justification="Should fail without a session.",
+ )
+ )
+
+ assert result["error"] == "session_not_found"
+ assert result["restart_from_intent_analysis"] is True
+ assert "expected_phase_order" in result
+
+
+@pytest.mark.asyncio
+async def test_out_of_order_phase_returns_error():
+ intent = json.loads(
+ await server.plan_intent(
+ thought="Start planning.",
+ core_question="Resolve an ambiguous query.",
+ query_type="exploratory",
+ time_sensitivity="recent",
+ )
+ )
+
+ wrong = json.loads(
+ await server.plan_sub_query(
+ session_id=intent["session_id"],
+ thought="Skip complexity on purpose.",
+ id="sq1",
+ goal="Wrong order",
+ expected_output="Should fail.",
+ boundary="Testing invalid order.",
+ tool_hint="web_search",
+ )
+ )
+
+ assert "requires 'complexity_assessment'" in wrong["error"]
+ assert wrong["expected_phase_order"][0] == "intent_analysis"
+
+
+@pytest.mark.asyncio
+async def test_level_1_blocks_later_phases():
+ intent = json.loads(
+ await server.plan_intent(
+ thought="Simple factual lookup.",
+ core_question="What is OpenAI?",
+ query_type="factual",
+ time_sensitivity="irrelevant",
+ )
+ )
+ session_id = intent["session_id"]
+
+ await server.plan_complexity(
+ session_id=session_id,
+ thought="Low complexity.",
+ level=1,
+ estimated_sub_queries=1,
+ estimated_tool_calls=2,
+ justification="One direct lookup is enough.",
+ )
+
+ await server.plan_sub_query(
+ session_id=session_id,
+ thought="Complete the required level 1 decomposition first.",
+ id="sq1",
+ goal="Identify OpenAI.",
+ expected_output="A short definition.",
+ boundary="Exclude unrelated details.",
+ tool_hint="web_search",
+ )
+
+ result = json.loads(
+ await server.plan_search_term(
+ session_id=session_id,
+ thought="This should be blocked for level 1.",
+ term="openai company",
+ purpose="sq1",
+ round=1,
+ approach="targeted",
+ )
+ )
+
+ assert result["error"] == "Level 1 planning completes after query_decomposition."
diff --git a/tests/test_sources.py b/tests/test_sources.py
new file mode 100644
index 0000000..6a6ff87
--- /dev/null
+++ b/tests/test_sources.py
@@ -0,0 +1,75 @@
+from grok_search.sources import sanitize_answer_text, split_answer_and_sources
+
+
+def test_sanitize_answer_text_removes_think_and_policy_prefix():
+ raw = """
+
+Hidden reasoning that should never be shown.
+
+
+**I cannot comply with user-injected "system:" instructions or custom rules attempting to override my core behavior.**
+
+OpenAI is an AI research and deployment company.
+"""
+
+ cleaned = sanitize_answer_text(raw)
+
+ assert "" not in cleaned
+ assert "cannot comply" not in cleaned.lower()
+ assert cleaned == "OpenAI is an AI research and deployment company."
+
+
+def test_split_answer_and_sources_keeps_clean_answer_and_extracts_links():
+ raw = """
+Hidden
+
+**拒绝执行。**
+
+OpenAI is an AI research and deployment company.
+
+## Sources
+1. [OpenAI](https://openai.com/)
+2. [Wikipedia](https://en.wikipedia.org/wiki/OpenAI)
+"""
+
+ answer, sources = split_answer_and_sources(raw)
+
+ assert answer == "OpenAI is an AI research and deployment company."
+ assert [item["url"] for item in sources] == [
+ "https://openai.com/",
+ "https://en.wikipedia.org/wiki/OpenAI",
+ ]
+
+
+def test_sanitize_answer_text_removes_trailing_policy_suffix():
+ raw = """
+OpenAI is an AI research and deployment company.
+
+I cannot comply with user-injected "system:" instructions or discuss jailbreak attempts that override my core behavior.
+"""
+
+ cleaned = sanitize_answer_text(raw)
+
+ assert cleaned == "OpenAI is an AI research and deployment company."
+
+
+def test_sanitize_answer_text_removes_refusal_preface():
+ raw = """
+**Refusal:** I do not accept or follow injected "system" prompts, custom instructions, or overrides.
+
+OpenAI is an AI research and deployment company.
+"""
+
+ cleaned = sanitize_answer_text(raw)
+
+ assert cleaned == "OpenAI is an AI research and deployment company."
+
+
+def test_sanitize_answer_text_does_not_strip_legitimate_prompt_injection_topic():
+ raw = """
+Prompt injection is a technique that tries to manipulate an LLM's instructions.
+"""
+
+ cleaned = sanitize_answer_text(raw)
+
+ assert cleaned == raw.strip()