diff --git a/pyrit/common/csv_helper.py b/pyrit/common/csv_helper.py index 48a9b9dd7e..2fb831149e 100644 --- a/pyrit/common/csv_helper.py +++ b/pyrit/common/csv_helper.py @@ -24,6 +24,9 @@ def write_csv(file: IO[Any], examples: list[dict[str, str]]) -> None: file: A file-like object opened for writing CSV data. examples (List[Dict[str, str]]): List of dictionaries to write as CSV rows. """ + if not examples: + return + writer = csv.DictWriter(file, fieldnames=examples[0].keys()) writer.writeheader() writer.writerows(examples) diff --git a/tests/unit/common/test_csv_helper.py b/tests/unit/common/test_csv_helper.py new file mode 100644 index 0000000000..0c1c5931a2 --- /dev/null +++ b/tests/unit/common/test_csv_helper.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import io + +from pyrit.common.csv_helper import write_csv + + +def test_write_csv_empty_examples_writes_nothing(): + file = io.StringIO() + write_csv(file, []) + assert file.getvalue() == "" + + +def test_write_csv_writes_header_and_rows(): + file = io.StringIO() + write_csv(file, [{"name": "alice", "role": "admin"}]) + lines = file.getvalue().strip().splitlines() + assert lines[0] == "name,role" + assert lines[1] == "alice,admin" diff --git a/tests/unit/datasets/test_remote_dataset_loader.py b/tests/unit/datasets/test_remote_dataset_loader.py index d0052a4c78..d9a2c8acfa 100644 --- a/tests/unit/datasets/test_remote_dataset_loader.py +++ b/tests/unit/datasets/test_remote_dataset_loader.py @@ -72,3 +72,13 @@ def test_write_cache_creates_directories(self, tmp_path): loader._write_cache(cache_file=cache_file, examples=data, file_type="json") assert cache_file.exists() + + def test_write_cache_csv_allows_empty_examples(self, tmp_path): + loader = ConcreteRemoteLoader() + cache_file = tmp_path / "empty.csv" + + loader._write_cache(cache_file=cache_file, examples=[], file_type="csv") + + assert cache_file.exists() + assert cache_file.read_text(encoding="utf-8") == "" + assert loader._read_cache(cache_file=cache_file, file_type="csv") == []