diff --git a/tests/end_to_end/test_all_datasets.py b/tests/end_to_end/test_all_datasets.py new file mode 100644 index 0000000000..269269cdce --- /dev/null +++ b/tests/end_to_end/test_all_datasets.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +End-to-end tests that verify every registered dataset provider can be fetched. + +These tests download real data from HuggingFace and GitHub, are slow, and are +subject to transient network failures. They are intended to run daily in e2e CI, +not on every PR. + +Resiliency: each fetch is retried up to 3 times with exponential backoff to +handle transient HuggingFace / GitHub rate-limiting and network errors. +""" + +import asyncio +import logging +import os + +import pytest +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential + +from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.remote import ( + _HarmBenchMultimodalDataset, + _PromptIntelDataset, + _VLSUMultimodalDataset, +) +from pyrit.models import SeedDataset +from pyrit.setup import IN_MEMORY, initialize_pyrit_async + +logger = logging.getLogger(__name__) + +# Per-test timeout in seconds (5 minutes per dataset) +_TEST_TIMEOUT = 300 + +# Transient error types that warrant a retry +_RETRYABLE_ERRORS = (OSError, ConnectionError, TimeoutError) + +# Providers that download many remote images; each image fetch may fail +# due to rate-limiting, so an empty result is expected in some environments. +_IMAGE_FETCHING_PROVIDERS: set[type] = {_HarmBenchMultimodalDataset, _VLSUMultimodalDataset} + + +def get_dataset_providers(): + """Helper to get all registered providers for parameterization.""" + providers = SeedDatasetProvider.get_all_providers() + return [(name, cls) for name, cls in providers.items()] + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=5, min=5, max=60), + retry=retry_if_exception_type(_RETRYABLE_ERRORS), + reraise=True, +) +async def _fetch_with_retry(provider) -> SeedDataset: + """Fetch a dataset with retry on transient network errors.""" + return await provider.fetch_dataset(cache=False) + + +@pytest.fixture(scope="module", autouse=True) +def _init_memory(): + """Multimodal providers need CentralMemory to save downloaded images.""" + asyncio.run(initialize_pyrit_async(memory_db_type=IN_MEMORY)) + + +class TestAllDatasets: + """Exhaustive test that every registered dataset provider can be fetched.""" + + @pytest.mark.asyncio + @pytest.mark.timeout(_TEST_TIMEOUT) + @pytest.mark.parametrize("name,provider_cls", get_dataset_providers()) + async def test_fetch_dataset(self, name, provider_cls): + """ + Verify that a specific registered dataset can be fetched. + + This test is parameterized to run for each registered provider. + It verifies that: + 1. The dataset can be downloaded/loaded without error + 2. The result is a SeedDataset + 3. The dataset is not empty (has seeds) + + Retries up to 3 times on transient network errors. + """ + # Skip providers that require credentials not available in CI + if provider_cls == _PromptIntelDataset and not os.environ.get("PROMPTINTEL_API_KEY"): + pytest.skip("PROMPTINTEL_API_KEY not set") + + logger.info(f"Testing provider: {name}") + + try: + # Limit examples for slow multimodal providers that fetch many remote images + provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() + + dataset = await _fetch_with_retry(provider) + except Exception as e: + # Multimodal providers silently skip failed image downloads. When ALL + # images fail the resulting empty seed list triggers "SeedDataset cannot + # be empty". That is a transient environment issue, not a code bug. + if provider_cls in _IMAGE_FETCHING_PROVIDERS and "cannot be empty" in str(e): + pytest.skip(f"{name}: all image downloads failed ({e})") + pytest.fail(f"Failed to fetch dataset from {name}: {e}") + + assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" + assert dataset.dataset_name, f"{name} has no dataset_name" + assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" + + for seed in dataset.seeds: + assert seed.value, f"Seed in {name} has no value" + assert seed.dataset_name == dataset.dataset_name, ( + f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}" + ) + + logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds") diff --git a/tests/integration/datasets/test_load_default_datasets_integration.py b/tests/integration/datasets/test_load_default_datasets_integration.py new file mode 100644 index 0000000000..8b600d378d --- /dev/null +++ b/tests/integration/datasets/test_load_default_datasets_integration.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Integration test for the LoadDefaultDatasets initializer. + +Runs the full pipeline: discovers scenario default datasets, fetches them +from real remote sources, and stores them in in-memory CentralMemory. +""" + +import logging + +import pytest + +from pyrit.memory import CentralMemory +from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets + +logger = logging.getLogger(__name__) + + +class TestLoadDefaultDatasetsIntegration: + """Integration test that LoadDefaultDatasets loads real datasets into memory.""" + + @pytest.mark.asyncio + async def test_initialize_loads_datasets_into_memory(self): + """ + Verify that LoadDefaultDatasets.initialize_async() successfully fetches + real datasets and stores them in CentralMemory. + """ + initializer = LoadDefaultDatasets() + await initializer.initialize_async() + + memory = CentralMemory.get_memory_instance() + dataset_names = memory.get_seed_dataset_names() + + assert len(dataset_names) > 0, "No datasets were loaded into memory" + logger.info(f"LoadDefaultDatasets loaded {len(dataset_names)} datasets into memory") diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index 0a3d47ecee..854e59dead 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -10,7 +10,7 @@ from pyrit.datasets import SeedDatasetProvider from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader -from pyrit.datasets.seed_datasets.remote import _VLSUMultimodalDataset +from pyrit.datasets.seed_datasets.remote import _SimpleSafetyTestsDataset, _XSTestDataset from pyrit.datasets.seed_datasets.seed_metadata import ( SeedDatasetFilter, ) @@ -19,49 +19,49 @@ logger = logging.getLogger(__name__) -def get_dataset_providers(): - """Helper to get all registered providers for parameterization.""" - providers = SeedDatasetProvider.get_all_providers() - return [(name, cls) for name, cls in providers.items()] +# Smoke-test providers covering the three distinct fetch paths: +# - local YAML (no network) +# - remote URL-based (_fetch_from_url via GitHub) +# - remote HuggingFace (_fetch_from_huggingface) +_all_providers = SeedDatasetProvider.get_all_providers() +_SMOKE_PROVIDERS: list[tuple[str, type]] = [ + ("LocalDataset_access_shell_commands", _all_providers["LocalDataset_access_shell_commands"]), + ("_XSTestDataset", _XSTestDataset), + ("_SimpleSafetyTestsDataset", _SimpleSafetyTestsDataset), +] -class TestSeedDatasetProviderIntegration: - """Integration tests for SeedDatasetProvider.""" +class TestSeedDatasetSmoke: + """Smoke tests for a small representative set of dataset providers. + + The exhaustive test over all providers lives in tests/end_to_end/test_all_datasets.py. + """ @pytest.mark.asyncio - @pytest.mark.parametrize("name,provider_cls", get_dataset_providers()) - async def test_fetch_dataset_integration(self, name, provider_cls): + @pytest.mark.parametrize("name,provider_cls", _SMOKE_PROVIDERS, ids=[p[0] for p in _SMOKE_PROVIDERS]) + async def test_fetch_dataset_smoke(self, name, provider_cls): """ - Integration test to verify that a specific registered dataset can be fetched. + Verify that a representative provider can be fetched successfully. - This test is parameterized to run for each registered provider. - It verifies that: - 1. The dataset can be downloaded/loaded without error - 2. The result is a SeedDataset - 3. The dataset is not empty (has seeds) + Covers one local, one URL-remote, and one HuggingFace-remote provider + to catch regressions in each fetch path without downloading all 58 datasets. """ - logger.info(f"Testing provider: {name}") + logger.info(f"Smoke testing provider: {name}") - try: - # Use max_examples for slow providers that fetch many remote images - provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() - dataset = await provider.fetch_dataset(cache=False) + provider = provider_cls() + dataset = await provider.fetch_dataset(cache=False) - assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" - assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" - assert dataset.dataset_name, f"{name} has no dataset_name" + assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" + assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" + assert dataset.dataset_name, f"{name} has no dataset_name" - # Verify seeds have required fields - for seed in dataset.seeds: - assert seed.value, f"Seed in {name} has no value" - assert seed.dataset_name == dataset.dataset_name, ( - f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}" - ) - - logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds") + for seed in dataset.seeds: + assert seed.value, f"Seed in {name} has no value" + assert seed.dataset_name == dataset.dataset_name, ( + f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}" + ) - except Exception as e: - pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}") + logger.info(f"Smoke test passed for {name} with {len(dataset.seeds)} seeds") class TestRemoteFilteringIntegration: