From 3db0da1520cde3198720570f5f21c78c2322af68 Mon Sep 17 00:00:00 2001 From: biefan <70761325+biefan@users.noreply.github.com> Date: Tue, 17 Mar 2026 03:07:34 +0000 Subject: [PATCH] Normalize remote dataset file types from URLs --- .../remote/remote_dataset_loader.py | 15 ++++++++- .../datasets/test_remote_dataset_loader.py | 32 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 5cd9212846..7622175ef6 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -10,6 +10,7 @@ from collections.abc import Callable from pathlib import Path from typing import Any, Literal, Optional, TextIO, cast +from urllib.parse import urlparse import requests from datasets import DownloadMode, disable_progress_bars, load_dataset @@ -76,6 +77,18 @@ def _validate_file_type(self, file_type: str) -> None: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + def _get_file_type(self, *, source: str) -> str: + """ + Infer the source file type from a URL or local path. + + Query strings and fragments are ignored for URLs, and the result is + normalized to lowercase so `.JSON` and `.json` are treated identically. + """ + parsed = urlparse(source) + source_path = parsed.path if parsed.scheme else source + suffix = Path(source_path).suffix + return suffix.lstrip(".").lower() + def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str]]: """ Read data from cache. @@ -188,7 +201,7 @@ def _fetch_from_url( ... source_type='public_url' ... ) """ - file_type = source.split(".")[-1] + file_type = self._get_file_type(source=source) if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") diff --git a/tests/unit/datasets/test_remote_dataset_loader.py b/tests/unit/datasets/test_remote_dataset_loader.py index d0052a4c78..a1325e3052 100644 --- a/tests/unit/datasets/test_remote_dataset_loader.py +++ b/tests/unit/datasets/test_remote_dataset_loader.py @@ -72,3 +72,35 @@ def test_write_cache_creates_directories(self, tmp_path): loader._write_cache(cache_file=cache_file, examples=data, file_type="json") assert cache_file.exists() + + @patch.object(_RemoteDatasetLoader, "_fetch_from_public_url", return_value=[{"key": "value"}]) + def test_fetch_from_url_supports_query_string_file_type(self, mock_fetch_from_public_url): + loader = ConcreteRemoteLoader() + + result = loader._fetch_from_url( + source="https://example.com/data.json?download=1", + source_type="public_url", + cache=False, + ) + + assert result == [{"key": "value"}] + mock_fetch_from_public_url.assert_called_once_with( + source="https://example.com/data.json?download=1", + file_type="json", + ) + + @patch.object(_RemoteDatasetLoader, "_fetch_from_public_url", return_value=[{"key": "value"}]) + def test_fetch_from_url_supports_uppercase_file_type(self, mock_fetch_from_public_url): + loader = ConcreteRemoteLoader() + + result = loader._fetch_from_url( + source="https://example.com/data.JSON", + source_type="public_url", + cache=False, + ) + + assert result == [{"key": "value"}] + mock_fetch_from_public_url.assert_called_once_with( + source="https://example.com/data.JSON", + file_type="json", + )