Skip to content
114 changes: 114 additions & 0 deletions tests/end_to_end/test_all_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
End-to-end tests that verify every registered dataset provider can be fetched.

These tests download real data from HuggingFace and GitHub, are slow, and are
subject to transient network failures. They are intended to run daily in e2e CI,
not on every PR.

Resiliency: each fetch is retried up to 3 times with exponential backoff to
handle transient HuggingFace / GitHub rate-limiting and network errors.
"""

import asyncio
import logging
import os

import pytest
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential

from pyrit.datasets import SeedDatasetProvider
from pyrit.datasets.seed_datasets.remote import (
_HarmBenchMultimodalDataset,
_PromptIntelDataset,
_VLSUMultimodalDataset,
)
from pyrit.models import SeedDataset
from pyrit.setup import IN_MEMORY, initialize_pyrit_async

logger = logging.getLogger(__name__)

# Per-test timeout in seconds (5 minutes per dataset)
_TEST_TIMEOUT = 300

# Transient error types that warrant a retry
_RETRYABLE_ERRORS = (OSError, ConnectionError, TimeoutError)

# Providers that download many remote images; each image fetch may fail
# due to rate-limiting, so an empty result is expected in some environments.
_IMAGE_FETCHING_PROVIDERS: set[type] = {_HarmBenchMultimodalDataset, _VLSUMultimodalDataset}


def get_dataset_providers():
"""Helper to get all registered providers for parameterization."""
providers = SeedDatasetProvider.get_all_providers()
return [(name, cls) for name, cls in providers.items()]


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=5, min=5, max=60),
retry=retry_if_exception_type(_RETRYABLE_ERRORS),
reraise=True,
)
async def _fetch_with_retry(provider) -> SeedDataset:
"""Fetch a dataset with retry on transient network errors."""
return await provider.fetch_dataset(cache=False)


@pytest.fixture(scope="module", autouse=True)
def _init_memory():
"""Multimodal providers need CentralMemory to save downloaded images."""
asyncio.run(initialize_pyrit_async(memory_db_type=IN_MEMORY))


class TestAllDatasets:
"""Exhaustive test that every registered dataset provider can be fetched."""

@pytest.mark.asyncio
@pytest.mark.timeout(_TEST_TIMEOUT)
@pytest.mark.parametrize("name,provider_cls", get_dataset_providers())
async def test_fetch_dataset(self, name, provider_cls):
"""
Verify that a specific registered dataset can be fetched.

This test is parameterized to run for each registered provider.
It verifies that:
1. The dataset can be downloaded/loaded without error
2. The result is a SeedDataset
3. The dataset is not empty (has seeds)

Retries up to 3 times on transient network errors.
"""
# Skip providers that require credentials not available in CI
if provider_cls == _PromptIntelDataset and not os.environ.get("PROMPTINTEL_API_KEY"):
pytest.skip("PROMPTINTEL_API_KEY not set")

logger.info(f"Testing provider: {name}")

try:
# Limit examples for slow multimodal providers that fetch many remote images
provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls()

dataset = await _fetch_with_retry(provider)
except Exception as e:
# Multimodal providers silently skip failed image downloads. When ALL
# images fail the resulting empty seed list triggers "SeedDataset cannot
# be empty". That is a transient environment issue, not a code bug.
if provider_cls in _IMAGE_FETCHING_PROVIDERS and "cannot be empty" in str(e):
pytest.skip(f"{name}: all image downloads failed ({e})")
pytest.fail(f"Failed to fetch dataset from {name}: {e}")

assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset"
assert dataset.dataset_name, f"{name} has no dataset_name"
assert len(dataset.seeds) > 0, f"{name} returned an empty dataset"

for seed in dataset.seeds:
assert seed.value, f"Seed in {name} has no value"
assert seed.dataset_name == dataset.dataset_name, (
f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}"
)

logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds")
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Integration test for the LoadDefaultDatasets initializer.

Runs the full pipeline: discovers scenario default datasets, fetches them
from real remote sources, and stores them in in-memory CentralMemory.
"""

import logging

import pytest

from pyrit.memory import CentralMemory
from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets

logger = logging.getLogger(__name__)


