Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,27 @@ def _split_generators(self, dl_manager):
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
data_files = dl_manager.download(self.config.data_files)
splits = []

# Updates: Keep track of the first available file for feature inference
first_tar_path = None
first_tar_iterator = None

for split_name, tar_paths in data_files.items():
tar_iterators = [dl_manager.iter_archive(tar_path) for tar_path in tar_paths]
splits.append(
datasets.SplitGenerator(
name=split_name, gen_kwargs={"tar_paths": tar_paths, "tar_iterators": tar_iterators}
)
)
if not self.info.features:
# Updates: Save the first valid file we find across all splits
if tar_paths and first_tar_path is None:
first_tar_path = tar_paths[0]
first_tar_iterator = tar_iterators[0]

# Updates: Only attempt feature inference if we actually have a file
if not self.info.features and first_tar_path is not None:
# Get one example to get the feature types
pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0])
pipeline = self._get_pipeline_from_tar(first_tar_path, first_tar_iterator)
first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE))
if any(example.keys() != first_examples[0].keys() for example in first_examples):
raise ValueError(
Expand Down Expand Up @@ -109,13 +120,17 @@ def _generate_shards(self, tar_paths, tar_iterators):
yield from tar_paths

def _generate_examples(self, tar_paths, tar_iterators):
# NEW: Safely default to an empty dictionary if features are None
features = self.info.features or {}

image_field_names = [
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Image)
field_name for field_name, feature in features.items() if isinstance(feature, datasets.Image)
]
audio_field_names = [
field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio)
field_name for field_name, feature in features.items() if isinstance(feature, datasets.Audio)
]
all_field_names = list(self.info.features.keys())
all_field_names = list(features.keys())

for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)):
for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)):
for field_name in all_field_names:
Expand Down
Loading