diff --git a/assayer/exporter.py b/assayer/exporter.py index a3395ff..59ff353 100644 --- a/assayer/exporter.py +++ b/assayer/exporter.py @@ -4,6 +4,16 @@ from assayer.models import ModelResult +_FIELDNAMES = [ + "model", + "output", + "tokens_input", + "tokens_output", + "latency_seconds", + "cost_usd", + "error", +] + def _to_dict(result: ModelResult) -> dict: return { @@ -23,7 +33,7 @@ def export(results: list[ModelResult], path: str) -> None: if dest.suffix.lower() == ".csv": with dest.open("w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=list(records[0].keys())) + writer = csv.DictWriter(f, fieldnames=_FIELDNAMES) writer.writeheader() writer.writerows(records) else: diff --git a/tests/test_exporter.py b/tests/test_exporter.py index e82f48b..23dc1f3 100644 --- a/tests/test_exporter.py +++ b/tests/test_exporter.py @@ -6,7 +6,15 @@ from assayer.exporter import export from assayer.models import ModelResult -_EXPECTED_FIELDS = {"model", "output", "tokens_input", "tokens_output", "latency_seconds", "cost_usd", "error"} +_EXPECTED_FIELDS = { + "model", + "output", + "tokens_input", + "tokens_output", + "latency_seconds", + "cost_usd", + "error", +} def _results() -> list[ModelResult]: @@ -93,6 +101,18 @@ def test_export_csv_has_all_headers(tmp_path): assert set(reader.fieldnames) == _EXPECTED_FIELDS +def test_export_csv_empty_results_writes_headers(tmp_path): + path = tmp_path / "results.csv" + export([], str(path)) + + with path.open(encoding="utf-8") as f: + reader = csv.DictReader(f) + rows = list(reader) + + assert set(reader.fieldnames or []) == _EXPECTED_FIELDS + assert rows == [] + + def test_export_csv_case_insensitive_extension(tmp_path): path = tmp_path / "results.CSV" export(_results(), str(path))