class TestLoadDefaultDatasetsIntegration:
"""Integration test that LoadDefaultDatasets loads real datasets into memory."""

@pytest.mark.asyncio
async def test_initialize_loads_datasets_into_memory(self):
"""
Verify that LoadDefaultDatasets.initialize_async() successfully fetches
real datasets and stores them in CentralMemory.
"""
initializer = LoadDefaultDatasets()
await initializer.initialize_async()

memory = CentralMemory.get_memory_instance()
dataset_names = memory.get_seed_dataset_names()

assert len(dataset_names) > 0, "No datasets were loaded into memory"
logger.info(f"LoadDefaultDatasets loaded {len(dataset_names)} datasets into memory")
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pyrit.datasets import SeedDatasetProvider
from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader
from pyrit.datasets.seed_datasets.remote import _VLSUMultimodalDataset
from pyrit.datasets.seed_datasets.remote import _SimpleSafetyTestsDataset, _XSTestDataset
from pyrit.datasets.seed_datasets.seed_metadata import (
SeedDatasetFilter,
)
Expand All @@ -19,49 +19,49 @@
logger = logging.getLogger(__name__)


def get_dataset_providers():
"""Helper to get all registered providers for parameterization."""
providers = SeedDatasetProvider.get_all_providers()
return [(name, cls) for name, cls in providers.items()]
# Smoke-test providers covering the three distinct fetch paths:
# - local YAML (no network)
# - remote URL-based (_fetch_from_url via GitHub)
# - remote HuggingFace (_fetch_from_huggingface)
_all_providers = SeedDatasetProvider.get_all_providers()
_SMOKE_PROVIDERS: list[tuple[str, type]] = [
("LocalDataset_access_shell_commands", _all_providers["LocalDataset_access_shell_commands"]),
("_XSTestDataset", _XSTestDataset),
("_SimpleSafetyTestsDataset", _SimpleSafetyTestsDataset),
]


class TestSeedDatasetProviderIntegration:
"""Integration tests for SeedDatasetProvider."""
class TestSeedDatasetSmoke:
"""Smoke tests for a small representative set of dataset providers.

The exhaustive test over all providers lives in tests/end_to_end/test_all_datasets.py.
"""

@pytest.mark.asyncio
@pytest.mark.parametrize("name,provider_cls", get_dataset_providers())
async def test_fetch_dataset_integration(self, name, provider_cls):
@pytest.mark.parametrize("name,provider_cls", _SMOKE_PROVIDERS, ids=[p[0] for p in _SMOKE_PROVIDERS])
async def test_fetch_dataset_smoke(self, name, provider_cls):
"""
Integration test to verify that a specific registered dataset can be fetched.
Verify that a representative provider can be fetched successfully.

This test is parameterized to run for each registered provider.
It verifies that:
1. The dataset can be downloaded/loaded without error
2. The result is a SeedDataset
3. The dataset is not empty (has seeds)
Covers one local, one URL-remote, and one HuggingFace-remote provider
to catch regressions in each fetch path without downloading all 58 datasets.
"""
logger.info(f"Testing provider: {name}")
logger.info(f"Smoke testing provider: {name}")

try:
# Use max_examples for slow providers that fetch many remote images
provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls()
dataset = await provider.fetch_dataset(cache=False)
provider = provider_cls()
dataset = await provider.fetch_dataset(cache=False)

assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset"
assert len(dataset.seeds) > 0, f"{name} returned an empty dataset"
assert dataset.dataset_name, f"{name} has no dataset_name"
assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset"
assert len(dataset.seeds) > 0, f"{name} returned an empty dataset"
assert dataset.dataset_name, f"{name} has no dataset_name"

# Verify seeds have required fields
for seed in dataset.seeds:
assert seed.value, f"Seed in {name} has no value"
assert seed.dataset_name == dataset.dataset_name, (
f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}"
)

logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds")
for seed in dataset.seeds:
assert seed.value, f"Seed in {name} has no value"
assert seed.dataset_name == dataset.dataset_name, (
f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}"
)

except Exception as e:
pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}")
logger.info(f"Smoke test passed for {name} with {len(dataset.seeds)} seeds")


class TestRemoteFilteringIntegration:
Expand Down
Loading