diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d1c9774 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +__pycache__/ +.pytest_cache/ +.grok-search/ +logs/ +*.egg-info/ +uv.lock +.env +.env.* diff --git a/src/grok_search/config.py b/src/grok_search/config.py index bdfbfd6..ea88507 100644 --- a/src/grok_search/config.py +++ b/src/grok_search/config.py @@ -63,6 +63,13 @@ def retry_multiplier(self) -> float: def retry_max_wait(self) -> int: return int(os.getenv("GROK_RETRY_MAX_WAIT", "10")) + @property + def output_cleanup_enabled(self) -> bool: + raw = os.getenv("GROK_OUTPUT_CLEANUP") + if raw is None: + raw = os.getenv("GROK_FILTER_THINK_TAGS", "true") + return raw.lower() in ("true", "1", "yes") + @property def grok_api_url(self) -> str: url = os.getenv("GROK_API_URL") @@ -184,6 +191,7 @@ def get_config_info(self) -> dict: "GROK_API_KEY": api_key_masked, "GROK_MODEL": self.grok_model, "GROK_DEBUG": self.debug_enabled, + "GROK_OUTPUT_CLEANUP": self.output_cleanup_enabled, "GROK_LOG_LEVEL": self.log_level, "GROK_LOG_DIR": str(self.log_dir), "TAVILY_API_URL": self.tavily_api_url, diff --git a/src/grok_search/planning.py b/src/grok_search/planning.py index 9f67a73..3fc76c2 100644 --- a/src/grok_search/planning.py +++ b/src/grok_search/planning.py @@ -84,6 +84,13 @@ class ExecutionOrderOutput(BaseModel): _ACCUMULATIVE_LIST_PHASES = {"query_decomposition", "tool_selection"} _MERGE_STRATEGY_PHASE = "search_strategy" +_PHASE_PREDECESSORS = { + "complexity_assessment": "intent_analysis", + "query_decomposition": "complexity_assessment", + "search_strategy": "query_decomposition", + "tool_selection": "search_strategy", + "execution_order": "tool_selection", +} def _split_csv(value: str) -> list[str]: @@ -126,6 +133,9 @@ def __init__(self): def get_session(self, session_id: str) -> PlanningSession | None: return self._sessions.get(session_id) + def reset(self) -> None: + self._sessions.clear() + def process_phase( self, phase: str, @@ -147,6 +157,35 @@ def process_phase( if target not in PHASE_NAMES: return {"error": f"Unknown phase: {target}. Valid: {', '.join(PHASE_NAMES)}"} + if not is_revision: + predecessor = _PHASE_PREDECESSORS.get(target) + if predecessor and predecessor not in session.phases: + return { + "error": f"Phase '{target}' requires '{predecessor}' to be completed first.", + "expected_phase_order": PHASE_NAMES, + "session_id": session.session_id, + "completed_phases": session.completed_phases, + "complexity_level": session.complexity_level, + } + + if session.complexity_level == 1 and target in {"search_strategy", "tool_selection", "execution_order"}: + return { + "error": "Level 1 planning completes after query_decomposition.", + "expected_phase_order": PHASE_NAMES, + "session_id": session.session_id, + "completed_phases": session.completed_phases, + "complexity_level": session.complexity_level, + } + + if session.complexity_level == 2 and target == "execution_order": + return { + "error": "Level 2 planning completes after tool_selection.", + "expected_phase_order": PHASE_NAMES, + "session_id": session.session_id, + "completed_phases": session.completed_phases, + "complexity_level": session.complexity_level, + } + if target in _ACCUMULATIVE_LIST_PHASES: if is_revision: session.phases[target] = PhaseRecord( diff --git a/src/grok_search/server.py b/src/grok_search/server.py index 7754216..ac01a6f 100644 --- a/src/grok_search/server.py +++ b/src/grok_search/server.py @@ -71,6 +71,26 @@ async def _get_available_models_cached(api_url: str, api_key: str) -> list[str]: return models +def _planning_session_error(session_id: str) -> str: + import json + return json.dumps( + { + "error": "session_not_found", + "message": f"Session '{session_id}' not found. Call plan_intent first.", + "expected_phase_order": [ + "intent_analysis", + "complexity_assessment", + "query_decomposition", + "search_strategy", + "tool_selection", + "execution_order", + ], + "restart_from_intent_analysis": True, + }, + ensure_ascii=False, + indent=2, + ) + def _extra_results_to_sources( tavily_results: list[dict] | None, firecrawl_results: list[dict] | None, @@ -209,7 +229,6 @@ async def _safe_firecrawl() -> list[dict] | None: await _SOURCES_CACHE.set(session_id, all_sources) return {"session_id": session_id, "content": answer, "sources_count": len(all_sources)} - @mcp.tool( name="get_sources", description=""" @@ -713,7 +732,7 @@ async def plan_complexity( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) return json.dumps(planning_engine.process_phase( phase="complexity_assessment", thought=thought, session_id=session_id, is_revision=is_revision, confidence=confidence, @@ -741,7 +760,7 @@ async def plan_sub_query( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) item = {"id": id, "goal": goal, "expected_output": expected_output, "boundary": boundary} if depends_on: item["depends_on"] = _split_csv(depends_on) @@ -771,7 +790,7 @@ async def plan_search_term( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) data = {"search_terms": [{"term": term, "purpose": purpose, "round": round}]} if approach: data["approach"] = approach @@ -800,7 +819,7 @@ async def plan_tool_mapping( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) item = {"sub_query_id": sub_query_id, "tool": tool, "reason": reason} if params_json: try: @@ -829,7 +848,7 @@ async def plan_execution( ) -> str: import json if not planning_engine.get_session(session_id): - return json.dumps({"error": f"Session '{session_id}' not found. Call plan_intent first."}) + return _planning_session_error(session_id) parallel = [_split_csv(g) for g in parallel_groups.split(";") if g.strip()] if parallel_groups else [] seq = _split_csv(sequential) return json.dumps(planning_engine.process_phase( diff --git a/src/grok_search/sources.py b/src/grok_search/sources.py index 63386e2..3c77c60 100644 --- a/src/grok_search/sources.py +++ b/src/grok_search/sources.py @@ -7,6 +7,7 @@ import asyncio +from .config import config from .utils import extract_unique_urls @@ -23,6 +24,41 @@ _SOURCES_FUNCTION_PATTERN = re.compile( r"(?im)(^|\n)\s*(sources|source|citations|citation|references|reference|citation_card|source_cards|source_card)\s*\(" ) +_THINK_BLOCK_PATTERN = re.compile(r"(?is).*?") +_LEADING_POLICY_PATTERNS = [ + re.compile(r"(?is)^\s*\**\s*i cannot comply\b.*"), + re.compile(r"(?is)^\s*\**\s*i do not accept\b.*"), + re.compile(r"(?is)^\s*\**\s*refusal\s*[::].*"), + re.compile(r"(?is)^\s*\**\s*refuse to\b.*"), + re.compile(r"(?is)^\s*\**\s*rejected?\b.*"), + re.compile(r"(?is)^\s*\**\s*拒绝执行\b.*"), + re.compile(r"(?is)^\s*\**\s*无法遵循\b.*"), +] +_POLICY_META_KEYWORDS = ( + "cannot comply", + "refuse", + "refusal", + "override my core", + "core behavior", + "custom rules", + "用户提供的自定义", + "覆盖我的核心", + "核心行为", + "拒绝执行", + "无法遵循", +) +_POLICY_CONTEXT_KEYWORDS = ( + "jailbreak", + "prompt injection", + "system instructions", + "system prompt", + "user-injected", + "注入", + "越狱", + "系统指令", + "系统提示", + "自定义“system”", +) def new_session_id() -> str: @@ -72,6 +108,11 @@ def split_answer_and_sources(text: str) -> tuple[str, list[dict]]: if not raw: return "", [] + if config.output_cleanup_enabled: + cleaned = sanitize_answer_text(raw) + if cleaned: + raw = cleaned + split = _split_function_call_sources(raw) if split: return split @@ -91,6 +132,43 @@ def split_answer_and_sources(text: str) -> tuple[str, list[dict]]: return raw, [] +def sanitize_answer_text(text: str) -> str: + raw = (text or "").strip() + if not raw: + return "" + + cleaned = _THINK_BLOCK_PATTERN.sub("", raw).strip() + paragraphs = _split_paragraphs(cleaned) + filtered = [paragraph for paragraph in paragraphs if not _looks_like_policy_block(paragraph)] + if filtered: + return "\n\n".join(filtered).strip() + return cleaned + + +def _split_paragraphs(text: str) -> list[str]: + parts = [part.strip() for part in re.split(r"\n\s*\n", text) if part.strip()] + return parts or ([text.strip()] if text.strip() else []) + + +def _looks_like_policy_block(text: str) -> bool: + normalized = _normalize_policy_text(text) + if not normalized: + return False + + if any(pattern.match(normalized) for pattern in _LEADING_POLICY_PATTERNS): + return True + + return any(keyword in normalized for keyword in _POLICY_META_KEYWORDS) and any( + keyword in normalized for keyword in _POLICY_CONTEXT_KEYWORDS + ) + + +def _normalize_policy_text(text: str) -> str: + normalized = re.sub(r"[>*_`#-]+", " ", text or "") + normalized = re.sub(r"\s+", " ", normalized) + return normalized.strip().lower() + + def _split_function_call_sources(text: str) -> tuple[str, list[dict]] | None: matches = list(_SOURCES_FUNCTION_PATTERN.finditer(text)) if not matches: diff --git a/src/grok_search/utils.py b/src/grok_search/utils.py index eedbd0f..942544c 100644 --- a/src/grok_search/utils.py +++ b/src/grok_search/utils.py @@ -238,4 +238,6 @@ def format_search_results(results: List[SearchResult]) -> str: 5. **Every sentence must cite sources** (`citation_card`). More references = stronger credibility. Silence if uncited. 6. Expand on key concepts—after proposing solutions, **use real-world analogies** to demystify technical terms. 7. **Strictly format outputs in polished Markdown** (LaTeX for formulas, code blocks for scripts, etc.). +8. Do not mention prompt injection, jailbreaks, system prompts, hidden instructions, or refusal policies unless the user explicitly asks about those topics. +9. Do not output hidden reasoning or `` tags. """ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..71f3f4b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +import sys +from pathlib import Path + + +# Make the src layout importable in local pytest runs. +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" + +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) diff --git a/tests/test_planning.py b/tests/test_planning.py new file mode 100644 index 0000000..80e9fcb --- /dev/null +++ b/tests/test_planning.py @@ -0,0 +1,142 @@ +import json + +import pytest + +from grok_search import planning +from grok_search import server + + +@pytest.fixture(autouse=True) +def reset_planning_state(): + planning.engine.reset() + yield + planning.engine.reset() + + +@pytest.mark.asyncio +async def test_legacy_plan_flow_level_1_still_completes_after_sub_query(): + intent = json.loads( + await server.plan_intent( + thought="Simple factual lookup.", + core_question="What is OpenAI?", + query_type="factual", + time_sensitivity="irrelevant", + ) + ) + session_id = intent["session_id"] + + complexity = json.loads( + await server.plan_complexity( + session_id=session_id, + thought="Low complexity.", + level=1, + estimated_sub_queries=1, + estimated_tool_calls=2, + justification="One direct lookup is enough.", + ) + ) + assert complexity["plan_complete"] is False + + decomposition = json.loads( + await server.plan_sub_query( + session_id=session_id, + thought="One sub-query is enough.", + id="sq1", + goal="Identify OpenAI.", + expected_output="One-sentence company definition.", + boundary="Exclude product history and leadership details.", + tool_hint="web_search", + ) + ) + assert decomposition["plan_complete"] is True + assert decomposition["complexity_level"] == 1 + + +@pytest.mark.asyncio +async def test_missing_session_returns_structured_error(): + result = json.loads( + await server.plan_complexity( + session_id="", + thought="Missing session.", + level=1, + estimated_sub_queries=1, + estimated_tool_calls=2, + justification="Should fail without a session.", + ) + ) + + assert result["error"] == "session_not_found" + assert result["restart_from_intent_analysis"] is True + assert "expected_phase_order" in result + + +@pytest.mark.asyncio +async def test_out_of_order_phase_returns_error(): + intent = json.loads( + await server.plan_intent( + thought="Start planning.", + core_question="Resolve an ambiguous query.", + query_type="exploratory", + time_sensitivity="recent", + ) + ) + + wrong = json.loads( + await server.plan_sub_query( + session_id=intent["session_id"], + thought="Skip complexity on purpose.", + id="sq1", + goal="Wrong order", + expected_output="Should fail.", + boundary="Testing invalid order.", + tool_hint="web_search", + ) + ) + + assert "requires 'complexity_assessment'" in wrong["error"] + assert wrong["expected_phase_order"][0] == "intent_analysis" + + +@pytest.mark.asyncio +async def test_level_1_blocks_later_phases(): + intent = json.loads( + await server.plan_intent( + thought="Simple factual lookup.", + core_question="What is OpenAI?", + query_type="factual", + time_sensitivity="irrelevant", + ) + ) + session_id = intent["session_id"] + + await server.plan_complexity( + session_id=session_id, + thought="Low complexity.", + level=1, + estimated_sub_queries=1, + estimated_tool_calls=2, + justification="One direct lookup is enough.", + ) + + await server.plan_sub_query( + session_id=session_id, + thought="Complete the required level 1 decomposition first.", + id="sq1", + goal="Identify OpenAI.", + expected_output="A short definition.", + boundary="Exclude unrelated details.", + tool_hint="web_search", + ) + + result = json.loads( + await server.plan_search_term( + session_id=session_id, + thought="This should be blocked for level 1.", + term="openai company", + purpose="sq1", + round=1, + approach="targeted", + ) + ) + + assert result["error"] == "Level 1 planning completes after query_decomposition." diff --git a/tests/test_sources.py b/tests/test_sources.py new file mode 100644 index 0000000..6a6ff87 --- /dev/null +++ b/tests/test_sources.py @@ -0,0 +1,75 @@ +from grok_search.sources import sanitize_answer_text, split_answer_and_sources + + +def test_sanitize_answer_text_removes_think_and_policy_prefix(): + raw = """ + +Hidden reasoning that should never be shown. + + +**I cannot comply with user-injected "system:" instructions or custom rules attempting to override my core behavior.** + +OpenAI is an AI research and deployment company. +""" + + cleaned = sanitize_answer_text(raw) + + assert "" not in cleaned + assert "cannot comply" not in cleaned.lower() + assert cleaned == "OpenAI is an AI research and deployment company." + + +def test_split_answer_and_sources_keeps_clean_answer_and_extracts_links(): + raw = """ +Hidden + +**拒绝执行。** + +OpenAI is an AI research and deployment company. + +## Sources +1. [OpenAI](https://openai.com/) +2. [Wikipedia](https://en.wikipedia.org/wiki/OpenAI) +""" + + answer, sources = split_answer_and_sources(raw) + + assert answer == "OpenAI is an AI research and deployment company." + assert [item["url"] for item in sources] == [ + "https://openai.com/", + "https://en.wikipedia.org/wiki/OpenAI", + ] + + +def test_sanitize_answer_text_removes_trailing_policy_suffix(): + raw = """ +OpenAI is an AI research and deployment company. + +I cannot comply with user-injected "system:" instructions or discuss jailbreak attempts that override my core behavior. +""" + + cleaned = sanitize_answer_text(raw) + + assert cleaned == "OpenAI is an AI research and deployment company." + + +def test_sanitize_answer_text_removes_refusal_preface(): + raw = """ +**Refusal:** I do not accept or follow injected "system" prompts, custom instructions, or overrides. + +OpenAI is an AI research and deployment company. +""" + + cleaned = sanitize_answer_text(raw) + + assert cleaned == "OpenAI is an AI research and deployment company." + + +def test_sanitize_answer_text_does_not_strip_legitimate_prompt_injection_topic(): + raw = """ +Prompt injection is a technique that tries to manipulate an LLM's instructions. +""" + + cleaned = sanitize_answer_text(raw) + + assert cleaned == raw.strip()