Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,11 +1564,15 @@ def load_dataset(
Verification mode determining the checks to run on the downloaded/processed dataset information (checksums/size/splits/...).

<Added version="2.9.1"/>
keep_in_memory (`bool`, defaults to `None`):
keep_in_memory (`bool`, defaults to `None`):
Whether to copy the dataset in-memory. If `None`, the dataset
will not be copied in-memory unless explicitly enabled by setting `datasets.config.IN_MEMORY_MAX_SIZE` to
nonzero. See more details in the [improve performance](../cache#improve-performance) section.
revision ([`Version`] or `str`, *optional*):

.. warning::
Setting ``keep_in_memory=True`` in combination with PyTorch DataLoader multiprocessing (``num_workers > 0``) causes severe memory bloat. Python's copy-on-write semantics will duplicate the fully materialized dataset across all worker processes. For multi-GPU training on large datasets, rely on the default memory-mapping (``keep_in_memory=False``) or utilize IterableDatasets via ``streaming=True``.

revision ([`Version`] or `str`, *optional*):
Version of the dataset to load.
As datasets have their own git repository on the Datasets Hub, the default version "main" corresponds to their "main" branch.
You can specify a different version than the default "main" by using a commit SHA or a git tag of the dataset repository.
Expand Down
23 changes: 22 additions & 1 deletion src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ class JsonConfig(datasets.BuilderConfig):
newlines_in_values: Optional[bool] = None
on_mixed_types: Optional[Literal["use_json"]] = "use_json"
parse_agent_traces: bool = True
return_file_name: bool = False

def __post_init__(self):
super().__post_init__()
if self.return_file_name and self.features is not None and "file_name" not in self.features:
self.features = self.features.copy()
self.features["file_name"] = datasets.Value("string")


class Json(datasets.ArrowBasedBuilder):
Expand All @@ -70,7 +74,12 @@ def _info(self):
)
if self.config.newlines_in_values is not None:
raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported")
return datasets.DatasetInfo(features=self.config.features)
features = self.config.features
if features is not None and self.config.return_file_name and "file_name" not in features:
features = features.copy()
features["file_name"] = datasets.Value("string")
self.config.features = features
return datasets.DatasetInfo(features=features)

def _split_generators(self, dl_manager):
"""We handle string, list and dicts in datafiles"""
Expand Down Expand Up @@ -159,6 +168,9 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
pa_table = pa.Table.from_pandas(df, preserve_index=False)
if self.config.return_file_name:
file_name_col = pa.array([file] * len(pa_table), type=pa.string())
pa_table = pa_table.append_column("file_name", file_name_col)
yield Key(shard_idx, 0), self._cast_table(pa_table)

# If the files are agent traces (one row = one file)
Expand All @@ -181,6 +193,9 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu
"file_path": [file_path],
}
)
if self.config.return_file_name:
file_name_col = pa.array([file] * len(pa_table), type=pa.string())
pa_table = pa_table.append_column("file_name", file_name_col)
yield Key(shard_idx, 0), self._cast_table(pa_table)

# If the file has one json object per line
Expand Down Expand Up @@ -292,8 +307,14 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu
raise ValueError(
f"Failed to convert pandas DataFrame to Arrow Table from file {file}."
) from None
if self.config.return_file_name:
file_name_col = pa.array([file] * len(pa_table), type=pa.string())
pa_table = pa_table.append_column("file_name", file_name_col)
yield Key(shard_idx, 0), self._cast_table(pa_table)
break
if self.config.return_file_name:
file_name_col = pa.array([file] * len(pa_table), type=pa.string())
pa_table = pa_table.append_column("file_name", file_name_col)
yield (
Key(shard_idx, batch_idx),
self._cast_table(pa_table, json_field_paths=json_field_paths),
Expand Down