diff --git a/a2a/weather_service/src/weather_service/agent.py b/a2a/weather_service/src/weather_service/agent.py index 3d83afe8..5ada7f76 100644 --- a/a2a/weather_service/src/weather_service/agent.py +++ b/a2a/weather_service/src/weather_service/agent.py @@ -1,5 +1,6 @@ import logging import os +import re from textwrap import dedent import uvicorn @@ -14,6 +15,7 @@ from a2a.server.tasks import InMemoryTaskStore, TaskUpdater from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task +from weather_service.configuration import Configuration from weather_service.graph import get_graph, get_mcpclient from weather_service.observability import ( create_tracing_middleware, @@ -21,6 +23,35 @@ set_span_output, ) + +class SecretRedactionFilter(logging.Filter): + """Redacts Bearer tokens and the configured API key from log messages.""" + + _BEARER_RE = re.compile(r"(Bearer\s+)\S+", re.IGNORECASE) + + def __init__(self): + super().__init__() + key = os.environ.get("LLM_API_KEY", "").strip() + self._key_re = re.compile(re.escape(key)) if len(key) > 8 else None + + def _redact(self, text: str) -> str: + text = self._BEARER_RE.sub(r"\1[REDACTED]", text) + if self._key_re: + text = self._key_re.sub("[REDACTED]", text) + return text + + def filter(self, record: logging.LogRecord) -> bool: + if isinstance(record.msg, str): + record.msg = self._redact(record.msg) + if isinstance(record.args, dict): + record.args = {k: self._redact(v) if isinstance(v, str) else v for k, v in record.args.items()} + elif isinstance(record.args, tuple): + record.args = tuple(self._redact(a) if isinstance(a, str) else a for a in record.args) + return True + + +logging.basicConfig(level=logging.INFO) +logging.getLogger().addFilter(SecretRedactionFilter()) logger = logging.getLogger(__name__) @@ -110,6 +141,15 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): task_updater = TaskUpdater(event_queue, task.id, task.context_id) event_emitter = A2AEvent(task_updater) + # Check API key before attempting LLM calls + config = Configuration() + if not config.has_valid_api_key: + await event_emitter.emit_event( + "Error: No LLM API key configured. Set the LLM_API_KEY environment variable.", + failed=True, + ) + return + # Get user input for the agent user_input = context.get_user_input() @@ -225,16 +265,4 @@ def run(): # Add tracing middleware - creates root span with MLflow/GenAI attributes app.add_middleware(BaseHTTPMiddleware, dispatch=create_tracing_middleware()) - class LogAuthorizationMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - auth_header = request.headers.get("authorization", "No Authorization header") - logger.info( - f"🔐 Incoming request to {request.url.path} with Authorization: {auth_header[:80] + '...' if len(auth_header) > 80 else auth_header}" - ) - response = await call_next(request) - return response - - # Add logging middleware - app.add_middleware(LogAuthorizationMiddleware) - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/a2a/weather_service/src/weather_service/configuration.py b/a2a/weather_service/src/weather_service/configuration.py index dd1c9f4c..6faf3158 100644 --- a/a2a/weather_service/src/weather_service/configuration.py +++ b/a2a/weather_service/src/weather_service/configuration.py @@ -1,7 +1,23 @@ +from urllib.parse import urlparse + from pydantic_settings import BaseSettings +_PLACEHOLDER_KEYS = {"dummy", "changeme", "your-api-key-here", ""} + class Configuration(BaseSettings): llm_model: str = "llama3.1" llm_api_base: str = "http://localhost:11434/v1" llm_api_key: str = "dummy" + + @property + def has_valid_api_key(self) -> bool: + """Placeholder keys are only invalid for remote APIs. + + Local LLMs (Ollama, vLLM) accept any key, so we skip validation + when the API base points to localhost. + """ + host = urlparse(self.llm_api_base).hostname or "" + if host in {"localhost", "127.0.0.1", "0.0.0.0"}: + return True + return self.llm_api_key.strip() not in _PLACEHOLDER_KEYS diff --git a/tests/a2a/test_weather_secret_redaction.py b/tests/a2a/test_weather_secret_redaction.py new file mode 100644 index 00000000..9c519698 --- /dev/null +++ b/tests/a2a/test_weather_secret_redaction.py @@ -0,0 +1,189 @@ +"""Tests for secret redaction and API key validation in the weather service. + +Loads agent.py and configuration.py in isolation (same approach as +test_weather_service.py) to avoid pulling in heavy deps like opentelemetry. +""" + +import importlib.util +import logging +import pathlib +import sys +from types import ModuleType +from unittest.mock import MagicMock + +# --- Isolation setup (must happen before any weather_service imports) --- +_fake_ws = ModuleType("weather_service") +_fake_ws.__path__ = [] # type: ignore[attr-defined] +sys.modules.setdefault("weather_service", _fake_ws) +sys.modules.setdefault("weather_service.observability", MagicMock()) + +_BASE = pathlib.Path(__file__).parent.parent.parent / "a2a" / "weather_service" / "src" / "weather_service" + + +def _load_module(name: str, path: pathlib.Path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) # type: ignore[arg-type] + sys.modules[name] = mod + spec.loader.exec_module(mod) # type: ignore[union-attr] + return mod + + +_config_mod = _load_module("weather_service.configuration", _BASE / "configuration.py") + +# Mock modules that agent.py imports but we don't need +for mod_name in [ + "uvicorn", + "langchain_core", + "langchain_core.messages", + "starlette", + "starlette.middleware", + "starlette.middleware.base", + "starlette.routing", + "a2a", + "a2a.server", + "a2a.server.agent_execution", + "a2a.server.apps", + "a2a.server.events", + "a2a.server.events.event_queue", + "a2a.server.request_handlers", + "a2a.server.tasks", + "a2a.types", + "a2a.utils", + "weather_service.graph", +]: + sys.modules.setdefault(mod_name, MagicMock()) + +_agent_mod = _load_module("weather_service.agent", _BASE / "agent.py") + +Configuration = _config_mod.Configuration +SecretRedactionFilter = _agent_mod.SecretRedactionFilter + + +# --- Tests --- + + +class TestSecretRedactionFilter: + """Test the logging filter that redacts Bearer tokens and API keys.""" + + def setup_method(self): + self.filt = SecretRedactionFilter() + + def _make_record(self, msg: str, args=None) -> logging.LogRecord: + return logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg=msg, + args=args, + exc_info=None, + ) + + def test_redacts_bearer_token(self): + record = self._make_record("Authorization: Bearer sk-abc123xyz789secret") + self.filt.filter(record) + assert "sk-abc123xyz789secret" not in record.msg + assert "[REDACTED]" in record.msg + + def test_bearer_case_insensitive(self): + record = self._make_record("header: bearer my-secret-token-value") + self.filt.filter(record) + assert "my-secret-token-value" not in record.msg + + def test_preserves_non_secret_messages(self): + record = self._make_record("Processing weather request for New York") + self.filt.filter(record) + assert record.msg == "Processing weather request for New York" + + def test_redacts_bearer_in_tuple_args(self): + record = self._make_record("Header: %s", ("Bearer secret123",)) + self.filt.filter(record) + assert "secret123" not in record.args[0] + + def test_redacts_bearer_in_dict_args(self): + record = self._make_record("%(auth)s") + record.args = {"auth": "Bearer secret123"} + self.filt.filter(record) + assert "secret123" not in record.args["auth"] + + def test_always_returns_true(self): + record = self._make_record("Bearer secret123") + assert self.filt.filter(record) is True + + def test_redacts_literal_configured_key(self, monkeypatch): + rhoai_key = "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6" + monkeypatch.setenv("LLM_API_KEY", rhoai_key) + filt = SecretRedactionFilter() + record = self._make_record(f"Sending request with api-key={rhoai_key}") + filt.filter(record) + assert rhoai_key not in record.msg + assert "[REDACTED]" in record.msg + + def test_literal_key_redaction_in_args(self, monkeypatch): + rhoai_key = "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6" + monkeypatch.setenv("LLM_API_KEY", rhoai_key) + filt = SecretRedactionFilter() + record = self._make_record("key=%s", (rhoai_key,)) + filt.filter(record) + assert rhoai_key not in record.args[0] + + def test_short_key_not_redacted(self, monkeypatch): + monkeypatch.setenv("LLM_API_KEY", "dummy") + filt = SecretRedactionFilter() + record = self._make_record("Using dummy config for testing dummy values") + filt.filter(record) + assert "dummy" in record.msg + + def test_no_crash_when_key_unset(self, monkeypatch): + monkeypatch.delenv("LLM_API_KEY", raising=False) + filt = SecretRedactionFilter() + record = self._make_record("Normal log message") + filt.filter(record) + assert record.msg == "Normal log message" + + +class TestConfigurationApiKeyValidation: + """Test API key validation logic.""" + + def test_dummy_key_with_remote_api_is_invalid(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + monkeypatch.setenv("LLM_API_KEY", "dummy") + assert Configuration().has_valid_api_key is False + + def test_dummy_key_with_localhost_is_valid(self): + config = Configuration() # defaults: localhost + dummy + assert config.has_valid_api_key is True + + def test_empty_key_with_remote_api_is_invalid(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + monkeypatch.setenv("LLM_API_KEY", "") + assert Configuration().has_valid_api_key is False + + def test_placeholder_keys_with_remote_are_invalid(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + for key in ["changeme", "your-api-key-here"]: + monkeypatch.setenv("LLM_API_KEY", key) + assert Configuration().has_valid_api_key is False, f"'{key}' should be invalid" + + def test_placeholder_keys_with_local_llm_are_valid(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "http://localhost:11434/v1") + for key in ["dummy", "changeme", ""]: + monkeypatch.setenv("LLM_API_KEY", key) + assert Configuration().has_valid_api_key is True + + def test_real_key_is_valid(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + monkeypatch.setenv("LLM_API_KEY", "sk-proj-realkey123") + assert Configuration().has_valid_api_key is True + + def test_rhoai_maas_key_is_valid(self, monkeypatch): + monkeypatch.setenv( + "LLM_API_BASE", + "https://model--maas-apicast-production.apps.prod.rhoai.rh-aiservices-bu.com:443/v1", + ) + monkeypatch.setenv("LLM_API_KEY", "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6") + assert Configuration().has_valid_api_key is True + + def test_127_is_local(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "http://127.0.0.1:8080/v1") + assert Configuration().has_valid_api_key is True