From 9b9c70781597243ae187ff4a6262be5138a1ea01 Mon Sep 17 00:00:00 2001 From: Szymon Bednorz Date: Sun, 2 Nov 2025 18:56:46 +0100 Subject: [PATCH 1/3] Fixed project versioning --- src/test_mcp/__init__.py | 13 ++----------- src/test_mcp/config.py | 10 +--------- src/test_mcp/testing/__init__.py | 3 --- 3 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/test_mcp/__init__.py b/src/test_mcp/__init__.py index bacb55e..b0aeec1 100644 --- a/src/test_mcp/__init__.py +++ b/src/test_mcp/__init__.py @@ -1,14 +1,5 @@ -""" -test-mcp: Comprehensive testing framework for MCP (Model Context Protocol) servers - -A sophisticated testing framework that combines AI agents with MCP server connectivity -for automated testing and CI/CD integration. -""" +"""test-mcp: Comprehensive testing framework for MCP servers""" __version__ = "0.1.0-beta.5" -__author__ = "MCP Testing Suite" -__email__ = "antoni@golf.dev" -__all__ = [ - "__version__", -] +__all__ = ["__version__"] diff --git a/src/test_mcp/config.py b/src/test_mcp/config.py index c1f1dd1..a239490 100644 --- a/src/test_mcp/config.py +++ b/src/test_mcp/config.py @@ -5,27 +5,19 @@ """ import os +import warnings -# API Keys for Local Testing ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") -# Validate required API keys if not ANTHROPIC_API_KEY: - import warnings - warnings.warn( "ANTHROPIC_API_KEY environment variable is not set. Agent functionality will be limited.", stacklevel=2, ) if not OPENAI_API_KEY: - import warnings - warnings.warn( "OPENAI_API_KEY environment variable is not set. Judge and user simulator functionality will be limited.", stacklevel=2, ) - -# Task Configuration -MAX_RESULT_SIZE_MB = int(os.getenv("MAX_TASK_RESULT_SIZE_MB", "50")) diff --git a/src/test_mcp/testing/__init__.py b/src/test_mcp/testing/__init__.py index da3bb4f..42053e4 100644 --- a/src/test_mcp/testing/__init__.py +++ b/src/test_mcp/testing/__init__.py @@ -4,6 +4,3 @@ A comprehensive testing framework for MCP (Model Context Protocol) servers with AI agents. Supports both single-response testing and multi-turn conversation testing. """ - -__version__ = "0.2.0" -__all__ = ["__version__"] From 0c5cd6a1df5b727a20b973bb465e1862af45b1d2 Mon Sep 17 00:00:00 2001 From: Szymon Bednorz Date: Sun, 2 Nov 2025 19:31:01 +0100 Subject: [PATCH 2/3] Utils refactor --- src/test_mcp/agent/agent.py | 4 +- src/test_mcp/cli/main.py | 64 ++- src/test_mcp/utils/__init__.py | 1 - src/test_mcp/utils/performance_monitor.py | 40 +- src/test_mcp/utils/rate_limiter.py | 231 +++++++---- src/test_mcp/utils/user_tracking.py | 104 +++-- src/test_mcp/utils/version_checker.py | 115 ++++-- tests/test_user_tracking.py | 459 ++++++++++++++++++++++ 8 files changed, 833 insertions(+), 185 deletions(-) create mode 100644 tests/test_user_tracking.py diff --git a/src/test_mcp/agent/agent.py b/src/test_mcp/agent/agent.py index 2d5390e..61701d0 100644 --- a/src/test_mcp/agent/agent.py +++ b/src/test_mcp/agent/agent.py @@ -194,14 +194,14 @@ async def _make_api_call_with_retry(self, api_params: dict) -> Any: total_tokens = ( response.usage.input_tokens + response.usage.output_tokens ) - self.rate_limiter.record_token_usage(correlation_id, total_tokens) + await self.rate_limiter.record_token_usage(correlation_id, total_tokens) return response except Exception as e: # Clean up pending request on error if self.rate_limiter and correlation_id: - self.rate_limiter.cleanup_pending_request(correlation_id) + await self.rate_limiter.cleanup_pending_request(correlation_id) # Check if this is the last attempt if attempt == max_retries: diff --git a/src/test_mcp/cli/main.py b/src/test_mcp/cli/main.py index 5d2591b..4fc01cd 100644 --- a/src/test_mcp/cli/main.py +++ b/src/test_mcp/cli/main.py @@ -59,7 +59,7 @@ def handle_bad_parameter(self, error: click.BadParameter) -> None: else: self.console.print_error(f"Invalid {param_name}: {error_msg}") - _handle_command_completion(self.start_time, exit_code=1) + handle_command_completion(self.start_time, exit_code=1) sys.exit(1) def handle_usage_error(self, error: click.UsageError) -> None: @@ -107,7 +107,7 @@ def handle_usage_error(self, error: click.UsageError) -> None: else: self.console.print_error(str(error)) - _handle_command_completion(self.start_time, exit_code=1) + handle_command_completion(self.start_time, exit_code=1) sys.exit(1) def handle_system_exit(self, error: SystemExit) -> None: @@ -120,7 +120,7 @@ def handle_system_exit(self, error: SystemExit) -> None: exit_code = int(exit_code) except ValueError: exit_code = 1 - _handle_command_completion(self.start_time, exit_code=exit_code) + handle_command_completion(self.start_time, exit_code=exit_code) raise @@ -171,6 +171,27 @@ def show_help(ctx, param, value): ctx.exit() +def handle_command_completion(start_time: float, exit_code: int) -> None: + """Track command completion and show suggestions""" + try: + # Track command for analytics + duration_ms = (time.time() - start_time) * 1000 + command_name = " ".join(sys.argv) if sys.argv else "mcp-t" + + command_tracker = get_command_tracker() + command_tracker.record_command(command_name, exit_code, duration_ms) + + # Show suggestions for all commands (not just failures) + # Skip for help commands and version commands + if not any(flag in sys.argv for flag in ["--help", "-h", "--version"]): + ctx = click.get_current_context(silent=True) + if ctx and hasattr(ctx, "obj") and ctx.obj: + trigger_post_command_hooks(ctx) + except Exception as e: + console = get_console() + console.print_error(f"Unexpected error: {e!s}") + + @click.group( invoke_without_command=True, name="mcp-t", @@ -216,54 +237,25 @@ def mcpt_main() -> None: try: mcpt_cli(standalone_mode=False) - _handle_command_completion(start_time, exit_code=0) + handle_command_completion(start_time, exit_code=0) except click.BadParameter as e: error_handler.handle_bad_parameter(e) except click.UsageError as e: error_handler.handle_usage_error(e) except SystemExit as e: error_handler.handle_system_exit(e) - except click.Abort: - # Handle user interruption (Ctrl+C) - Click converts KeyboardInterrupt to Abort + except (click.Abort, KeyboardInterrupt): console = get_console() console.print("\n[dim]Operation cancelled by user[/dim]") - _handle_command_completion(start_time, exit_code=130) - sys.exit(130) - except KeyboardInterrupt: - # Handle user interruption (fallback, though Click usually catches this first) - console = get_console() - console.print("\n[dim]Operation cancelled by user[/dim]") - _handle_command_completion(start_time, exit_code=130) + handle_command_completion(start_time, exit_code=130) sys.exit(130) except Exception as e: - # Only for truly unexpected errors - _handle_command_completion(start_time, exit_code=1) + handle_command_completion(start_time, exit_code=1) console = get_console() console.print(f"[red]Unexpected error: {e}[/red]") raise -def _handle_command_completion(start_time: float, exit_code: int) -> None: - """Track command completion and show suggestions""" - try: - # Track command for analytics - duration_ms = (time.time() - start_time) * 1000 - command_name = " ".join(sys.argv) if sys.argv else "mcp-t" - - command_tracker = get_command_tracker() - command_tracker.record_command(command_name, exit_code, duration_ms) - - # Show suggestions for all commands (not just failures) - # Skip for help commands and version commands - if not any(flag in sys.argv for flag in ["--help", "-h", "--version"]): - ctx = click.get_current_context(silent=True) - if ctx and hasattr(ctx, "obj") and ctx.obj: - trigger_post_command_hooks(ctx) - except Exception: - # Silent failure - don't break CLI for tracking/suggestion issues - pass - - # Register all commands from modules mcpt_cli.add_command(create_run_command()) mcpt_cli.add_command(create_generate_command()) diff --git a/src/test_mcp/utils/__init__.py b/src/test_mcp/utils/__init__.py index bc59b30..e69de29 100644 --- a/src/test_mcp/utils/__init__.py +++ b/src/test_mcp/utils/__init__.py @@ -1 +0,0 @@ -# Golf Test Utils Package diff --git a/src/test_mcp/utils/performance_monitor.py b/src/test_mcp/utils/performance_monitor.py index e4b1226..dbe5585 100644 --- a/src/test_mcp/utils/performance_monitor.py +++ b/src/test_mcp/utils/performance_monitor.py @@ -9,13 +9,29 @@ class TestExecutionMetrics: test_id: str start_time: float end_time: float | None = None - duration: float | None = None # Duration in seconds (optional for incomplete tests) turns_completed: int = 0 api_calls_made: int = 0 - # tokens_consumed removed - unreliable estimation success: bool = False error_message: str | None = None + @property + def duration(self) -> float | None: + """Calculate duration from start and end times""" + if self.end_time is None: + return None + return self.end_time - self.start_time + + def __post_init__(self): + """Validate metrics after initialization""" + if self.end_time is not None and self.end_time < self.start_time: + raise ValueError( + f"end_time ({self.end_time}) cannot be before start_time ({self.start_time})" + ) + if self.api_calls_made < 0: + raise ValueError("api_calls_made cannot be negative") + if self.turns_completed < 0: + raise ValueError("turns_completed cannot be negative") + @dataclass class SuiteExecutionMetrics: @@ -25,27 +41,29 @@ class SuiteExecutionMetrics: start_time: float test_metrics: list[TestExecutionMetrics] = field(default_factory=list) parallelism_used: int = 1 - total_duration: float | None = ( - None # Total duration in seconds (optional until completion) - ) + total_duration: float | None = None - def get_summary_stats(self) -> dict[str, str | int | float]: + def get_summary_stats(self) -> dict[str, str | int | float | None]: """Generate summary statistics for the test suite""" completed_tests = [t for t in self.test_metrics if t.duration is not None] if not completed_tests: return {"status": "no_completed_tests"} - durations = [t.duration for t in completed_tests if t.duration is not None] + # Duration is guaranteed to be not None for completed_tests + durations = [t.duration for t in completed_tests] return { "total_tests": len(self.test_metrics), "completed_tests": len(completed_tests), "success_rate": len([t for t in completed_tests if t.success]) / len(completed_tests), - "average_duration": statistics.mean(durations), # Duration in seconds - "median_duration": statistics.median(durations), # Duration in seconds + "average_duration": statistics.mean(durations), + "median_duration": statistics.median(durations), "total_api_calls": sum(t.api_calls_made for t in completed_tests), - # Token consumption removed for simplicity - "parallelism_efficiency": len(completed_tests) / (self.total_duration or 1), + "parallelism_efficiency": ( + len(completed_tests) / self.total_duration + if self.total_duration and self.total_duration > 0 + else None + ), } diff --git a/src/test_mcp/utils/rate_limiter.py b/src/test_mcp/utils/rate_limiter.py index 756547d..faa8786 100644 --- a/src/test_mcp/utils/rate_limiter.py +++ b/src/test_mcp/utils/rate_limiter.py @@ -1,133 +1,202 @@ import asyncio +import logging import time import uuid -from collections import defaultdict, deque +from collections import defaultdict +from typing import NamedTuple + +logger = logging.getLogger(__name__) + + +class RequestRecord(NamedTuple): + """Immutable record of a rate-limited request.""" + + timestamp: float + tokens_used: int + correlation_id: str class RateLimiter: - """Simple rate limiter with RPM and token limits""" + """ + Thread-safe async rate limiter with RPM and token limits. - # Entry tuple indices - TIMESTAMP_INDEX = 0 - TOKENS_INDEX = 1 - CORRELATION_ID_INDEX = 2 + Manages rate limiting across multiple providers with both request-per-minute + and token-per-minute constraints. Uses asyncio locks for coroutine safety. + """ - # Entry tuple minimum lengths - MIN_ENTRY_WITH_TOKENS = 2 - MIN_ENTRY_WITH_ID = 3 + TOKEN_USAGE_THRESHOLD = 0.8 + RATE_LIMIT_WINDOW_SECONDS = 60 + REQUEST_TIMEOUT_SECONDS = 300 + CLEANUP_CHECK_INTERVAL = 1.0 def __init__(self) -> None: - # Updated to realistic API limits (conservative defaults for reliable operation) self.providers = { "anthropic": {"requests_per_minute": 5000, "tokens_per_minute": 100000}, "openai": {"requests_per_minute": 5000, "tokens_per_minute": 100000}, "gemini": {"requests_per_minute": 60, "tokens_per_minute": 8000}, } - self.request_history: dict[str, deque] = defaultdict(deque) - # Add token usage tracking - self.token_usage: dict[str, int] = defaultdict(int) # Current window total - # Track correlation IDs to provider/timestamp mapping + + self.request_history: dict[str, dict[str, RequestRecord]] = defaultdict(dict) + self.token_usage: dict[str, int] = defaultdict(int) self._pending_requests: dict[str, tuple[str, float]] = {} + self._locks: dict[str, asyncio.Lock] = {} + self._global_lock = asyncio.Lock() + + async def _get_provider_lock(self, provider: str) -> asyncio.Lock: + """Get or create a lock for a specific provider.""" + async with self._global_lock: + if provider not in self._locks: + self._locks[provider] = asyncio.Lock() + return self._locks[provider] + async def acquire_request_slot(self, provider: str) -> str: - """Acquire permission to make API request and return correlation ID""" + """ + Acquire permission to make API request and return correlation ID. + + Thread-safe operation that waits if rate limits are exceeded. + + Args: + provider: API provider name (e.g., 'anthropic', 'openai') + + Returns: + Correlation ID for tracking this request + + Raises: + ValueError: If provider limits are not configured + """ + if provider not in self.providers: + logger.warning(f"Unknown provider '{provider}', using default limits") + limits = self.providers.get(provider, {}) rpm_limit = limits.get("requests_per_minute", 500) tpm_limit = limits.get("tokens_per_minute", 100000) - now = time.time() - self._clean_old_requests(provider, now) + if rpm_limit <= 0 or tpm_limit <= 0: + raise ValueError(f"Invalid rate limits for provider '{provider}'") + + lock = await self._get_provider_lock(provider) - # Check both request and token limits - while ( - len(self.request_history[provider]) >= rpm_limit - or self.token_usage[provider] > tpm_limit * 0.8 - ): - await asyncio.sleep(1) + async with lock: now = time.time() self._clean_old_requests(provider, now) - # Generate correlation ID and record the request - correlation_id = f"{provider}_{int(now)}_{uuid.uuid4().hex[:8]}" - self.request_history[provider].append((now, 0, correlation_id)) - self._pending_requests[correlation_id] = (provider, now) + while ( + len(self.request_history[provider]) >= rpm_limit + or self.token_usage[provider] > tpm_limit * self.TOKEN_USAGE_THRESHOLD + ): + await asyncio.sleep(self.CLEANUP_CHECK_INTERVAL) + now = time.time() + self._clean_old_requests(provider, now) + + correlation_id = f"{provider}_{int(now)}_{uuid.uuid4().hex[:8]}" + record = RequestRecord( + timestamp=now, tokens_used=0, correlation_id=correlation_id + ) + self.request_history[provider][correlation_id] = record + self._pending_requests[correlation_id] = (provider, now) + + return correlation_id + + async def record_token_usage(self, correlation_id: str, tokens_used: int) -> None: + """ + Record actual token usage from API response using correlation ID. - return correlation_id + Thread-safe operation that updates token usage tracking. + + Args: + correlation_id: ID returned from acquire_request_slot + tokens_used: Number of tokens consumed by the request + + Raises: + ValueError: If tokens_used is negative + """ + if tokens_used < 0: + raise ValueError(f"tokens_used must be non-negative, got {tokens_used}") - def record_token_usage(self, correlation_id: str, tokens_used: int) -> None: - """Record actual token usage from API response using correlation ID""" if correlation_id not in self._pending_requests: - print(f"Warning: Unknown correlation ID {correlation_id}") + logger.warning(f"Unknown correlation ID {correlation_id}") return provider, _timestamp = self._pending_requests[correlation_id] - self.token_usage[provider] += tokens_used + lock = await self._get_provider_lock(provider) - # Find and update the specific request entry - for i, entry in enumerate(self.request_history[provider]): - if ( - len(entry) >= self.MIN_ENTRY_WITH_ID - and entry[self.CORRELATION_ID_INDEX] == correlation_id - ): - req_time, _req_tokens, req_id = entry - self.request_history[provider][i] = (req_time, tokens_used, req_id) - break + async with lock: + self.token_usage[provider] += tokens_used - # Clean up pending request - del self._pending_requests[correlation_id] + if correlation_id in self.request_history[provider]: + old_record = self.request_history[provider][correlation_id] + updated_record = RequestRecord( + timestamp=old_record.timestamp, + tokens_used=tokens_used, + correlation_id=correlation_id, + ) + self.request_history[provider][correlation_id] = updated_record + + del self._pending_requests[correlation_id] def _clean_old_requests(self, provider: str, current_time: float) -> None: - """Remove requests older than 1 minute and their token usage""" - cutoff_time = current_time - 60 - timeout_cutoff = current_time - 300 # 5 minute absolute timeout + """ + Remove requests older than rate limit window and their token usage. - tokens_to_remove = 0 + Must be called while holding the provider lock. + + Args: + provider: Provider name to clean requests for + current_time: Current timestamp + """ + cutoff_time = current_time - self.RATE_LIMIT_WINDOW_SECONDS + timeout_cutoff = current_time - self.REQUEST_TIMEOUT_SECONDS - # Use index-based removal to avoid O(n²) performance and race conditions - # Process from right to left to maintain valid indices during removal - for i in range(len(self.request_history[provider]) - 1, -1, -1): - entry = self.request_history[provider][i] + tokens_to_remove = 0 + correlation_ids_to_remove = [] - if entry[self.TIMESTAMP_INDEX] >= cutoff_time: - continue # Entry is still fresh, keep it + for correlation_id, record in list(self.request_history[provider].items()): + if record.timestamp >= cutoff_time: + continue should_remove = False - # Force cleanup of extremely old requests (5+ minutes) - if entry[self.TIMESTAMP_INDEX] < timeout_cutoff: + if record.timestamp < timeout_cutoff: should_remove = True - if len(entry) >= self.MIN_ENTRY_WITH_ID: - correlation_id = entry[self.CORRELATION_ID_INDEX] - if correlation_id in self._pending_requests: - print( - f"Warning: Timing out request {correlation_id} after 5 minutes" - ) - del self._pending_requests[correlation_id] - # Only remove if tokens have been recorded (not pending) - elif len(entry) >= self.MIN_ENTRY_WITH_ID: - correlation_id = entry[self.CORRELATION_ID_INDEX] - # Skip entries that are still pending token recording - if correlation_id not in self._pending_requests: - should_remove = True - else: - # Entry without correlation ID can be safely removed + if correlation_id in self._pending_requests: + logger.warning( + f"Timing out request {correlation_id} after " + f"{self.REQUEST_TIMEOUT_SECONDS}s" + ) + del self._pending_requests[correlation_id] + elif correlation_id not in self._pending_requests: should_remove = True if should_remove: - # Count tokens before removal - if len(entry) >= self.MIN_ENTRY_WITH_TOKENS: - tokens = entry[self.TOKENS_INDEX] - tokens_to_remove += tokens + tokens_to_remove += record.tokens_used + correlation_ids_to_remove.append(correlation_id) - # Remove entry using index (O(1) for deque) - del self.request_history[provider][i] + for correlation_id in correlation_ids_to_remove: + del self.request_history[provider][correlation_id] - # Remove old tokens from current usage self.token_usage[provider] = max( 0, self.token_usage[provider] - tokens_to_remove ) - def cleanup_pending_request(self, correlation_id: str) -> None: - """Clean up pending request on error""" - if correlation_id in self._pending_requests: - del self._pending_requests[correlation_id] + async def cleanup_pending_request(self, correlation_id: str) -> None: + """ + Clean up pending request on error. + + Thread-safe operation to remove a pending request and its history entry. + + Args: + correlation_id: ID of the request to clean up + """ + if correlation_id not in self._pending_requests: + return + + provider, _timestamp = self._pending_requests[correlation_id] + lock = await self._get_provider_lock(provider) + + async with lock: + if correlation_id in self._pending_requests: + del self._pending_requests[correlation_id] + + if correlation_id in self.request_history[provider]: + del self.request_history[provider][correlation_id] diff --git a/src/test_mcp/utils/user_tracking.py b/src/test_mcp/utils/user_tracking.py index 8fa272a..eb6f1b2 100644 --- a/src/test_mcp/utils/user_tracking.py +++ b/src/test_mcp/utils/user_tracking.py @@ -1,58 +1,114 @@ import json +import os +import tempfile import threading import uuid from datetime import datetime +from pathlib import Path from .. import __version__ from ..config.config_manager import ConfigManager class UserTracker: - """Manages anonymous user identification""" + """Manages anonymous user identification with thread-safe operations.""" - def __init__(self): + def __init__(self) -> None: self.config_manager = ConfigManager() self.user_id_file = ( self.config_manager.paths.get_system_paths()["cache_dir"] / "user_id.json" ) + self._lock = threading.Lock() def get_or_create_user_id(self) -> str: - """Get existing or create new anonymous user ID""" - if self.user_id_file.exists(): - try: - data = json.loads(self.user_id_file.read_text()) - user_id = data.get("user_id") - if user_id: - return user_id - except Exception: - pass # Create new ID if file corrupted - - # Generate new anonymous ID - user_id = str(uuid.uuid4()) - self._save_user_id(user_id) - return user_id - - def _save_user_id(self, user_id: str): - """Save user ID to cache file""" + """ + Get existing or create new anonymous user ID. + + Thread-safe operation that reads or creates a persistent user ID. + Returns a valid UUID v4 string. + """ + with self._lock: + user_id = self._load_existing_user_id() + if user_id: + return user_id + + new_user_id = str(uuid.uuid4()) + self._save_user_id(new_user_id) + return new_user_id + + def _load_existing_user_id(self) -> str | None: + """Load and validate existing user ID from file.""" + if not self.user_id_file.exists(): + return None + + try: + data = json.loads(self.user_id_file.read_text(encoding="utf-8")) + user_id = data.get("user_id") + + if user_id and self._is_valid_uuid(user_id): + return user_id + + except (json.JSONDecodeError, OSError, KeyError): + pass + + return None + + def _is_valid_uuid(self, value: str) -> bool: + """Validate that a string is a proper UUID.""" + try: + uuid.UUID(value) + return True + except (ValueError, AttributeError, TypeError): + return False + + def _save_user_id(self, user_id: str) -> None: + """Save user ID to cache file atomically.""" self.user_id_file.parent.mkdir(parents=True, exist_ok=True) + data = { "user_id": user_id, "created_at": datetime.now().isoformat(), "version": __version__, } - self.user_id_file.write_text(json.dumps(data, indent=2)) + + self._atomic_write(self.user_id_file, json.dumps(data, indent=2)) + + def _atomic_write(self, target_path: Path, content: str) -> None: + """Write content to file atomically using temp file and rename.""" + temp_fd = None + temp_path = None + + try: + temp_fd, temp_path = tempfile.mkstemp( + dir=target_path.parent, prefix=f".{target_path.name}_", suffix=".tmp" + ) + + os.write(temp_fd, content.encode("utf-8")) + os.close(temp_fd) + temp_fd = None + + os.replace(temp_path, target_path) + temp_path = None + + except OSError as e: + if temp_fd is not None: + os.close(temp_fd) + if temp_path and os.path.exists(temp_path): + os.unlink(temp_path) + raise OSError(f"Failed to save user ID: {e}") from e -# Global instance _user_tracker: UserTracker | None = None _user_tracker_lock = threading.Lock() def get_user_tracker() -> UserTracker: - """Get shared user tracker instance""" + """Get shared user tracker instance using double-checked locking pattern.""" global _user_tracker - if _user_tracker is None: # First check (optimization) + + if _user_tracker is None: with _user_tracker_lock: - if _user_tracker is None: # Second check (safety) + if _user_tracker is None: _user_tracker = UserTracker() + return _user_tracker diff --git a/src/test_mcp/utils/version_checker.py b/src/test_mcp/utils/version_checker.py index 4fce4bf..469e9e5 100644 --- a/src/test_mcp/utils/version_checker.py +++ b/src/test_mcp/utils/version_checker.py @@ -1,5 +1,8 @@ import importlib.metadata import json +import logging +import os +import tempfile import threading from datetime import datetime, timedelta from pathlib import Path @@ -9,13 +12,19 @@ from ..config.config_manager import ConfigManager +logger = logging.getLogger(__name__) + class VersionChecker: """Handles version checking against PyPI with smart caching""" - def __init__(self, package_name: str = "mcp-testing", timeout: int = 5): + DEFAULT_TIMEOUT = 5 + DEFAULT_CACHE_TTL_DAYS = 7 + PYPI_BASE_URL = "https://pypi.org" + + def __init__(self, package_name: str = "mcp-testing", timeout: int | None = None): self.package_name = package_name - self.timeout = timeout + self.timeout = timeout or self.DEFAULT_TIMEOUT self.config_manager = ConfigManager() self.cache_file = self._get_cache_file() @@ -24,31 +33,42 @@ def _get_cache_file(self) -> Path: cache_dir = self.config_manager.paths.get_system_paths()["cache_dir"] return cache_dir / "version_check.json" - def get_current_version(self) -> str: - """Get currently installed package version""" + def get_current_version(self) -> str | None: + """Get currently installed package version.""" try: return importlib.metadata.version(self.package_name) except importlib.metadata.PackageNotFoundError: - return "0.0.0" # Development fallback + # Development mode - package not installed + return None def check_for_update_async(self, callback=None): - """Run version check in background thread""" + """Run version check in background thread.""" def check(): try: result = self.check_for_update() if callback: - callback(result) - except Exception: - # Silently fail - don't interrupt user workflow - pass - - thread = threading.Thread(target=check) + try: + callback(result) + except Exception as e: + logger.warning(f"Version check callback failed: {e}") + except Exception as e: + logger.debug(f"Background version check failed: {e}") + + thread = threading.Thread(target=check, name="version-checker") thread.daemon = True thread.start() - def check_for_update(self) -> dict | None: - """Check PyPI for newer version""" + def check_for_update(self) -> dict | None: # noqa: PLR0911 + """Check PyPI for newer version with caching.""" + # Skip check in development mode (package not installed) + current_version = self.get_current_version() + if current_version is None: + logger.debug( + f"Package '{self.package_name}' not installed, skipping version check" + ) + return None + # Check cache first cached_result = self._load_cache() if cached_result and not self._is_cache_expired(cached_result): @@ -57,13 +77,19 @@ def check_for_update(self) -> dict | None: try: # Fetch from PyPI response = httpx.get( - f"https://pypi.org/pypi/{self.package_name}/json", timeout=self.timeout + f"{self.PYPI_BASE_URL}/pypi/{self.package_name}/json", + timeout=self.timeout, ) response.raise_for_status() data = response.json() + + # Validate PyPI response structure + if "info" not in data or "version" not in data["info"]: + logger.warning(f"Invalid PyPI response for {self.package_name}") + return cached_result + latest_version = data["info"]["version"] - current_version = self.get_current_version() # Prepare result result = { @@ -72,41 +98,70 @@ def check_for_update(self) -> dict | None: "has_update": version.parse(latest_version) > version.parse(current_version), "last_check": datetime.now().isoformat(), - "package_url": f"https://pypi.org/project/{self.package_name}/", - "release_notes_url": f"https://pypi.org/project/{self.package_name}/{latest_version}/", + "package_url": f"{self.PYPI_BASE_URL}/project/{self.package_name}/", + "release_notes_url": f"{self.PYPI_BASE_URL}/project/{self.package_name}/{latest_version}/", } # Cache result self._save_cache(result) return result - except Exception: - # Return cached result on error, or None + except (httpx.HTTPError, httpx.TimeoutException) as e: + logger.debug(f"Network error checking PyPI: {e}") + return cached_result + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Invalid PyPI response format: {e}") + return cached_result + except Exception as e: + logger.error(f"Unexpected error checking for updates: {e}") return cached_result def _load_cache(self) -> dict | None: - """Load cached version check result""" + """Load cached version check result.""" try: if self.cache_file.exists(): with open(self.cache_file) as f: return json.load(f) - except Exception: - pass + except (json.JSONDecodeError, OSError) as e: + logger.debug(f"Failed to load version cache: {e}") return None def _save_cache(self, result: dict): - """Save version check result to cache""" + """Save version check result to cache using atomic write.""" try: self.cache_file.parent.mkdir(parents=True, exist_ok=True) - with open(self.cache_file, "w") as f: + + # Atomic write: write to temp file, then rename + with tempfile.NamedTemporaryFile( + mode="w", + dir=self.cache_file.parent, + delete=False, + prefix=".version_check_", + suffix=".tmp", + ) as f: json.dump(result, f, indent=2) - except Exception: - pass # Fail silently + temp_path = f.name + + # Atomic rename (overwrites existing file) + os.replace(temp_path, self.cache_file) + + except OSError as e: + logger.debug(f"Failed to save version cache: {e}") + # Clean up temp file if it exists + try: + if "temp_path" in locals(): + os.unlink(temp_path) + except OSError: + pass - def _is_cache_expired(self, cached_result: dict, ttl_days: int = 7) -> bool: - """Check if cached result has expired""" + def _is_cache_expired( + self, cached_result: dict, ttl_days: int | None = None + ) -> bool: + """Check if cached result has expired.""" + ttl_days = ttl_days or self.DEFAULT_CACHE_TTL_DAYS try: last_check = datetime.fromisoformat(cached_result["last_check"]) return datetime.now() - last_check > timedelta(days=ttl_days) - except Exception: + except (KeyError, ValueError, TypeError) as e: + logger.debug(f"Invalid cache timestamp: {e}") return True # Treat invalid cache as expired diff --git a/tests/test_user_tracking.py b/tests/test_user_tracking.py new file mode 100644 index 0000000..b5fcd81 --- /dev/null +++ b/tests/test_user_tracking.py @@ -0,0 +1,459 @@ +""" +Comprehensive tests for UserTracker functionality. + +Tests cover: +- User ID creation and persistence +- UUID validation +- Thread safety +- Error handling +- Atomic file operations +- Edge cases +""" + +import json +import os +import threading +import uuid +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock, patch + +import pytest + +from src.test_mcp.utils.user_tracking import UserTracker, get_user_tracker + + +@pytest.fixture +def temp_cache_dir(tmp_path): + """Create a temporary cache directory for testing.""" + cache_dir = tmp_path / "cache" + cache_dir.mkdir() + return cache_dir + + +@pytest.fixture +def mock_config_manager(temp_cache_dir): + """Mock ConfigManager to use temporary directory.""" + mock_manager = MagicMock() + mock_manager.paths.get_system_paths.return_value = {"cache_dir": temp_cache_dir} + return mock_manager + + +@pytest.fixture +def user_tracker(mock_config_manager): + """Create UserTracker instance with mocked config.""" + with patch( + "src.test_mcp.utils.user_tracking.ConfigManager", + return_value=mock_config_manager, + ): + tracker = UserTracker() + return tracker + + +class TestUserTrackerBasics: + """Test basic UserTracker functionality.""" + + def test_initialization(self, user_tracker, temp_cache_dir): + """Test UserTracker initializes correctly.""" + assert user_tracker.user_id_file == temp_cache_dir / "user_id.json" + assert hasattr(user_tracker, "_lock") + assert isinstance(user_tracker._lock, type(threading.Lock())) + + def test_create_new_user_id(self, user_tracker): + """Test creating a new user ID when none exists.""" + user_id = user_tracker.get_or_create_user_id() + + assert user_id is not None + assert isinstance(user_id, str) + assert user_tracker._is_valid_uuid(user_id) + + def test_user_id_persisted_to_file(self, user_tracker): + """Test user ID is saved to file correctly.""" + user_id = user_tracker.get_or_create_user_id() + + assert user_tracker.user_id_file.exists() + + data = json.loads(user_tracker.user_id_file.read_text()) + assert data["user_id"] == user_id + assert "created_at" in data + assert "version" in data + + def test_load_existing_user_id(self, user_tracker): + """Test loading existing user ID from file.""" + first_id = user_tracker.get_or_create_user_id() + second_id = user_tracker.get_or_create_user_id() + + assert first_id == second_id + + def test_user_id_consistent_across_instances(self, mock_config_manager): + """Test same user ID is returned across different tracker instances.""" + with patch( + "src.test_mcp.utils.user_tracking.ConfigManager", + return_value=mock_config_manager, + ): + tracker1 = UserTracker() + user_id1 = tracker1.get_or_create_user_id() + + tracker2 = UserTracker() + user_id2 = tracker2.get_or_create_user_id() + + assert user_id1 == user_id2 + + +class TestUUIDValidation: + """Test UUID validation functionality.""" + + def test_valid_uuid_v4(self, user_tracker): + """Test valid UUID v4 strings are accepted.""" + valid_uuid = str(uuid.uuid4()) + assert user_tracker._is_valid_uuid(valid_uuid) is True + + def test_valid_uuid_v1(self, user_tracker): + """Test valid UUID v1 strings are accepted.""" + valid_uuid = str(uuid.uuid1()) + assert user_tracker._is_valid_uuid(valid_uuid) is True + + def test_invalid_uuid_string(self, user_tracker): + """Test invalid UUID strings are rejected.""" + invalid_uuids = [ + "not-a-uuid", + "12345", + "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "123e4567-e89b-12d3-a456-42661417400", # Too short + "", + "123e4567-e89b-12d3-a456-426614174000-extra", # Too long + ] + + for invalid in invalid_uuids: + assert user_tracker._is_valid_uuid(invalid) is False + + def test_non_string_uuid(self, user_tracker): + """Test non-string values are rejected.""" + assert user_tracker._is_valid_uuid(None) is False + assert user_tracker._is_valid_uuid(123) is False + assert user_tracker._is_valid_uuid([]) is False + + def test_uuid_with_uppercase(self, user_tracker): + """Test UUID with uppercase letters is valid.""" + valid_uuid = "550E8400-E29B-41D4-A716-446655440000" + assert user_tracker._is_valid_uuid(valid_uuid) is True + + +class TestFileHandling: + """Test file operations and error handling.""" + + def test_creates_parent_directory(self, user_tracker): + """Test parent directory is created if it doesn't exist.""" + parent_dir = user_tracker.user_id_file.parent + if parent_dir.exists(): + import shutil + + shutil.rmtree(parent_dir) + + user_tracker.get_or_create_user_id() + + assert parent_dir.exists() + assert user_tracker.user_id_file.exists() + + def test_handles_corrupted_json(self, user_tracker): + """Test handles corrupted JSON file gracefully.""" + user_tracker.user_id_file.parent.mkdir(parents=True, exist_ok=True) + user_tracker.user_id_file.write_text("{ invalid json }") + + user_id = user_tracker.get_or_create_user_id() + + assert user_id is not None + assert user_tracker._is_valid_uuid(user_id) + + def test_handles_empty_file(self, user_tracker): + """Test handles empty file gracefully.""" + user_tracker.user_id_file.parent.mkdir(parents=True, exist_ok=True) + user_tracker.user_id_file.write_text("") + + user_id = user_tracker.get_or_create_user_id() + + assert user_id is not None + assert user_tracker._is_valid_uuid(user_id) + + def test_handles_missing_user_id_field(self, user_tracker): + """Test handles JSON without user_id field.""" + user_tracker.user_id_file.parent.mkdir(parents=True, exist_ok=True) + data = {"created_at": "2024-01-01", "version": "1.0.0"} + user_tracker.user_id_file.write_text(json.dumps(data)) + + user_id = user_tracker.get_or_create_user_id() + + assert user_id is not None + assert user_tracker._is_valid_uuid(user_id) + + def test_handles_invalid_uuid_in_file(self, user_tracker): + """Test handles invalid UUID in file.""" + user_tracker.user_id_file.parent.mkdir(parents=True, exist_ok=True) + data = { + "user_id": "not-a-valid-uuid", + "created_at": "2024-01-01", + "version": "1.0.0", + } + user_tracker.user_id_file.write_text(json.dumps(data)) + + user_id = user_tracker.get_or_create_user_id() + + assert user_id is not None + assert user_tracker._is_valid_uuid(user_id) + assert user_id != "not-a-valid-uuid" + + def test_handles_null_user_id(self, user_tracker): + """Test handles null user_id in file.""" + user_tracker.user_id_file.parent.mkdir(parents=True, exist_ok=True) + data = {"user_id": None, "created_at": "2024-01-01", "version": "1.0.0"} + user_tracker.user_id_file.write_text(json.dumps(data)) + + user_id = user_tracker.get_or_create_user_id() + + assert user_id is not None + assert user_tracker._is_valid_uuid(user_id) + + +class TestAtomicWrite: + """Test atomic write operations.""" + + def test_atomic_write_creates_file(self, user_tracker, temp_cache_dir): + """Test atomic write creates file correctly.""" + test_file = temp_cache_dir / "test.json" + content = '{"test": "data"}' + + user_tracker._atomic_write(test_file, content) + + assert test_file.exists() + assert test_file.read_text() == content + + def test_atomic_write_overwrites_existing(self, user_tracker, temp_cache_dir): + """Test atomic write overwrites existing file.""" + test_file = temp_cache_dir / "test.json" + test_file.write_text("old content") + + new_content = '{"test": "new data"}' + user_tracker._atomic_write(test_file, new_content) + + assert test_file.read_text() == new_content + + def test_atomic_write_cleanup_on_error(self, user_tracker, temp_cache_dir): + """Test temporary file is cleaned up on error.""" + test_file = temp_cache_dir / "test.json" + + with patch("os.write", side_effect=OSError("Disk full")): + with pytest.raises(OSError, match="Failed to save user ID"): + user_tracker._atomic_write(test_file, "content") + + temp_files = list(temp_cache_dir.glob(".test.json_*.tmp")) + assert len(temp_files) == 0 + + def test_atomic_write_with_unicode(self, user_tracker, temp_cache_dir): + """Test atomic write handles unicode content.""" + test_file = temp_cache_dir / "test.json" + content = '{"emoji": "🎉", "chinese": "你好"}' + + user_tracker._atomic_write(test_file, content) + + assert test_file.read_text(encoding="utf-8") == content + + def test_no_partial_writes(self, user_tracker, temp_cache_dir): + """Test file is not partially written on failure.""" + test_file = temp_cache_dir / "test.json" + original_content = "original" + test_file.write_text(original_content) + + call_count = 0 + + def failing_write(fd, data): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise OSError("Write failed") + + with patch("os.write", side_effect=failing_write): + with pytest.raises(OSError): + user_tracker._atomic_write(test_file, "new content") + + assert test_file.read_text() == original_content + + +class TestThreadSafety: + """Test thread safety of UserTracker operations.""" + + def test_concurrent_get_or_create(self, user_tracker): + """Test concurrent get_or_create_user_id calls return same ID.""" + user_ids = [] + + def get_user_id(): + user_id = user_tracker.get_or_create_user_id() + user_ids.append(user_id) + + with ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(get_user_id) for _ in range(50)] + for future in futures: + future.result() + + assert len(set(user_ids)) == 1 + + def test_single_instance_no_race_condition(self, user_tracker): + """Test single tracker instance prevents race conditions.""" + user_ids = set() + + def get_id_multiple_times(): + for _ in range(10): + user_id = user_tracker.get_or_create_user_id() + user_ids.add(user_id) + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(get_id_multiple_times) for _ in range(10)] + for future in futures: + future.result() + + assert len(user_ids) == 1 + + def test_concurrent_read_write(self, user_tracker): + """Test concurrent reads during write operations.""" + results = [] + errors = [] + + def read_or_write(i): + try: + user_id = user_tracker.get_or_create_user_id() + results.append(user_id) + except Exception as e: + errors.append(e) + + with ThreadPoolExecutor(max_workers=30) as executor: + futures = [executor.submit(read_or_write, i) for i in range(100)] + for future in futures: + future.result() + + assert len(errors) == 0 + assert len(set(results)) == 1 + assert all(user_tracker._is_valid_uuid(uid) for uid in results) + + +class TestSingletonPattern: + """Test singleton pattern for get_user_tracker().""" + + def test_returns_same_instance(self): + """Test get_user_tracker returns same instance.""" + tracker1 = get_user_tracker() + tracker2 = get_user_tracker() + + assert tracker1 is tracker2 + + def test_singleton_thread_safe(self): + """Test singleton pattern is thread-safe.""" + instances = [] + + def get_instance(): + tracker = get_user_tracker() + instances.append(id(tracker)) + + with ThreadPoolExecutor(max_workers=50) as executor: + futures = [executor.submit(get_instance) for _ in range(100)] + for future in futures: + future.result() + + assert len(set(instances)) == 1 + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_very_long_path(self, mock_config_manager, tmp_path): + """Test handles very long file paths.""" + long_dir = tmp_path / ("x" * 100) / ("y" * 100) + mock_config_manager.paths.get_system_paths.return_value = { + "cache_dir": long_dir + } + + with patch( + "src.test_mcp.utils.user_tracking.ConfigManager", + return_value=mock_config_manager, + ): + tracker = UserTracker() + user_id = tracker.get_or_create_user_id() + + assert tracker._is_valid_uuid(user_id) + + def test_file_with_special_characters_in_path(self, mock_config_manager, tmp_path): + """Test handles paths with special characters.""" + special_dir = tmp_path / "test-dir_123" / "sub.dir" + mock_config_manager.paths.get_system_paths.return_value = { + "cache_dir": special_dir + } + + with patch( + "src.test_mcp.utils.user_tracking.ConfigManager", + return_value=mock_config_manager, + ): + tracker = UserTracker() + user_id = tracker.get_or_create_user_id() + + assert tracker._is_valid_uuid(user_id) + + def test_readonly_parent_directory(self, user_tracker): + """Test handles read-only parent directory.""" + user_tracker.user_id_file.parent.mkdir(parents=True, exist_ok=True) + + try: + os.chmod(user_tracker.user_id_file.parent, 0o444) + + with pytest.raises(OSError): + user_tracker.get_or_create_user_id() + finally: + os.chmod(user_tracker.user_id_file.parent, 0o755) # noqa: S103 + + def test_existing_file_with_extra_fields(self, user_tracker): + """Test handles file with extra fields gracefully.""" + user_tracker.user_id_file.parent.mkdir(parents=True, exist_ok=True) + valid_uuid = str(uuid.uuid4()) + data = { + "user_id": valid_uuid, + "created_at": "2024-01-01", + "version": "1.0.0", + "extra_field": "should be ignored", + "another_field": 12345, + } + user_tracker.user_id_file.write_text(json.dumps(data)) + + user_id = user_tracker.get_or_create_user_id() + + assert user_id == valid_uuid + + +class TestMetadata: + """Test metadata storage in user ID file.""" + + def test_stores_created_at_timestamp(self, user_tracker): + """Test created_at timestamp is stored.""" + from datetime import datetime + + before = datetime.now() + user_tracker.get_or_create_user_id() + after = datetime.now() + + data = json.loads(user_tracker.user_id_file.read_text()) + created_at = datetime.fromisoformat(data["created_at"]) + + assert before <= created_at <= after + + def test_stores_version(self, user_tracker): + """Test version is stored in file.""" + from src.test_mcp import __version__ + + user_tracker.get_or_create_user_id() + + data = json.loads(user_tracker.user_id_file.read_text()) + assert data["version"] == __version__ + + def test_json_formatting(self, user_tracker): + """Test JSON file is properly formatted with indentation.""" + user_tracker.get_or_create_user_id() + + content = user_tracker.user_id_file.read_text() + + assert content.count("\n") > 1 + assert " " in content From 3d6674573c9b462671d227e11a124f21fcf5da6b Mon Sep 17 00:00:00 2001 From: Szymon Bednorz Date: Sun, 2 Nov 2025 22:47:05 +0100 Subject: [PATCH 3/3] Add tests to utils --- tests/test_command_tracker.py | 292 ++++++++++++++++++++++++++ tests/test_performance_monitor.py | 325 +++++++++++++++++++++++++++++ tests/test_rate_limiter.py | 330 ++++++++++++++++++++++++++++++ 3 files changed, 947 insertions(+) create mode 100644 tests/test_command_tracker.py create mode 100644 tests/test_performance_monitor.py create mode 100644 tests/test_rate_limiter.py diff --git a/tests/test_command_tracker.py b/tests/test_command_tracker.py new file mode 100644 index 0000000..d6ed3be --- /dev/null +++ b/tests/test_command_tracker.py @@ -0,0 +1,292 @@ +"""Tests for command tracking functionality""" + +import json +import pytest +import tempfile +from pathlib import Path +from datetime import datetime +from unittest.mock import Mock, patch + +from src.test_mcp.utils.command_tracker import CommandTracker, get_command_tracker +from src.test_mcp.models.reporting import CommandHistoryEntry + + +class TestCommandTracker: + """Test CommandTracker class""" + + @pytest.fixture + def temp_cache_dir(self): + """Create a temporary cache directory for testing""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + @pytest.fixture + def tracker(self, temp_cache_dir): + """Create a CommandTracker with temporary cache directory""" + with patch("src.test_mcp.utils.command_tracker.ConfigManager") as mock_config: + mock_instance = Mock() + mock_instance.paths.get_system_paths.return_value = { + "cache_dir": temp_cache_dir + } + mock_config.return_value = mock_instance + return CommandTracker() + + def test_initialization(self, tracker): + """Test CommandTracker initialization""" + assert tracker.max_history == 25 + assert tracker.history_file.name == "command_history.json" + + def test_initialization_custom_max_history(self, temp_cache_dir): + """Test initialization with custom max_history""" + with patch("src.test_mcp.utils.command_tracker.ConfigManager") as mock_config: + mock_instance = Mock() + mock_instance.paths.get_system_paths.return_value = { + "cache_dir": temp_cache_dir + } + mock_config.return_value = mock_instance + tracker = CommandTracker(max_history=50) + assert tracker.max_history == 50 + + def test_sanitize_command_basic(self, tracker): + """Test basic command sanitization""" + command = "mcp-t run suite.json" + sanitized = tracker._sanitize_command(command) + assert sanitized == "mcp-t run suite.json" + + def test_sanitize_command_full_path(self, tracker): + """Test sanitization of full path to mcp-t""" + command = "/usr/local/bin/mcp-t run suite.json" + sanitized = tracker._sanitize_command(command) + assert sanitized == "mcp-t run suite.json" + + def test_sanitize_command_windows_path(self, tracker): + """Test sanitization of Windows path to mcp-t""" + command = "C:\\Users\\test\\bin\\mcp-t run suite.json" + sanitized = tracker._sanitize_command(command) + assert sanitized == "mcp-t run suite.json" + + def test_sanitize_command_user_home(self, tracker): + """Test sanitization of user home directory""" + command = "mcp-t run /Users/john/config/suite.json" + sanitized = tracker._sanitize_command(command) + assert "john" not in sanitized + assert "~" in sanitized + + def test_sanitize_command_empty(self, tracker): + """Test sanitization of empty command""" + assert tracker._sanitize_command("") == "" + + def test_sanitize_command_complex(self, tracker): + """Test sanitization of complex command with multiple arguments""" + command = "/usr/local/bin/mcp-t run --parallel 4 /Users/test/suite.json" + sanitized = tracker._sanitize_command(command) + assert sanitized.startswith("mcp-t") + assert "--parallel" in sanitized + assert "4" in sanitized + assert "/usr/local/bin" not in sanitized + + def test_record_command_basic(self, tracker): + """Test recording a basic command""" + tracker.record_command("mcp-t run suite.json") + history = tracker.get_recent_history() + + assert len(history) == 1 + assert history[0].command == "mcp-t run suite.json" + assert history[0].exit_code is None + assert history[0].duration_ms is None + assert isinstance(history[0].timestamp, datetime) + + def test_record_command_with_exit_code(self, tracker): + """Test recording command with exit code""" + tracker.record_command("mcp-t run suite.json", exit_code=0) + history = tracker.get_recent_history() + + assert len(history) == 1 + assert history[0].exit_code == 0 + + def test_record_command_with_duration(self, tracker): + """Test recording command with duration""" + tracker.record_command("mcp-t run suite.json", duration_ms=1234.5) + history = tracker.get_recent_history() + + assert len(history) == 1 + assert history[0].duration_ms == 1234.5 + + def test_record_command_with_all_fields(self, tracker): + """Test recording command with all fields""" + tracker.record_command( + "mcp-t run suite.json", exit_code=0, duration_ms=1234.5 + ) + history = tracker.get_recent_history() + + assert len(history) == 1 + assert history[0].command == "mcp-t run suite.json" + assert history[0].exit_code == 0 + assert history[0].duration_ms == 1234.5 + + def test_record_multiple_commands(self, tracker): + """Test recording multiple commands""" + tracker.record_command("mcp-t run suite1.json") + tracker.record_command("mcp-t run suite2.json") + tracker.record_command("mcp-t run suite3.json") + + history = tracker.get_recent_history() + assert len(history) == 3 + assert history[0].command == "mcp-t run suite1.json" + assert history[1].command == "mcp-t run suite2.json" + assert history[2].command == "mcp-t run suite3.json" + + def test_get_recent_history_limit(self, tracker): + """Test getting recent history with limit""" + for i in range(15): + tracker.record_command(f"mcp-t run suite{i}.json") + + history = tracker.get_recent_history(limit=5) + assert len(history) == 5 + # Should get the most recent 5 + assert history[0].command == "mcp-t run suite10.json" + assert history[4].command == "mcp-t run suite14.json" + + def test_max_history_enforcement(self, tracker): + """Test that history is limited to max_history entries""" + tracker.max_history = 10 + + # Record more commands than max_history + for i in range(20): + tracker.record_command(f"mcp-t run suite{i}.json") + + history = tracker.get_recent_history(limit=100) + assert len(history) == 10 + # Should only keep the last 10 + assert history[0].command == "mcp-t run suite10.json" + assert history[9].command == "mcp-t run suite19.json" + + def test_load_history_nonexistent_file(self, tracker): + """Test loading history when file doesn't exist""" + history = tracker._load_history() + assert history == [] + + def test_save_and_load_history(self, tracker): + """Test saving and loading history from file""" + tracker.record_command("mcp-t run suite.json", exit_code=0) + + # Create a new tracker instance with same cache dir + with patch("src.test_mcp.utils.command_tracker.ConfigManager") as mock_config: + mock_instance = Mock() + mock_instance.paths.get_system_paths.return_value = { + "cache_dir": tracker.history_file.parent + } + mock_config.return_value = mock_instance + new_tracker = CommandTracker() + + # Should load existing history + history = new_tracker.get_recent_history() + assert len(history) == 1 + assert history[0].command == "mcp-t run suite.json" + assert history[0].exit_code == 0 + + def test_load_history_corrupted_file(self, tracker): + """Test loading history from corrupted file returns empty list""" + # Create corrupted JSON file + tracker.history_file.parent.mkdir(parents=True, exist_ok=True) + tracker.history_file.write_text("corrupted{json") + + history = tracker._load_history() + assert history == [] + + def test_save_history_creates_directory(self, tracker): + """Test that saving history creates parent directory if needed""" + # Ensure directory doesn't exist + if tracker.history_file.parent.exists(): + tracker.history_file.unlink(missing_ok=True) + tracker.history_file.parent.rmdir() + + tracker.record_command("mcp-t run suite.json") + + assert tracker.history_file.exists() + assert tracker.history_file.parent.exists() + + def test_get_recent_history_empty(self, tracker): + """Test getting history when none exists""" + history = tracker.get_recent_history() + assert history == [] + + def test_command_sanitization_preserves_arguments(self, tracker): + """Test that sanitization preserves command arguments""" + command = "mcp-t run --parallel 4 --timeout 30 suite.json" + tracker.record_command(command) + history = tracker.get_recent_history() + + assert "--parallel" in history[0].command + assert "4" in history[0].command + assert "--timeout" in history[0].command + assert "30" in history[0].command + + def test_history_persistence(self, tracker): + """Test that history persists across tracker instances""" + # Record in first tracker + tracker.record_command("mcp-t run suite1.json", exit_code=0) + tracker.record_command("mcp-t run suite2.json", exit_code=1) + + # Create second tracker with same cache dir + with patch("src.test_mcp.utils.command_tracker.ConfigManager") as mock_config: + mock_instance = Mock() + mock_instance.paths.get_system_paths.return_value = { + "cache_dir": tracker.history_file.parent + } + mock_config.return_value = mock_instance + tracker2 = CommandTracker() + + # Add more commands + tracker2.record_command("mcp-t run suite3.json", exit_code=0) + + # Should have all commands + history = tracker2.get_recent_history() + assert len(history) == 3 + + def test_history_file_format(self, tracker): + """Test that history file has correct JSON format""" + tracker.record_command("mcp-t run suite.json", exit_code=0, duration_ms=100.5) + + # Read raw file content + content = json.loads(tracker.history_file.read_text()) + + assert isinstance(content, list) + assert len(content) == 1 + assert "command" in content[0] + assert "timestamp" in content[0] + assert "exit_code" in content[0] + assert "duration_ms" in content[0] + + +class TestGetCommandTracker: + """Test get_command_tracker singleton function""" + + def test_returns_same_instance(self): + """Test that get_command_tracker returns the same instance""" + tracker1 = get_command_tracker() + tracker2 = get_command_tracker() + assert tracker1 is tracker2 + + def test_thread_safe_initialization(self): + """Test that singleton initialization is thread-safe""" + import threading + + # Reset global tracker + import src.test_mcp.utils.command_tracker as ct_module + ct_module._command_tracker = None + + trackers = [] + + def get_tracker(): + trackers.append(get_command_tracker()) + + threads = [threading.Thread(target=get_tracker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should get the same instance + assert len(set(id(t) for t in trackers)) == 1 + diff --git a/tests/test_performance_monitor.py b/tests/test_performance_monitor.py new file mode 100644 index 0000000..5f8dbc1 --- /dev/null +++ b/tests/test_performance_monitor.py @@ -0,0 +1,325 @@ +"""Tests for performance monitoring utilities""" + +import pytest +import time + +from src.test_mcp.utils.performance_monitor import ( + TestExecutionMetrics, + SuiteExecutionMetrics, +) + + +class TestTestExecutionMetrics: + """Test TestExecutionMetrics dataclass""" + + def test_basic_creation(self): + """Test creating a basic metrics object""" + metrics = TestExecutionMetrics( + test_id="test-1", + start_time=100.0, + ) + assert metrics.test_id == "test-1" + assert metrics.start_time == 100.0 + assert metrics.end_time is None + assert metrics.turns_completed == 0 + assert metrics.api_calls_made == 0 + assert metrics.success is False + assert metrics.error_message is None + + def test_creation_with_all_fields(self): + """Test creating metrics with all fields populated""" + metrics = TestExecutionMetrics( + test_id="test-2", + start_time=100.0, + end_time=150.0, + turns_completed=5, + api_calls_made=10, + success=True, + error_message="Error occurred", + ) + assert metrics.test_id == "test-2" + assert metrics.start_time == 100.0 + assert metrics.end_time == 150.0 + assert metrics.turns_completed == 5 + assert metrics.api_calls_made == 10 + assert metrics.success is True + assert metrics.error_message == "Error occurred" + + def test_duration_calculation(self): + """Test duration calculation from start and end times""" + metrics = TestExecutionMetrics( + test_id="test-3", + start_time=100.0, + end_time=150.0, + ) + assert metrics.duration == 50.0 + + def test_duration_none_when_not_ended(self): + """Test that duration is None when end_time is not set""" + metrics = TestExecutionMetrics( + test_id="test-4", + start_time=100.0, + ) + assert metrics.duration is None + + def test_duration_zero_for_instant_completion(self): + """Test duration can be zero for instant completion""" + metrics = TestExecutionMetrics( + test_id="test-5", + start_time=100.0, + end_time=100.0, + ) + assert metrics.duration == 0.0 + + def test_invalid_end_time_raises_error(self): + """Test that end_time before start_time raises ValueError""" + with pytest.raises(ValueError, match="end_time.*cannot be before start_time"): + TestExecutionMetrics( + test_id="test-6", + start_time=100.0, + end_time=50.0, + ) + + def test_negative_api_calls_raises_error(self): + """Test that negative api_calls_made raises ValueError""" + with pytest.raises(ValueError, match="api_calls_made cannot be negative"): + TestExecutionMetrics( + test_id="test-7", + start_time=100.0, + api_calls_made=-1, + ) + + def test_negative_turns_raises_error(self): + """Test that negative turns_completed raises ValueError""" + with pytest.raises(ValueError, match="turns_completed cannot be negative"): + TestExecutionMetrics( + test_id="test-8", + start_time=100.0, + turns_completed=-5, + ) + + def test_realistic_timing(self): + """Test with realistic timing values""" + start = time.time() + time.sleep(0.01) # Small delay + end = time.time() + + metrics = TestExecutionMetrics( + test_id="test-9", + start_time=start, + end_time=end, + ) + assert metrics.duration is not None + assert metrics.duration > 0.01 + assert metrics.duration < 1.0 # Should complete quickly + + +class TestSuiteExecutionMetrics: + """Test SuiteExecutionMetrics dataclass""" + + def test_basic_creation(self): + """Test creating a basic suite metrics object""" + suite = SuiteExecutionMetrics( + suite_id="suite-1", + start_time=100.0, + ) + assert suite.suite_id == "suite-1" + assert suite.start_time == 100.0 + assert suite.test_metrics == [] + assert suite.parallelism_used == 1 + assert suite.total_duration is None + + def test_creation_with_tests(self): + """Test creating suite with test metrics""" + test1 = TestExecutionMetrics( + test_id="test-1", start_time=100.0, end_time=110.0, success=True + ) + test2 = TestExecutionMetrics( + test_id="test-2", start_time=105.0, end_time=115.0, success=True + ) + + suite = SuiteExecutionMetrics( + suite_id="suite-2", + start_time=100.0, + test_metrics=[test1, test2], + parallelism_used=2, + total_duration=15.0, + ) + assert len(suite.test_metrics) == 2 + assert suite.parallelism_used == 2 + assert suite.total_duration == 15.0 + + def test_summary_stats_empty_suite(self): + """Test summary stats for suite with no completed tests""" + suite = SuiteExecutionMetrics(suite_id="suite-3", start_time=100.0) + stats = suite.get_summary_stats() + assert stats == {"status": "no_completed_tests"} + + def test_summary_stats_with_incomplete_tests(self): + """Test summary stats when tests haven't ended""" + test1 = TestExecutionMetrics(test_id="test-1", start_time=100.0) + + suite = SuiteExecutionMetrics( + suite_id="suite-4", start_time=100.0, test_metrics=[test1] + ) + stats = suite.get_summary_stats() + assert stats == {"status": "no_completed_tests"} + + def test_summary_stats_single_successful_test(self): + """Test summary stats with one successful test""" + test1 = TestExecutionMetrics( + test_id="test-1", + start_time=100.0, + end_time=120.0, + api_calls_made=5, + success=True, + ) + + suite = SuiteExecutionMetrics( + suite_id="suite-5", + start_time=100.0, + test_metrics=[test1], + total_duration=20.0, + ) + stats = suite.get_summary_stats() + + assert stats["total_tests"] == 1 + assert stats["completed_tests"] == 1 + assert stats["success_rate"] == 1.0 + assert stats["average_duration"] == 20.0 + assert stats["median_duration"] == 20.0 + assert stats["total_api_calls"] == 5 + assert stats["parallelism_efficiency"] == 1 / 20.0 + + def test_summary_stats_multiple_tests(self): + """Test summary stats with multiple tests""" + test1 = TestExecutionMetrics( + test_id="test-1", + start_time=100.0, + end_time=110.0, + api_calls_made=3, + success=True, + ) + test2 = TestExecutionMetrics( + test_id="test-2", + start_time=105.0, + end_time=125.0, + api_calls_made=7, + success=False, + ) + test3 = TestExecutionMetrics( + test_id="test-3", + start_time=110.0, + end_time=140.0, + api_calls_made=5, + success=True, + ) + + suite = SuiteExecutionMetrics( + suite_id="suite-6", + start_time=100.0, + test_metrics=[test1, test2, test3], + parallelism_used=2, + total_duration=40.0, + ) + stats = suite.get_summary_stats() + + assert stats["total_tests"] == 3 + assert stats["completed_tests"] == 3 + assert stats["success_rate"] == 2.0 / 3.0 # 2 out of 3 successful + assert stats["average_duration"] == (10.0 + 20.0 + 30.0) / 3.0 + assert stats["median_duration"] == 20.0 + assert stats["total_api_calls"] == 15 + assert stats["parallelism_efficiency"] == 3 / 40.0 + + def test_summary_stats_mixed_completion(self): + """Test summary stats with mix of completed and incomplete tests""" + test1 = TestExecutionMetrics( + test_id="test-1", + start_time=100.0, + end_time=110.0, + success=True, + ) + test2 = TestExecutionMetrics( + test_id="test-2", + start_time=105.0, + ) # Not completed + + suite = SuiteExecutionMetrics( + suite_id="suite-7", + start_time=100.0, + test_metrics=[test1, test2], + ) + stats = suite.get_summary_stats() + + assert stats["total_tests"] == 2 + assert stats["completed_tests"] == 1 + assert stats["success_rate"] == 1.0 + + def test_summary_stats_zero_total_duration(self): + """Test parallelism efficiency when total_duration is zero""" + test1 = TestExecutionMetrics( + test_id="test-1", start_time=100.0, end_time=110.0, success=True + ) + + suite = SuiteExecutionMetrics( + suite_id="suite-8", + start_time=100.0, + test_metrics=[test1], + total_duration=0.0, + ) + stats = suite.get_summary_stats() + + assert stats["parallelism_efficiency"] is None + + def test_summary_stats_none_total_duration(self): + """Test parallelism efficiency when total_duration is None""" + test1 = TestExecutionMetrics( + test_id="test-1", start_time=100.0, end_time=110.0, success=True + ) + + suite = SuiteExecutionMetrics( + suite_id="suite-9", start_time=100.0, test_metrics=[test1] + ) + stats = suite.get_summary_stats() + + assert stats["parallelism_efficiency"] is None + + def test_summary_stats_all_failed_tests(self): + """Test summary stats when all tests failed""" + test1 = TestExecutionMetrics( + test_id="test-1", start_time=100.0, end_time=110.0, success=False + ) + test2 = TestExecutionMetrics( + test_id="test-2", start_time=105.0, end_time=115.0, success=False + ) + + suite = SuiteExecutionMetrics( + suite_id="suite-10", + start_time=100.0, + test_metrics=[test1, test2], + ) + stats = suite.get_summary_stats() + + assert stats["success_rate"] == 0.0 + + def test_add_test_metrics_dynamically(self): + """Test adding test metrics to suite after creation""" + suite = SuiteExecutionMetrics(suite_id="suite-11", start_time=100.0) + + # Add tests dynamically + suite.test_metrics.append( + TestExecutionMetrics( + test_id="test-1", start_time=100.0, end_time=110.0, success=True + ) + ) + suite.test_metrics.append( + TestExecutionMetrics( + test_id="test-2", start_time=105.0, end_time=115.0, success=True + ) + ) + + stats = suite.get_summary_stats() + assert stats["total_tests"] == 2 + assert stats["completed_tests"] == 2 + diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py new file mode 100644 index 0000000..3605473 --- /dev/null +++ b/tests/test_rate_limiter.py @@ -0,0 +1,330 @@ +"""Tests for rate limiter functionality""" + +import asyncio +import pytest + +from src.test_mcp.utils.rate_limiter import RateLimiter, RequestRecord + + +class TestRequestRecord: + """Test RequestRecord namedtuple""" + + def test_create_request_record(self): + """Test creating a request record""" + record = RequestRecord( + timestamp=123.456, + tokens_used=1000, + correlation_id="test-id-123", + ) + assert record.timestamp == 123.456 + assert record.tokens_used == 1000 + assert record.correlation_id == "test-id-123" + + def test_request_record_immutability(self): + """Test that RequestRecord is immutable""" + record = RequestRecord( + timestamp=123.456, + tokens_used=1000, + correlation_id="test-id-123", + ) + with pytest.raises(AttributeError): + record.timestamp = 999.0 + + +class TestRateLimiter: + """Test RateLimiter class""" + + def test_initialization(self): + """Test RateLimiter initialization""" + limiter = RateLimiter() + assert "anthropic" in limiter.providers + assert "openai" in limiter.providers + assert "gemini" in limiter.providers + assert limiter.providers["anthropic"]["requests_per_minute"] == 5000 + assert limiter.providers["anthropic"]["tokens_per_minute"] == 100000 + + @pytest.mark.asyncio + async def test_acquire_request_slot_basic(self): + """Test acquiring a basic request slot""" + limiter = RateLimiter() + correlation_id = await limiter.acquire_request_slot("anthropic") + + assert correlation_id is not None + assert correlation_id.startswith("anthropic_") + assert "anthropic" in limiter.request_history + assert correlation_id in limiter.request_history["anthropic"] + + @pytest.mark.asyncio + async def test_acquire_request_slot_unknown_provider(self): + """Test acquiring slot for unknown provider uses defaults""" + limiter = RateLimiter() + correlation_id = await limiter.acquire_request_slot("unknown_provider") + + assert correlation_id is not None + assert correlation_id.startswith("unknown_provider_") + + @pytest.mark.asyncio + async def test_acquire_request_slot_invalid_limits(self): + """Test that invalid rate limits raise ValueError""" + limiter = RateLimiter() + limiter.providers["test_provider"] = { + "requests_per_minute": 0, + "tokens_per_minute": 100, + } + + with pytest.raises(ValueError, match="Invalid rate limits"): + await limiter.acquire_request_slot("test_provider") + + @pytest.mark.asyncio + async def test_record_token_usage_basic(self): + """Test recording token usage""" + limiter = RateLimiter() + correlation_id = await limiter.acquire_request_slot("anthropic") + + await limiter.record_token_usage(correlation_id, 500) + + assert limiter.token_usage["anthropic"] == 500 + record = limiter.request_history["anthropic"][correlation_id] + assert record.tokens_used == 500 + + @pytest.mark.asyncio + async def test_record_token_usage_negative_raises_error(self): + """Test that negative token usage raises ValueError""" + limiter = RateLimiter() + correlation_id = await limiter.acquire_request_slot("anthropic") + + with pytest.raises(ValueError, match="tokens_used must be non-negative"): + await limiter.record_token_usage(correlation_id, -100) + + @pytest.mark.asyncio + async def test_record_token_usage_unknown_correlation_id(self): + """Test recording tokens for unknown correlation ID logs warning""" + limiter = RateLimiter() + # Should not raise, just log warning + await limiter.record_token_usage("unknown-id-123", 100) + + @pytest.mark.asyncio + async def test_multiple_requests_same_provider(self): + """Test acquiring multiple request slots for same provider""" + limiter = RateLimiter() + + id1 = await limiter.acquire_request_slot("anthropic") + id2 = await limiter.acquire_request_slot("anthropic") + id3 = await limiter.acquire_request_slot("anthropic") + + assert id1 != id2 != id3 + assert len(limiter.request_history["anthropic"]) == 3 + + @pytest.mark.asyncio + async def test_multiple_providers(self): + """Test acquiring slots from different providers""" + limiter = RateLimiter() + + id_anthropic = await limiter.acquire_request_slot("anthropic") + id_openai = await limiter.acquire_request_slot("openai") + id_gemini = await limiter.acquire_request_slot("gemini") + + assert id_anthropic.startswith("anthropic_") + assert id_openai.startswith("openai_") + assert id_gemini.startswith("gemini_") + + @pytest.mark.asyncio + async def test_cleanup_old_requests(self): + """Test that old requests are cleaned up""" + limiter = RateLimiter() + + # Override window for faster testing + original_window = limiter.RATE_LIMIT_WINDOW_SECONDS + limiter.RATE_LIMIT_WINDOW_SECONDS = 0.1 + + try: + # Create a request + correlation_id = await limiter.acquire_request_slot("anthropic") + await limiter.record_token_usage(correlation_id, 500) + + # Verify it exists + assert limiter.token_usage["anthropic"] == 500 + assert len(limiter.request_history["anthropic"]) == 1 + + # Wait for cleanup window to pass + await asyncio.sleep(0.2) + + # Acquire another request (should trigger cleanup) + await limiter.acquire_request_slot("anthropic") + + # Old tokens should be cleaned + assert limiter.token_usage["anthropic"] < 500 + finally: + limiter.RATE_LIMIT_WINDOW_SECONDS = original_window + + @pytest.mark.asyncio + async def test_cleanup_pending_request(self): + """Test cleaning up a pending request""" + limiter = RateLimiter() + + correlation_id = await limiter.acquire_request_slot("anthropic") + assert correlation_id in limiter._pending_requests + + await limiter.cleanup_pending_request(correlation_id) + + assert correlation_id not in limiter._pending_requests + assert correlation_id not in limiter.request_history["anthropic"] + + @pytest.mark.asyncio + async def test_cleanup_pending_request_unknown_id(self): + """Test cleaning up unknown pending request does nothing""" + limiter = RateLimiter() + # Should not raise + await limiter.cleanup_pending_request("unknown-id-456") + + @pytest.mark.asyncio + async def test_token_accumulation(self): + """Test that token usage accumulates correctly""" + limiter = RateLimiter() + + id1 = await limiter.acquire_request_slot("anthropic") + await limiter.record_token_usage(id1, 100) + + id2 = await limiter.acquire_request_slot("anthropic") + await limiter.record_token_usage(id2, 200) + + id3 = await limiter.acquire_request_slot("anthropic") + await limiter.record_token_usage(id3, 300) + + assert limiter.token_usage["anthropic"] == 600 + + @pytest.mark.asyncio + async def test_concurrent_requests(self): + """Test handling concurrent requests""" + limiter = RateLimiter() + + async def make_request(provider): + correlation_id = await limiter.acquire_request_slot(provider) + await limiter.record_token_usage(correlation_id, 100) + return correlation_id + + # Make 10 concurrent requests + tasks = [make_request("anthropic") for _ in range(10)] + results = await asyncio.gather(*tasks) + + # All should complete successfully + assert len(results) == 10 + assert len(set(results)) == 10 # All unique IDs + assert limiter.token_usage["anthropic"] == 1000 + + @pytest.mark.asyncio + async def test_rate_limiting_enforcement(self): + """Test that rate limiting actually enforces limits""" + limiter = RateLimiter() + + # Set very restrictive limits for testing + limiter.providers["test_provider"] = { + "requests_per_minute": 2, + "tokens_per_minute": 1000, + } + limiter.RATE_LIMIT_WINDOW_SECONDS = 0.1 + limiter.CLEANUP_CHECK_INTERVAL = 0.01 + + # Acquire max allowed requests + id1 = await limiter.acquire_request_slot("test_provider") + id2 = await limiter.acquire_request_slot("test_provider") + + # Verify we hit the limit + assert len(limiter.request_history["test_provider"]) == 2 + + # Complete first two requests to free up slots + await limiter.record_token_usage(id1, 100) + await limiter.record_token_usage(id2, 100) + + # Wait for cleanup window to pass + await asyncio.sleep(0.12) + + # Third request should now succeed after cleanup + id3 = await limiter.acquire_request_slot("test_provider") + + assert id3 is not None + assert id3.startswith("test_provider_") + + @pytest.mark.asyncio + async def test_token_limit_enforcement(self): + """Test that token limits are enforced""" + limiter = RateLimiter() + + # Set very low token limit + limiter.providers["test_provider"] = { + "requests_per_minute": 100, + "tokens_per_minute": 100, + } + limiter.RATE_LIMIT_WINDOW_SECONDS = 0.1 + limiter.CLEANUP_CHECK_INTERVAL = 0.01 + + # Fill up to 80% (threshold) + id1 = await limiter.acquire_request_slot("test_provider") + await limiter.record_token_usage(id1, 80) + + # Verify we're at the threshold + assert limiter.token_usage["test_provider"] == 80 + + # Wait for tokens to clear (window + small buffer) + await asyncio.sleep(0.12) + + # Second request should now succeed after token cleanup + id2 = await limiter.acquire_request_slot("test_provider") + + assert id2 is not None + assert id2.startswith("test_provider_") + + @pytest.mark.asyncio + async def test_provider_isolation(self): + """Test that different providers are isolated""" + limiter = RateLimiter() + + # Use up one provider + id1 = await limiter.acquire_request_slot("anthropic") + await limiter.record_token_usage(id1, 50000) + + # Other provider should be unaffected + id2 = await limiter.acquire_request_slot("openai") + await limiter.record_token_usage(id2, 100) + + assert limiter.token_usage["anthropic"] == 50000 + assert limiter.token_usage["openai"] == 100 + + @pytest.mark.asyncio + async def test_correlation_id_format(self): + """Test that correlation IDs have expected format""" + limiter = RateLimiter() + correlation_id = await limiter.acquire_request_slot("anthropic") + + parts = correlation_id.split("_") + assert len(parts) >= 3 + assert parts[0] == "anthropic" + assert parts[1].isdigit() # timestamp + assert len(parts[2]) == 8 # UUID hex + + @pytest.mark.asyncio + async def test_request_timeout_cleanup(self): + """Test that timed-out requests are cleaned up""" + limiter = RateLimiter() + + # Override timeout for faster testing + original_timeout = limiter.REQUEST_TIMEOUT_SECONDS + limiter.REQUEST_TIMEOUT_SECONDS = 0.1 + limiter.RATE_LIMIT_WINDOW_SECONDS = 0.2 + + try: + # Create a request but don't complete it + correlation_id = await limiter.acquire_request_slot("anthropic") + assert correlation_id in limiter._pending_requests + + # Wait for timeout + await asyncio.sleep(0.15) + + # Trigger cleanup by acquiring another request + await limiter.acquire_request_slot("anthropic") + + # The timed-out request should be cleaned up + # (It should be removed from pending, but exact cleanup depends on implementation) + finally: + limiter.REQUEST_TIMEOUT_SECONDS = original_timeout + limiter.RATE_LIMIT_WINDOW_SECONDS = 60