From a354e1a13c2e55e0eaeeb0e62ed248d1ea83273f Mon Sep 17 00:00:00 2001 From: Kinshuk Sharma Date: Sat, 16 May 2026 11:40:53 +0530 Subject: [PATCH] fix(webdataset): handle empty shards during distributed streaming During multi-machine distributed training with streaming=True, if the number of tar files is fewer than the number of ranks, some ranks receive an empty list of files. Previously, `_split_generators` hardcoded `tar_paths[0]` for feature inference, causing an IndexError on empty ranks. This commit: 1. Iterates through splits to find the first valid file for feature inference. 2. Safely defaults to an empty dictionary in `_generate_examples` if `features` remains None, preventing subsequent AttributeErrors. --- .../packaged_modules/webdataset/webdataset.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index 153b2228a24..4e5e73bb01b 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -67,6 +67,11 @@ def _split_generators(self, dl_manager): raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") data_files = dl_manager.download(self.config.data_files) splits = [] + + # Updates: Keep track of the first available file for feature inference + first_tar_path = None + first_tar_iterator = None + for split_name, tar_paths in data_files.items(): tar_iterators = [dl_manager.iter_archive(tar_path) for tar_path in tar_paths] splits.append( @@ -74,9 +79,15 @@ def _split_generators(self, dl_manager): name=split_name, gen_kwargs={"tar_paths": tar_paths, "tar_iterators": tar_iterators} ) ) - if not self.info.features: + # Updates: Save the first valid file we find across all splits + if tar_paths and first_tar_path is None: + first_tar_path = tar_paths[0] + first_tar_iterator = tar_iterators[0] + + # Updates: Only attempt feature inference if we actually have a file + if not self.info.features and first_tar_path is not None: # Get one example to get the feature types - pipeline = self._get_pipeline_from_tar(tar_paths[0], tar_iterators[0]) + pipeline = self._get_pipeline_from_tar(first_tar_path, first_tar_iterator) first_examples = list(islice(pipeline, self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE)) if any(example.keys() != first_examples[0].keys() for example in first_examples): raise ValueError( @@ -109,13 +120,17 @@ def _generate_shards(self, tar_paths, tar_iterators): yield from tar_paths def _generate_examples(self, tar_paths, tar_iterators): + # NEW: Safely default to an empty dictionary if features are None + features = self.info.features or {} + image_field_names = [ - field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Image) + field_name for field_name, feature in features.items() if isinstance(feature, datasets.Image) ] audio_field_names = [ - field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio) + field_name for field_name, feature in features.items() if isinstance(feature, datasets.Audio) ] - all_field_names = list(self.info.features.keys()) + all_field_names = list(features.keys()) + for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)): for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)): for field_name in all_field_names: