diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 00000000..7abc5ac6 --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,79 @@ +name: EP CLI Benchmark + +# Run daily at 6 AM UTC (10 PM PST / 11 PM PDT) +on: + schedule: + - cron: '0 6 * * *' + workflow_dispatch: # Allow manual triggering + +jobs: + benchmark: + name: CLI Startup & Import Benchmark + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Install the project + run: uv sync --locked --all-extras --dev + + - name: Run Benchmark Tests + id: benchmark + env: + RUN_BENCHMARK_TESTS: "1" + PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning" + run: | + echo "Running EP CLI benchmark tests..." + set +e + + uv run pytest tests/test_cli_startup_benchmark.py -v --tb=short --durations=10 2>&1 | tee benchmark_output.log + + BENCHMARK_EXIT_CODE=$? + echo "benchmark_exit_code=$BENCHMARK_EXIT_CODE" >> $GITHUB_OUTPUT + + if [ $BENCHMARK_EXIT_CODE -eq 0 ]; then + echo "āœ… Benchmark tests passed" + else + echo "āŒ Benchmark tests failed" + fi + + exit $BENCHMARK_EXIT_CODE + + - name: Upload benchmark results + if: always() + uses: actions/upload-artifact@v4 + with: + name: benchmark-results-py${{ matrix.python-version }}-${{ github.run_number }} + path: benchmark_output.log + retention-days: 30 + + - name: Send failure notification to Slack + uses: act10ns/slack@v1 + if: failure() + with: + status: failure + message: | + 🐌 EP CLI Benchmark Failed (Python ${{ matrix.python-version }}) + CLI startup or import times exceeded thresholds. + Check for new imports or heavy dependencies. + Job: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index 9c26bdb1..da1760de 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -8,102 +8,159 @@ tool-augmented models using self-contained task bundles. """ +import importlib +import sys import warnings +from typing import TYPE_CHECKING -from .auth import get_fireworks_account_id, get_fireworks_api_key -from .common_utils import load_jsonl -from .config import RewardKitConfig, get_config, load_config -from .mcp_env import ( - AnthropicPolicy, - FireworksPolicy, - LiteLLMPolicy, - OpenAIPolicy, - make, - rollout, - test_mcp, -) -from .data_loader import DynamicDataLoader, InlineDataLoader -from . import mcp, rewards -from .models import EvaluateResult, Message, MetricResult, EvaluationRow, InputMetadata, Status -from .playback_policy import PlaybackPolicyBase -from .resources import create_llm_resource -from .reward_function import RewardFunction -from .typed_interface import reward_function -from .quickstart.aha_judge import aha_judge -from .utils.evaluation_row_utils import ( - multi_turn_assistant_to_ground_truth, - assistant_to_ground_truth, - filter_longest_conversation, -) -from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor -from .pytest.parameterize import DefaultParameterIdGenerator -from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler -from .log_utils.rollout_id_filter import RolloutIdFilter -from .log_utils.util import setup_rollout_logging_for_elasticsearch_handler -from .log_utils.fireworks_tracing_http_handler import FireworksTracingHttpHandler -from .log_utils.elasticsearch_client import ElasticsearchConfig - - -from .types.remote_rollout_processor import ( - InitRequest, - RolloutMetadata, - StatusResponse, - create_langfuse_config_tags, - DataLoaderConfig, -) - -try: - from .adapters import OpenAIResponsesAdapter -except ImportError: - OpenAIResponsesAdapter = None - -try: - from .adapters import LangfuseAdapter, create_langfuse_adapter -except ImportError: - LangfuseAdapter = None - -try: - from .adapters import BraintrustAdapter, create_braintrust_adapter -except ImportError: - BraintrustAdapter = None - -try: - from .adapters import LangSmithAdapter -except ImportError: - LangSmithAdapter = None - - -try: - from .adapters import WeaveAdapter -except ImportError: - WeaveAdapter = None - -try: - from .proxy import create_app, AuthProvider, AccountInfo # pyright: ignore[reportAssignmentType] -except ImportError: - - def create_app(*args, **kwargs): - raise ImportError( - "Proxy functionality requires additional dependencies. " - "Please install with: pip install eval-protocol[proxy]" - ) - - class AuthProvider: - def __init__(self, *args, **kwargs): - raise ImportError( - "Proxy functionality requires additional dependencies. " - "Please install with: pip install eval-protocol[proxy]" - ) - - class AccountInfo: - def __init__(self, *args, **kwargs): - raise ImportError( - "Proxy functionality requires additional dependencies. " - "Please install with: pip install eval-protocol[proxy]" - ) +warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol") +# Eager imports for symbols that conflict with module names - ONLY when pytest is running. +# The reward_function.py module exports RewardFunction class, and we also export the +# reward_function decorator from typed_interface. When pytest's AssertionRewritingHook +# imports eval_protocol.reward_function as a module, Python would set +# eval_protocol.reward_function to the module, shadowing our function export. +# +# We detect pytest by checking if _pytest or pytest is already loaded. This avoids +# the ~500ms import overhead for non-test scenarios like the CLI. +_running_under_pytest = "_pytest" in sys.modules or "pytest" in sys.modules +if _running_under_pytest: + from .reward_function import RewardFunction # noqa: E402 + from .typed_interface import reward_function # noqa: E402 + +# Lazy import mappings: name -> (module_path, attribute_name or None for module import) +_LAZY_IMPORTS = { + # From .auth + "get_fireworks_account_id": (".auth", "get_fireworks_account_id"), + "get_fireworks_api_key": (".auth", "get_fireworks_api_key"), + # From .common_utils + "load_jsonl": (".common_utils", "load_jsonl"), + # From .config + "RewardKitConfig": (".config", "RewardKitConfig"), + "get_config": (".config", "get_config"), + "load_config": (".config", "load_config"), + # From .mcp_env + "AnthropicPolicy": (".mcp_env", "AnthropicPolicy"), + "FireworksPolicy": (".mcp_env", "FireworksPolicy"), + "LiteLLMPolicy": (".mcp_env", "LiteLLMPolicy"), + "OpenAIPolicy": (".mcp_env", "OpenAIPolicy"), + "make": (".mcp_env", "make"), + "rollout": (".mcp_env", "rollout"), + "test_mcp": (".mcp_env", "test_mcp"), + # From .data_loader + "DynamicDataLoader": (".data_loader", "DynamicDataLoader"), + "InlineDataLoader": (".data_loader", "InlineDataLoader"), + # Submodules (accessible as eval_protocol.submodule) + "mcp": (".mcp", None), + "rewards": (".rewards", None), + "models": (".models", None), + "auth": (".auth", None), + "config": (".config", None), + # From .models + "EvaluateResult": (".models", "EvaluateResult"), + "Message": (".models", "Message"), + "MetricResult": (".models", "MetricResult"), + "EvaluationRow": (".models", "EvaluationRow"), + "InputMetadata": (".models", "InputMetadata"), + "Status": (".models", "Status"), + # From .playback_policy + "PlaybackPolicyBase": (".playback_policy", "PlaybackPolicyBase"), + # From .resources + "create_llm_resource": (".resources", "create_llm_resource"), + # From .reward_function + "RewardFunction": (".reward_function", "RewardFunction"), + # From .typed_interface + "reward_function": (".typed_interface", "reward_function"), + # From .quickstart.aha_judge + "aha_judge": (".quickstart.aha_judge", "aha_judge"), + # From .utils.evaluation_row_utils + "multi_turn_assistant_to_ground_truth": (".utils.evaluation_row_utils", "multi_turn_assistant_to_ground_truth"), + "assistant_to_ground_truth": (".utils.evaluation_row_utils", "assistant_to_ground_truth"), + "filter_longest_conversation": (".utils.evaluation_row_utils", "filter_longest_conversation"), + # From .pytest + "evaluation_test": (".pytest", "evaluation_test"), + "SingleTurnRolloutProcessor": (".pytest", "SingleTurnRolloutProcessor"), + "RemoteRolloutProcessor": (".pytest", "RemoteRolloutProcessor"), + "GithubActionRolloutProcessor": (".pytest", "GithubActionRolloutProcessor"), + # From .pytest.parameterize + "DefaultParameterIdGenerator": (".pytest.parameterize", "DefaultParameterIdGenerator"), + # From .log_utils + "ElasticsearchDirectHttpHandler": ( + ".log_utils.elasticsearch_direct_http_handler", + "ElasticsearchDirectHttpHandler", + ), + "RolloutIdFilter": (".log_utils.rollout_id_filter", "RolloutIdFilter"), + "setup_rollout_logging_for_elasticsearch_handler": ( + ".log_utils.util", + "setup_rollout_logging_for_elasticsearch_handler", + ), + "FireworksTracingHttpHandler": (".log_utils.fireworks_tracing_http_handler", "FireworksTracingHttpHandler"), + "ElasticsearchConfig": (".log_utils.elasticsearch_client", "ElasticsearchConfig"), + # From .types.remote_rollout_processor + "InitRequest": (".types.remote_rollout_processor", "InitRequest"), + "RolloutMetadata": (".types.remote_rollout_processor", "RolloutMetadata"), + "StatusResponse": (".types.remote_rollout_processor", "StatusResponse"), + "create_langfuse_config_tags": (".types.remote_rollout_processor", "create_langfuse_config_tags"), + "DataLoaderConfig": (".types.remote_rollout_processor", "DataLoaderConfig"), +} + +# Optional imports that may not be available +_OPTIONAL_IMPORTS = { + "OpenAIResponsesAdapter": (".adapters", "OpenAIResponsesAdapter"), + "LangfuseAdapter": (".adapters", "LangfuseAdapter"), + "create_langfuse_adapter": (".adapters", "create_langfuse_adapter"), + "BraintrustAdapter": (".adapters", "BraintrustAdapter"), + "create_braintrust_adapter": (".adapters", "create_braintrust_adapter"), + "LangSmithAdapter": (".adapters", "LangSmithAdapter"), + "WeaveAdapter": (".adapters", "WeaveAdapter"), + "create_app": (".proxy", "create_app"), + "AuthProvider": (".proxy", "AuthProvider"), + "AccountInfo": (".proxy", "AccountInfo"), +} + + +def __getattr__(name: str): + """Lazy import handler for module-level attributes.""" + # Check regular lazy imports + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + module = importlib.import_module(module_path, package="eval_protocol") + if attr_name is None: + # It's a submodule import + return module + return getattr(module, attr_name) + + # Check optional imports + if name in _OPTIONAL_IMPORTS: + module_path, attr_name = _OPTIONAL_IMPORTS[name] + try: + module = importlib.import_module(module_path, package="eval_protocol") + return getattr(module, attr_name) + except ImportError: + # Return None or a placeholder for optional imports + if name in ("create_app",): + + def create_app(*args, **kwargs): + raise ImportError( + "Proxy functionality requires additional dependencies. " + "Please install with: pip install eval-protocol[proxy]" + ) + + return create_app + elif name in ("AuthProvider", "AccountInfo"): + + class _Placeholder: + def __init__(self, *args, **kwargs): + raise ImportError( + "Proxy functionality requires additional dependencies. " + "Please install with: pip install eval-protocol[proxy]" + ) + + return _Placeholder + return None + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol") __all__ = [ "ElasticsearchConfig", @@ -173,6 +230,85 @@ def __init__(self, *args, **kwargs): "AccountInfo", ] -from . import _version +# Version is loaded lazily too +_version_info = None + + +def _get_version(): + global _version_info + if _version_info is None: + from . import _version + + _version_info = _version.get_versions()["version"] + return _version_info + + +# For TYPE_CHECKING, we provide type hints so IDEs can see the exports +if TYPE_CHECKING: + from .auth import get_fireworks_account_id, get_fireworks_api_key + from .common_utils import load_jsonl + from .config import RewardKitConfig, get_config, load_config + from .mcp_env import ( + AnthropicPolicy, + FireworksPolicy, + LiteLLMPolicy, + OpenAIPolicy, + make, + rollout, + test_mcp, + ) + from .data_loader import DynamicDataLoader, InlineDataLoader + from . import mcp, rewards + from .models import EvaluateResult, Message, MetricResult, EvaluationRow, InputMetadata, Status + from .playback_policy import PlaybackPolicyBase + from .resources import create_llm_resource + from .reward_function import RewardFunction + from .typed_interface import reward_function + from .quickstart.aha_judge import aha_judge + from .utils.evaluation_row_utils import ( + multi_turn_assistant_to_ground_truth, + assistant_to_ground_truth, + filter_longest_conversation, + ) + from .pytest import ( + evaluation_test, + SingleTurnRolloutProcessor, + RemoteRolloutProcessor, + GithubActionRolloutProcessor, + ) + from .pytest.parameterize import DefaultParameterIdGenerator + from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler + from .log_utils.rollout_id_filter import RolloutIdFilter + from .log_utils.util import setup_rollout_logging_for_elasticsearch_handler + from .log_utils.fireworks_tracing_http_handler import FireworksTracingHttpHandler + from .log_utils.elasticsearch_client import ElasticsearchConfig + from .types.remote_rollout_processor import ( + InitRequest, + RolloutMetadata, + StatusResponse, + create_langfuse_config_tags, + DataLoaderConfig, + ) + from .adapters import ( + OpenAIResponsesAdapter, + LangfuseAdapter, + create_langfuse_adapter, + BraintrustAdapter, + create_braintrust_adapter, + LangSmithAdapter, + WeaveAdapter, + ) + from .proxy import create_app, AuthProvider, AccountInfo + + +# __version__ property - computed lazily +class _VersionModule: + @property + def __version__(self): + return _get_version() + + +import sys -__version__ = _version.get_versions()["version"] +_this_module = sys.modules[__name__] +_this_module.__class__ = type("module", (type(_this_module),), {"__version__": property(lambda self: _get_version())}) diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index b6827dd9..1990042e 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -1,44 +1,103 @@ -from .default_agent_rollout_processor import AgentRolloutProcessor -from .default_dataset_adapter import default_dataset_adapter -from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor -from .default_no_op_rollout_processor import NoOpRolloutProcessor -from .default_single_turn_rollout_process import SingleTurnRolloutProcessor -from .remote_rollout_processor import RemoteRolloutProcessor -from .github_action_rollout_processor import GithubActionRolloutProcessor -from .evaluation_test import evaluation_test -from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config -from .rollout_processor import RolloutProcessor -from .rollout_result_post_processor import RolloutResultPostProcessor, NoOpRolloutResultPostProcessor -from .types import RolloutProcessorConfig - -# Conditional import for optional Klavis dependency -try: - from .default_klavis_sandbox_rollout_processor import KlavisSandboxRolloutProcessor - - KLAVIS_AVAILABLE = True -except ImportError: - KLAVIS_AVAILABLE = False - KlavisSandboxRolloutProcessor = None - -# Conditional import for optional dependencies -try: - from .default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor - - PYDANTIC_AI_AVAILABLE = True -except ImportError: - PYDANTIC_AI_AVAILABLE = False - PydanticAgentRolloutProcessor = None - -# Conditional import for optional LangChain dependency -try: - from .default_langchain_rollout_processor import LangGraphRolloutProcessor - - LANGCHAIN_AVAILABLE = True -except ImportError: - LANGCHAIN_AVAILABLE = False - LangGraphRolloutProcessor = None +""" +eval_protocol.pytest - Pytest integration for evaluation testing. + +This module uses lazy loading to minimize import time. +Heavy dependencies (litellm, torch, etc.) are only loaded when needed. +""" + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + +# Lazy imports mapping: name -> (module_path, attr_name) +# These are loaded on-demand when accessed +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + # Rollout processors + "AgentRolloutProcessor": (".default_agent_rollout_processor", "AgentRolloutProcessor"), + "MCPGymRolloutProcessor": (".default_mcp_gym_rollout_processor", "MCPGymRolloutProcessor"), + "NoOpRolloutProcessor": (".default_no_op_rollout_processor", "NoOpRolloutProcessor"), + "SingleTurnRolloutProcessor": (".default_single_turn_rollout_process", "SingleTurnRolloutProcessor"), + "RemoteRolloutProcessor": (".remote_rollout_processor", "RemoteRolloutProcessor"), + "GithubActionRolloutProcessor": (".github_action_rollout_processor", "GithubActionRolloutProcessor"), + "RolloutProcessor": (".rollout_processor", "RolloutProcessor"), + # Dataset adapter + "default_dataset_adapter": (".default_dataset_adapter", "default_dataset_adapter"), + # Core decorator + "evaluation_test": (".evaluation_test", "evaluation_test"), + # Exception handling + "ExceptionHandlerConfig": (".exception_config", "ExceptionHandlerConfig"), + "BackoffConfig": (".exception_config", "BackoffConfig"), + "get_default_exception_handler_config": (".exception_config", "get_default_exception_handler_config"), + # Post processors + "RolloutResultPostProcessor": (".rollout_result_post_processor", "RolloutResultPostProcessor"), + "NoOpRolloutResultPostProcessor": (".rollout_result_post_processor", "NoOpRolloutResultPostProcessor"), + # Types + "RolloutProcessorConfig": (".types", "RolloutProcessorConfig"), +} + +# Optional imports that may not be available +_OPTIONAL_IMPORTS: dict[str, tuple[str, str]] = { + "KlavisSandboxRolloutProcessor": (".default_klavis_sandbox_rollout_processor", "KlavisSandboxRolloutProcessor"), + "PydanticAgentRolloutProcessor": (".default_pydantic_ai_rollout_processor", "PydanticAgentRolloutProcessor"), + "LangGraphRolloutProcessor": (".default_langchain_rollout_processor", "LangGraphRolloutProcessor"), +} + +# Track which optional imports are available (set on first access) +_optional_availability: dict[str, bool] = {} + + +def __getattr__(name: str): + """Lazy load attributes on first access.""" + # Handle lazy imports + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + module = importlib.import_module(module_path, package="eval_protocol.pytest") + value = getattr(module, attr_name) + # Cache in module namespace for future access + globals()[name] = value + return value + + # Handle optional imports + if name in _OPTIONAL_IMPORTS: + module_path, attr_name = _OPTIONAL_IMPORTS[name] + try: + module = importlib.import_module(module_path, package="eval_protocol.pytest") + value = getattr(module, attr_name) + globals()[name] = value + _optional_availability[name] = True + return value + except ImportError: + _optional_availability[name] = False + return None + + # Handle availability flags + if name == "KLAVIS_AVAILABLE": + if "KlavisSandboxRolloutProcessor" not in _optional_availability: + # Trigger the import to check availability + __getattr__("KlavisSandboxRolloutProcessor") + return _optional_availability.get("KlavisSandboxRolloutProcessor", False) + + if name == "PYDANTIC_AI_AVAILABLE": + if "PydanticAgentRolloutProcessor" not in _optional_availability: + __getattr__("PydanticAgentRolloutProcessor") + return _optional_availability.get("PydanticAgentRolloutProcessor", False) + + if name == "LANGCHAIN_AVAILABLE": + if "LangGraphRolloutProcessor" not in _optional_availability: + __getattr__("LangGraphRolloutProcessor") + return _optional_availability.get("LangGraphRolloutProcessor", False) + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + """List available attributes for tab completion.""" + return list(__all__) + ["KLAVIS_AVAILABLE", "PYDANTIC_AI_AVAILABLE", "LANGCHAIN_AVAILABLE"] + __all__ = [ + # Rollout processors "AgentRolloutProcessor", "MCPGymRolloutProcessor", "RolloutProcessor", @@ -46,23 +105,53 @@ "RemoteRolloutProcessor", "GithubActionRolloutProcessor", "NoOpRolloutProcessor", + # Dataset "default_dataset_adapter", + # Types "RolloutProcessorConfig", + # Core "evaluation_test", + # Exception handling "ExceptionHandlerConfig", "BackoffConfig", "get_default_exception_handler_config", + # Post processors "RolloutResultPostProcessor", "NoOpRolloutResultPostProcessor", + # Optional (may be None if dependencies not installed) + "KlavisSandboxRolloutProcessor", + "PydanticAgentRolloutProcessor", + "LangGraphRolloutProcessor", ] -# Only add to __all__ if available -if KLAVIS_AVAILABLE: - __all__.append("KlavisSandboxRolloutProcessor") - -# Only add to __all__ if available -if PYDANTIC_AI_AVAILABLE: - __all__.append("PydanticAgentRolloutProcessor") -if LANGCHAIN_AVAILABLE: - __all__.append("LangGraphRolloutProcessor") +# Type hints for IDE support (not executed at runtime) +if TYPE_CHECKING: + from .default_agent_rollout_processor import AgentRolloutProcessor as AgentRolloutProcessor + from .default_dataset_adapter import default_dataset_adapter as default_dataset_adapter + from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor as MCPGymRolloutProcessor + from .default_no_op_rollout_processor import NoOpRolloutProcessor as NoOpRolloutProcessor + from .default_single_turn_rollout_process import SingleTurnRolloutProcessor as SingleTurnRolloutProcessor + from .remote_rollout_processor import RemoteRolloutProcessor as RemoteRolloutProcessor + from .github_action_rollout_processor import GithubActionRolloutProcessor as GithubActionRolloutProcessor + from .evaluation_test import evaluation_test as evaluation_test + from .exception_config import ( + ExceptionHandlerConfig as ExceptionHandlerConfig, + BackoffConfig as BackoffConfig, + get_default_exception_handler_config as get_default_exception_handler_config, + ) + from .rollout_processor import RolloutProcessor as RolloutProcessor + from .rollout_result_post_processor import ( + RolloutResultPostProcessor as RolloutResultPostProcessor, + NoOpRolloutResultPostProcessor as NoOpRolloutResultPostProcessor, + ) + from .types import RolloutProcessorConfig as RolloutProcessorConfig + from .default_klavis_sandbox_rollout_processor import ( + KlavisSandboxRolloutProcessor as KlavisSandboxRolloutProcessor, + ) + from .default_pydantic_ai_rollout_processor import ( + PydanticAgentRolloutProcessor as PydanticAgentRolloutProcessor, + ) + from .default_langchain_rollout_processor import ( + LangGraphRolloutProcessor as LangGraphRolloutProcessor, + ) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index eb05aa35..2256b35f 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -34,9 +34,9 @@ from eval_protocol.pytest.parameterize import pytest_parametrize, create_dynamically_parameterized_wrapper from eval_protocol.pytest.validate_signature import validate_signature from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter -from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor + +# Note: MCPGymRolloutProcessor and SingleTurnRolloutProcessor are imported lazily to avoid loading litellm (~1300ms) from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor -from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor from eval_protocol.pytest.exception_config import ExceptionHandlerConfig from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( @@ -188,6 +188,9 @@ def evaluation_test( if os.environ.get("EP_USE_NO_OP_ROLLOUT_PROCESSOR") == "1": rollout_processor = NoOpRolloutProcessor() elif rollout_processor is None: + # Lazy import to avoid loading litellm at decorator definition time + from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor + rollout_processor = SingleTurnRolloutProcessor() active_logger: DatasetLogger = logger if logger else default_logger @@ -410,6 +413,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo rollout_processor.setup() + # Lazy import to avoid loading litellm at module load time + from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor + use_priority_scheduler = os.environ.get( "EP_USE_PRIORITY_SCHEDULER", "0" ) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor) @@ -688,6 +694,9 @@ async def _collect_result(config, lst): # if rollout_processor is McpGymRolloutProcessor, we execute runs sequentially since McpGym does not support concurrent runs # else, we execute runs in parallel + # Lazy import (cached after first import above) + from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor + if isinstance(rollout_processor, MCPGymRolloutProcessor): # For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts for run_idx in range(num_runs): diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index b0ebd235..5ab80364 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -6,9 +6,11 @@ from dataclasses import replace from typing import Any, Literal, Callable, AsyncGenerator, Optional -from litellm.cost_calculator import cost_per_token from tqdm import tqdm +# Note: litellm.cost_calculator.cost_per_token is imported lazily in add_cost_metrics() +# to avoid ~1300ms import time at module load + from eval_protocol.dataset_logger.dataset_logger import DatasetLogger from eval_protocol.models import ( CostMetrics, @@ -22,7 +24,8 @@ from eval_protocol.data_loader import DynamicDataLoader from eval_protocol.data_loader.models import EvaluationDataLoader from eval_protocol.pytest.rollout_processor import RolloutProcessor -from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor + +# Note: MCPGymRolloutProcessor is imported lazily in validate_config_for_processor() to avoid loading litellm from eval_protocol.pytest.types import ( RolloutProcessorConfig, ServerMode, @@ -551,6 +554,9 @@ def add_cost_metrics(row: EvaluationRow) -> None: # Try to calculate costs, but gracefully handle unknown models try: + # Lazy import to avoid ~1300ms import time at module load + from litellm.cost_calculator import cost_per_token + input_cost, output_cost = cost_per_token( model=model_id, prompt_tokens=input_tokens, completion_tokens=output_tokens ) @@ -605,6 +611,9 @@ def build_rollout_processor_config( completion_params = {"model": model, "temperature": temperature, "max_tokens": max_tokens} + # Lazy import to avoid loading litellm at module load time + from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor + if isinstance(rollout_processor, MCPGymRolloutProcessor): base_kwargs = {**(rollout_processor_kwargs or {}), "start_server": start_server} if server_mode is not None and "server_mode" not in base_kwargs: diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py index a2244b2a..6511252b 100644 --- a/eval_protocol/pytest/exception_config.py +++ b/eval_protocol/pytest/exception_config.py @@ -1,54 +1,87 @@ """ Exception handling configuration for rollout processors with backoff retry logic. + +This module intentionally avoids importing heavy deps (litellm/requests/httpx) +at module import time to keep `@evaluation_test` import fast. """ import os from dataclasses import dataclass, field -from typing import Callable, Dict, Set, Type, Union +from typing import Callable, Set, Type, Union import backoff - -import litellm -import requests -import httpx - import eval_protocol.exceptions - -# Default exceptions that should be retried with backoff -DEFAULT_RETRYABLE_EXCEPTIONS: Set[Type[Exception]] = { - # Standard library exceptions - ConnectionError, - TimeoutError, - OSError, # Covers network-related OS errors - # Requests library exceptions - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - requests.exceptions.HTTPError, - requests.exceptions.RequestException, - # HTTPX library exceptions - httpx.ConnectError, - httpx.TimeoutException, - httpx.NetworkError, - httpx.RemoteProtocolError, - # LiteLLM library exceptions - litellm.exceptions.RateLimitError, - litellm.exceptions.InternalServerError, - litellm.exceptions.Timeout, - litellm.exceptions.NotFoundError, - litellm.exceptions.ServiceUnavailableError, - litellm.exceptions.APIError, - litellm.exceptions.BadRequestError, - # Eval Protocol exceptions - eval_protocol.exceptions.UnknownError, - eval_protocol.exceptions.DeadlineExceededError, - eval_protocol.exceptions.NotFoundError, - eval_protocol.exceptions.PermissionDeniedError, - eval_protocol.exceptions.UnavailableError, - eval_protocol.exceptions.UnauthenticatedError, - eval_protocol.exceptions.ResourceExhaustedError, - eval_protocol.exceptions.ResponseQualityError, -} +# Cache for the default retryable exceptions (populated on first access) +_default_retryable_exceptions: Set[Type[Exception]] | None = None + + +def get_default_retryable_exceptions() -> Set[Type[Exception]]: + """Compute the default set of retryable exceptions (lazy heavy imports).""" + global _default_retryable_exceptions + if _default_retryable_exceptions is not None: + return _default_retryable_exceptions + + # Lazy imports (these are expensive) + import httpx + import litellm + import requests + + _default_retryable_exceptions = { + # Standard library exceptions + ConnectionError, # type: ignore[assignment] + TimeoutError, # type: ignore[assignment] + OSError, # type: ignore[assignment] # Covers network-related OS errors + # Requests library exceptions + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.HTTPError, + requests.exceptions.RequestException, + # HTTPX library exceptions + httpx.ConnectError, + httpx.TimeoutException, + httpx.NetworkError, + httpx.RemoteProtocolError, + # LiteLLM library exceptions + litellm.exceptions.RateLimitError, + litellm.exceptions.InternalServerError, + litellm.exceptions.Timeout, + litellm.exceptions.NotFoundError, + litellm.exceptions.ServiceUnavailableError, + litellm.exceptions.APIError, + litellm.exceptions.BadRequestError, + # Eval Protocol exceptions + eval_protocol.exceptions.UnknownError, + eval_protocol.exceptions.DeadlineExceededError, + eval_protocol.exceptions.NotFoundError, + eval_protocol.exceptions.PermissionDeniedError, + eval_protocol.exceptions.UnavailableError, + eval_protocol.exceptions.UnauthenticatedError, + eval_protocol.exceptions.ResourceExhaustedError, + eval_protocol.exceptions.ResponseQualityError, + } + + return _default_retryable_exceptions + + +class _LazyDefaultRetryableExceptions(Set[Type[Exception]]): + """Set-like view that materializes the default exception set on first use.""" + + def __iter__(self): + return iter(get_default_retryable_exceptions()) + + def __len__(self) -> int: + return len(get_default_retryable_exceptions()) + + def __contains__(self, x: object) -> bool: + return x in get_default_retryable_exceptions() + + def copy(self) -> Set[Type[Exception]]: + return set(get_default_retryable_exceptions()) + + +# Backwards compatible name: behaves like a set but doesn't import heavy deps until used +DEFAULT_RETRYABLE_EXCEPTIONS: Set[Type[Exception]] = _LazyDefaultRetryableExceptions() @dataclass @@ -68,7 +101,8 @@ class BackoffConfig: max_tries: int = 3 # Jitter: adds randomness to backoff delays (None = no jitter for predictable timing) - jitter: Union[None, Callable] = None + # Backoff's jitter expects a function like `lambda value: float` + jitter: Union[None, Callable[[float], float]] = None # Factor for exponential backoff (only used if strategy == 'expo') factor: float = 2.0 @@ -81,7 +115,7 @@ class BackoffConfig: def get_backoff_decorator(self, exceptions: Set[Type[Exception]]): """Get the appropriate backoff decorator based on configuration. - + Args: exceptions: Set of exception types to retry """ @@ -123,7 +157,8 @@ class ExceptionHandlerConfig: """Configuration for exception handling in rollout processors.""" # Exceptions that should be retried using backoff - retryable_exceptions: Set[Type[Exception]] = field(default_factory=lambda: DEFAULT_RETRYABLE_EXCEPTIONS.copy()) + # Use field with default_factory to lazily get the exceptions + retryable_exceptions: Set[Type[Exception]] = field(default_factory=lambda: set(get_default_retryable_exceptions())) # Backoff configuration backoff_config: BackoffConfig = field(default_factory=BackoffConfig) @@ -141,9 +176,7 @@ def __post_init__(self): def get_backoff_decorator(self): """Get the backoff decorator configured for this exception handler.""" - return self.backoff_config.get_backoff_decorator( - self.retryable_exceptions - ) + return self.backoff_config.get_backoff_decorator(self.retryable_exceptions) def get_default_exception_handler_config() -> ExceptionHandlerConfig: diff --git a/pytest.ini b/pytest.ini index b3c84ce1..18a40408 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,7 @@ [pytest] markers = asyncio + benchmark: marks tests as benchmark tests (for CLI startup time checks) asyncio_mode = auto asyncio_default_fixture_loop_scope = function testpaths = tests ./eval_protocol/quickstart diff --git a/tests/test_cli_startup_benchmark.py b/tests/test_cli_startup_benchmark.py new file mode 100644 index 00000000..0e991c7c --- /dev/null +++ b/tests/test_cli_startup_benchmark.py @@ -0,0 +1,168 @@ +""" +Benchmark test for CLI startup time and evaluation_test import time. + +These are smoke tests that run on schedule (not on every PR) to catch performance regressions. +Run manually with: RUN_BENCHMARK_TESTS=1 pytest tests/test_cli_startup_benchmark.py -v +""" + +import os +import subprocess +import sys +import time + +import pytest + +# Skip benchmark tests unless explicitly enabled via environment variable +# This prevents flaky failures from blocking PRs +SKIP_BENCHMARK = os.environ.get("RUN_BENCHMARK_TESTS", "0") != "1" +SKIP_REASON = "Benchmark tests only run when RUN_BENCHMARK_TESTS=1 (scheduled smoke tests)" + +# Target: CLI should start in under 1.5 seconds (CI runners are slower) +CLI_STARTUP_TARGET_SECONDS = 1.5 + +# Target: evaluation_test import should be under 10.0 seconds (CI runners can be very slow) +EVALUATION_TEST_IMPORT_TARGET_SECONDS = 10.0 + +# Number of runs to average (first run may be slower due to cold cache) +NUM_RUNS = 3 + + +def measure_cli_startup_time() -> float: + """Measure CLI --help startup time in seconds.""" + start = time.perf_counter() + result = subprocess.run( + [sys.executable, "-m", "eval_protocol.cli", "--help"], + capture_output=True, + text=True, + env={**dict(os.environ), "FIREWORKS_API_KEY": "benchmark-test-key"}, + ) + elapsed = time.perf_counter() - start + + # Ensure the command succeeded + assert result.returncode == 0, f"CLI failed: {result.stderr}" + + return elapsed + + +@pytest.mark.benchmark +@pytest.mark.skipif(SKIP_BENCHMARK, reason=SKIP_REASON) +def test_cli_startup_time(): + """Test that CLI startup time is under the target threshold.""" + times = [] + + for i in range(NUM_RUNS): + elapsed = measure_cli_startup_time() + times.append(elapsed) + print(f" Run {i + 1}: {elapsed:.3f}s") + + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + + print(f"\n Average: {avg_time:.3f}s") + print(f" Min: {min_time:.3f}s") + print(f" Max: {max_time:.3f}s") + print(f" Target: {CLI_STARTUP_TARGET_SECONDS}s") + + # Use the best time (min) as some CI environments have variable overhead + assert min_time < CLI_STARTUP_TARGET_SECONDS, ( + f"CLI startup time ({min_time:.3f}s) exceeds target ({CLI_STARTUP_TARGET_SECONDS}s)." + ) + + +@pytest.mark.benchmark +@pytest.mark.skipif(SKIP_BENCHMARK, reason=SKIP_REASON) +def test_package_import_time(): + """Test that importing eval_protocol package is fast (lazy loading check).""" + # Use subprocess to get a clean import measurement + code = """ +import time +start = time.perf_counter() +import eval_protocol +elapsed = time.perf_counter() - start +print(f"{elapsed:.6f}") +""" + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + ) + + assert result.returncode == 0, f"Import failed: {result.stderr}" + + import_time = float(result.stdout.strip()) + print(f"\n Package import time: {import_time * 1000:.1f}ms") + + # Package import should be very fast with lazy loading (< 100ms for CI) + assert import_time < 0.1, f"Package import time ({import_time * 1000:.1f}ms) is too slow." + + +@pytest.mark.benchmark +@pytest.mark.skipif(SKIP_BENCHMARK, reason=SKIP_REASON) +def test_evaluation_test_import_time(): + """Test that importing evaluation_test decorator is under the target threshold.""" + code = """ +import sys +import time +start = time.perf_counter() +from eval_protocol import evaluation_test +elapsed = time.perf_counter() - start +litellm_loaded = "litellm" in sys.modules +print(f"{elapsed:.6f}") +print(f"{litellm_loaded}") +""" + times = [] + + for i in range(NUM_RUNS): + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + ) + + assert result.returncode == 0, f"Import failed: {result.stderr}" + + lines = result.stdout.strip().split("\n") + import_time = float(lines[0]) + litellm_loaded = lines[1] == "True" + times.append(import_time) + print(f" Run {i + 1}: {import_time:.3f}s (litellm loaded: {litellm_loaded})") + + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + + print(f"\n Average: {avg_time:.3f}s") + print(f" Min: {min_time:.3f}s") + print(f" Max: {max_time:.3f}s") + print(f" Target: {EVALUATION_TEST_IMPORT_TARGET_SECONDS}s") + + # Use the best time (min) as some CI environments have variable overhead + assert min_time < EVALUATION_TEST_IMPORT_TARGET_SECONDS, ( + f"evaluation_test import time ({min_time:.3f}s) exceeds target ({EVALUATION_TEST_IMPORT_TARGET_SECONDS}s)." + ) + + +if __name__ == "__main__": + # When run directly, always execute (ignore SKIP_BENCHMARK) + print("=== CLI Startup Benchmark ===\n") + + print("Testing CLI startup time...") + times = [] + for i in range(NUM_RUNS): + elapsed = measure_cli_startup_time() + times.append(elapsed) + print(f" Run {i + 1}: {elapsed:.3f}s") + + avg_time = sum(times) / len(times) + min_time = min(times) + + print(f"\n Average: {avg_time:.3f}s") + print(f" Best: {min_time:.3f}s") + print(f" Target: {CLI_STARTUP_TARGET_SECONDS}s") + + if min_time < CLI_STARTUP_TARGET_SECONDS: + print(f"\nāœ“ PASS: CLI startup ({min_time:.3f}s) is under target ({CLI_STARTUP_TARGET_SECONDS}s)") + else: + print(f"\nāœ— FAIL: CLI startup ({min_time:.3f}s) exceeds target ({CLI_STARTUP_TARGET_SECONDS}s)") + sys.exit(1)