diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b4915b..a739bc0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `-o` / `--output` to `mcts doctor` and surface scan subcommands for CI artifact paths (#156, #157). - Accept `--no-progress` on `readiness`, `fuzz`, `scan-mcp`, and surface scan subcommands for shared CI scripts (#158). - Explain when `mcts doctor --deep` import checks are skipped (no MCP config or no `-m` module in launch args). +- Classify SQL database tools separately from filesystem tools so names like `read_query` are not flagged for path traversal (#165). - Exclude design prompt markdown under `docs/prompts/` from default instruction discovery (#162). ### Changed diff --git a/src/mcts/analyzers/path_validation.py b/src/mcts/analyzers/path_validation.py index d21945b..8bd2519 100644 --- a/src/mcts/analyzers/path_validation.py +++ b/src/mcts/analyzers/path_validation.py @@ -5,10 +5,10 @@ import re from mcts.analyzers.base import BaseAnalyzer -from mcts.mcp.models import MCPServerInfo, MCPTool +from mcts.analyzers.tool_classification import is_file_access_tool +from mcts.mcp.models import MCPServerInfo from mcts.reporting.models import Finding, Severity, SourceLocation -FILE_TOOL_HINTS = ("read", "file", "path", "open", "load") CANONICALIZATION_HINTS = re.compile( r"\b(resolve|realpath|abspath|canonicalize|normpath|is_relative_to|startswith)\b", re.I, @@ -23,7 +23,7 @@ class PathValidationAnalyzer(BaseAnalyzer): def analyze(self, server: MCPServerInfo) -> list[Finding]: findings: list[Finding] = [] for tool in server.tools: - if not self._is_file_tool(tool): + if not is_file_access_tool(tool): continue snippet = tool.handler_snippet or "" if tool.source_file and tool.source_file in server.source_files: @@ -45,7 +45,3 @@ def analyze(self, server: MCPServerInfo) -> list[Finding]: ) ) return findings - - def _is_file_tool(self, tool: MCPTool) -> bool: - haystack = f"{tool.name} {tool.description}".lower() - return any(hint in haystack for hint in FILE_TOOL_HINTS) diff --git a/src/mcts/analyzers/tool_abuse.py b/src/mcts/analyzers/tool_abuse.py index b6faa58..642b7f7 100644 --- a/src/mcts/analyzers/tool_abuse.py +++ b/src/mcts/analyzers/tool_abuse.py @@ -4,11 +4,10 @@ from mcts.analyzers.base import BaseAnalyzer from mcts.analyzers.path_traversal import SENSITIVE_PATH_TARGETS, TRAVERSAL_PAYLOADS -from mcts.mcp.models import MCPServerInfo, MCPTool +from mcts.analyzers.tool_classification import is_file_access_tool +from mcts.mcp.models import MCPServerInfo from mcts.reporting.models import Finding, Severity -FILE_TOOL_HINTS = ("read", "file", "path", "open", "load", "fetch") - class ToolAbuseAnalyzer(BaseAnalyzer): """Identifies tools susceptible to path traversal and unauthorized access.""" @@ -18,7 +17,7 @@ class ToolAbuseAnalyzer(BaseAnalyzer): def analyze(self, server: MCPServerInfo) -> list[Finding]: findings: list[Finding] = [] for tool in server.tools: - if self._is_file_tool(tool): + if is_file_access_tool(tool): findings.append( Finding( id=f"abuse-path-{tool.name}", @@ -39,7 +38,3 @@ def analyze(self, server: MCPServerInfo) -> list[Finding]: ) ) return findings - - def _is_file_tool(self, tool: MCPTool) -> bool: - haystack = f"{tool.name} {tool.description}".lower() - return any(hint in haystack for hint in FILE_TOOL_HINTS) diff --git a/src/mcts/analyzers/tool_classification.py b/src/mcts/analyzers/tool_classification.py new file mode 100644 index 0000000..0d7c83c --- /dev/null +++ b/src/mcts/analyzers/tool_classification.py @@ -0,0 +1,109 @@ +"""Shared heuristics for classifying MCP tools as file-access vs database access.""" + +from __future__ import annotations + +import re + +from mcts.mcp.models import MCPTool + +SQL_TOOL_NAMES = frozenset( + { + "read_query", + "run_query", + "execute_sql", + "execute_query", + "query_database", + "sql_query", + } +) + +SQL_TOOL_MARKERS = ( + "sql", + "query", + "snowflake", + "database", + "jdbc", + "postgres", + "mysql", + "sqlite", +) + +SQL_SCHEMA_PARAMS = frozenset({"query", "sql", "statement", "sql_query"}) + +FILE_SCHEMA_PARAMS = frozenset( + { + "path", + "filepath", + "file_path", + "filename", + "directory", + "dir", + } +) + +FILE_TOOL_NAME_PATTERNS: tuple[re.Pattern[str], ...] = ( + re.compile(r"\bread_file\b", re.I), + re.compile(r"\bfile_read\b", re.I), + re.compile(r"\bread_path\b", re.I), + re.compile(r"\bpath_read\b", re.I), + re.compile(r"\bload_file\b", re.I), + re.compile(r"\bopen_file\b", re.I), + re.compile(r"\bwrite_file\b", re.I), + re.compile(r"\bfile_write\b", re.I), + re.compile(r"\bget_file\b", re.I), + re.compile(r"\blist_dir\b", re.I), + re.compile(r"\blist_directory\b", re.I), +) + +FILE_TOOL_TOKEN = re.compile(r"\b(read|file|path|open|load|fetch)\b", re.I) + + +def _schema_param_names(tool: MCPTool) -> set[str]: + schema = tool.input_schema or {} + if not isinstance(schema, dict): + return set() + props = schema.get("properties") + if not isinstance(props, dict): + return set() + return {str(name).lower() for name in props} + + +def is_sql_database_tool(tool: MCPTool) -> bool: + """Return True when the tool appears to execute SQL rather than read files.""" + if tool.name.lower() in SQL_TOOL_NAMES: + return True + + schema_params = _schema_param_names(tool) + has_sql_schema = bool(schema_params & SQL_SCHEMA_PARAMS) + has_file_schema = bool(schema_params & FILE_SCHEMA_PARAMS) + + if has_sql_schema and not has_file_schema: + return True + + haystack = f"{tool.name} {tool.description}".lower() + if any(marker in haystack for marker in SQL_TOOL_MARKERS): + return not (has_file_schema and not has_sql_schema) + + return False + + +def is_file_access_tool(tool: MCPTool) -> bool: + """Return True when the tool likely reads or writes local filesystem paths.""" + if is_sql_database_tool(tool): + return False + + schema_params = _schema_param_names(tool) + has_file_schema = bool(schema_params & FILE_SCHEMA_PARAMS) + has_sql_schema = bool(schema_params & SQL_SCHEMA_PARAMS) + + if has_file_schema and not has_sql_schema: + return True + + if any(pattern.search(tool.name) for pattern in FILE_TOOL_NAME_PATTERNS): + return True + + haystack = f"{tool.name} {tool.description}" + if FILE_TOOL_TOKEN.search(haystack) and has_file_schema: + return True + + return bool(re.search(r"\bfile\s+(?:path|system|access|read|write)\b", haystack, re.I)) diff --git a/tests/test_tool_classification.py b/tests/test_tool_classification.py new file mode 100644 index 0000000..1771757 --- /dev/null +++ b/tests/test_tool_classification.py @@ -0,0 +1,93 @@ +"""Tests for shared MCP tool file-access vs SQL classification.""" + +from __future__ import annotations + +from mcts.analyzers.path_validation import PathValidationAnalyzer +from mcts.analyzers.tool_abuse import ToolAbuseAnalyzer +from mcts.analyzers.tool_classification import is_file_access_tool, is_sql_database_tool +from mcts.mcp.models import MCPServerInfo, MCPTool + + +def _tool(**kwargs: object) -> MCPTool: + defaults: dict[str, object] = { + "name": "read_file", + "description": "Read a file from disk", + "input_schema": { + "type": "object", + "properties": {"path": {"type": "string"}}, + }, + } + defaults.update(kwargs) + return MCPTool(**defaults) # type: ignore[arg-type] + + +def _server(tools: list[MCPTool]) -> MCPServerInfo: + return MCPServerInfo(name="test", tools=tools, source_files={}) + + +def test_read_query_with_sql_schema_is_not_file_tool() -> None: + tool = _tool( + name="read_query", + description="Run a read-only Snowflake SQL query", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + ) + assert is_sql_database_tool(tool) + assert not is_file_access_tool(tool) + + +def test_read_file_with_path_schema_is_file_tool() -> None: + tool = _tool(name="read_file", description="Read a file from an allowed directory") + assert not is_sql_database_tool(tool) + assert is_file_access_tool(tool) + + +def test_run_query_name_is_not_file_tool() -> None: + tool = _tool( + name="run_query", + description="Execute SQL against the analytics warehouse", + input_schema={ + "type": "object", + "properties": {"sql": {"type": "string"}}, + }, + ) + assert is_sql_database_tool(tool) + assert not is_file_access_tool(tool) + + +def test_tool_abuse_skips_read_query() -> None: + tool = _tool( + name="read_query", + description="Execute a JDBC SELECT statement", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + ) + findings = ToolAbuseAnalyzer().analyze(_server([tool])) + assert not findings + + +def test_tool_abuse_flags_read_file() -> None: + findings = ToolAbuseAnalyzer().analyze(_server([_tool()])) + assert any(f.analyzer == "tool_abuse" and f.tool == "read_file" for f in findings) + + +def test_path_validation_skips_read_query() -> None: + tool = _tool( + name="read_query", + description="Query Snowflake tables", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + ) + findings = PathValidationAnalyzer().analyze(_server([tool])) + assert not findings + + +def test_path_validation_flags_read_file_without_guards() -> None: + findings = PathValidationAnalyzer().analyze(_server([_tool()])) + assert any(f.analyzer == "path_validation" and f.tool == "read_file" for f in findings)