diff --git a/docs/source/about_dataset_features.mdx b/docs/source/about_dataset_features.mdx index 3df1c5773cf..968ae4aa285 100644 --- a/docs/source/about_dataset_features.mdx +++ b/docs/source/about_dataset_features.mdx @@ -134,6 +134,33 @@ And in this case the numpy arrays are encoded into PNG (or TIFF if the pixels va For multi-channels arrays like RGB or RGBA, only uint8 is supported. If you use a larger precision, you get a warning and the array is downcasted to uint8. For gray-scale images you can use the integer or float precision you want as long as it is compatible with `Pillow`. A warning is shown if your image integer or float precision is too high, and in this case the array is downcated: an int64 array is downcasted to int32, and a float64 array is downcasted to float32. +## Mesh feature + +Mesh datasets have a column with type [`Mesh`], which loads 3D mesh files with `trimesh`. + +When you load a mesh dataset and call the mesh column, the [`Mesh`] feature automatically decodes the mesh file: + +```py +>>> from datasets import Dataset, Features, Mesh + +>>> dataset = Dataset.from_dict({"mesh": ["path/to/model.glb"]}, features=Features({"mesh": Mesh()})) +>>> dataset[0]["mesh"] + +``` + +Depending on the file content, `trimesh` may return a `trimesh.Trimesh` object or a `trimesh.Scene` object. GLB files commonly decode to scenes, while STL and PLY files commonly decode to meshes. + +With `decode=False`, the [`Mesh`] type gives you the path or bytes of the mesh file without decoding it with `trimesh`: + +```py +>>> dataset = dataset.cast_column("mesh", Mesh(decode=False)) +>>> dataset[0]["mesh"] +{'bytes': None, + 'path': 'path/to/model.glb'} +``` + +For embedded bytes, the stored `path` is used to infer the mesh file type. + ## Json feature Datasets are based on Arrow which is a columnar format, and therefore they expect every example to have the same type and subtypes, and dictionaries to have the same keys and values types. diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 968364ff176..5d3ce6588b5 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -277,6 +277,10 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable [[autodoc]] datasets.Video +### Mesh + +[[autodoc]] datasets.Mesh + ### Json [[autodoc]] datasets.Json diff --git a/setup.py b/setup.py index 4a6b4b13787..fd7ecc31508 100644 --- a/setup.py +++ b/setup.py @@ -145,6 +145,10 @@ "Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced ] +MESH_REQUIRE = [ + "trimesh>=4.10.0", +] + BENCHMARKS_REQUIRE = [ "tensorflow==2.12.0", "torch==2.0.1", @@ -188,6 +192,7 @@ "Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced "torchcodec>=0.7.0; python_version < '3.14'", # minium version to get windows support, torchcodec doesn't have wheels for 3.14 yet "nibabel>=5.3.1", + "trimesh>=4.10.0", ] NUMPY2_INCOMPATIBLE_LIBRARIES = [ @@ -214,6 +219,7 @@ EXTRAS_REQUIRE = { "audio": AUDIO_REQUIRE, "vision": VISION_REQUIRE, + "mesh": MESH_REQUIRE, "tensorflow": [ "tensorflow>=2.6.0", ], diff --git a/src/datasets/config.py b/src/datasets/config.py index 5f5212d9102..fc16dbe713a 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -140,6 +140,7 @@ TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None PDFPLUMBER_AVAILABLE = importlib.util.find_spec("pdfplumber") is not None NIBABEL_AVAILABLE = importlib.util.find_spec("nibabel") is not None +TRIMESH_AVAILABLE = importlib.util.find_spec("trimesh") is not None # Optional compression tools RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index d2c780f788c..163aaa0e2c5 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -12,6 +12,7 @@ "Sequence", "Value", "Image", + "Mesh", "Translation", "TranslationVariableLanguages", "Video", @@ -21,6 +22,7 @@ from .audio import Audio from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, Json, LargeList, List, Sequence, Value from .image import Image +from .mesh import Mesh from .nifti import Nifti from .pdf import Pdf from .translation import Translation, TranslationVariableLanguages diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 0a67d9c31d5..b377a3acdd0 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -43,6 +43,7 @@ from ..utils.py_utils import asdict, first_non_null_value, zip_dict from .audio import Audio from .image import Image, encode_pil_image +from .mesh import Mesh from .nifti import Nifti, encode_nibabel_image from .pdf import Pdf, encode_pdfplumber_pdf from .translation import Translation, TranslationVariableLanguages @@ -1361,6 +1362,7 @@ def __repr__(self): Array5D, Audio, Image, + Mesh, Video, Pdf, Nifti, @@ -1522,6 +1524,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[dict[str, Uni Array5D.__name__: Array5D, Audio.__name__: Audio, Image.__name__: Image, + Mesh.__name__: Mesh, Video.__name__: Video, Pdf.__name__: Pdf, Nifti.__name__: Nifti, diff --git a/src/datasets/features/mesh.py b/src/datasets/features/mesh.py new file mode 100644 index 00000000000..8f674febbe3 --- /dev/null +++ b/src/datasets/features/mesh.py @@ -0,0 +1,271 @@ +import os +from dataclasses import dataclass, field +from io import BytesIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union + +import pyarrow as pa + +from .. import config +from ..download.download_config import DownloadConfig +from ..table import array_cast +from ..utils.file_utils import is_local_path, xopen +from ..utils.py_utils import string_to_dict + + +if TYPE_CHECKING: + import trimesh + + from .features import FeatureType + + +@dataclass +class Mesh: + """Mesh [`Feature`] to read 3D mesh data from a file. + + Input: The Mesh feature accepts as input: + - A `str`: Absolute path to the mesh file (i.e. random access is allowed). + - A `pathlib.Path`: path to the mesh file (i.e. random access is allowed). + - A `dict` with the keys: + + - `path`: String with relative path of the mesh file to the archive file. + - `bytes`: Bytes of the mesh file. + + This is useful for parquet or webdataset files which embed mesh files. + + - A `trimesh.Trimesh` or `trimesh.Scene`: 3D mesh or scene object. + + Output: The Mesh feature outputs data as `trimesh.Trimesh` or `trimesh.Scene` objects. + + Args: + decode (`bool`, defaults to `True`): + Whether to decode the mesh data. If `False`, + returns the underlying dictionary in the format `{"path": mesh_path, "bytes": mesh_bytes}`. + Mesh decoding uses `trimesh` and supports `.glb`, `.ply`, and `.stl` files. + """ + + decode: bool = True + id: Optional[str] = field(default=None, repr=False) + # Automatically constructed + dtype: ClassVar[str] = "trimesh.Trimesh" + pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()}) + _type: str = field(default="Mesh", init=False, repr=False) + + def __call__(self): + return self.pa_type + + def encode_example(self, value: Union[str, bytes, bytearray, dict, "trimesh.Trimesh", "trimesh.Scene"]) -> dict: + """Encode example into a format for Arrow. + + Args: + value (`str`, `bytes`, `dict`, `trimesh.Trimesh`, or `trimesh.Scene`): + Data passed as input to Mesh feature. + + Returns: + `dict` with "path" and "bytes" fields + """ + if config.TRIMESH_AVAILABLE: + import trimesh + else: + trimesh = None + + if isinstance(value, str): + return {"path": value, "bytes": None} + elif isinstance(value, Path): + return {"path": str(value.absolute()), "bytes": None} + elif isinstance(value, (bytes, bytearray)): + return {"path": None, "bytes": value} + elif trimesh is not None and isinstance(value, (trimesh.Trimesh, trimesh.Scene)): + return encode_trimesh_mesh(value) + elif isinstance(value, dict) and value.get("path") is not None and os.path.isfile(value["path"]): + # we set "bytes": None to not duplicate the data if they're already available locally + return {"bytes": None, "path": value.get("path")} + elif isinstance(value, dict) and (value.get("bytes") is not None or value.get("path") is not None): + # store the mesh bytes, and path is used to infer the mesh format using the file extension + return {"bytes": value.get("bytes"), "path": value.get("path")} + else: + raise ValueError( + f"A mesh sample should have one of 'path' or 'bytes' but they are missing or None in {value}." + ) + + def decode_example(self, value: dict, token_per_repo_id=None) -> Union["trimesh.Trimesh", "trimesh.Scene"]: + """Decode example mesh file. + + Args: + value (`dict`): + A dictionary with keys: + + - `path`: String with absolute or relative mesh file path. + - `bytes`: The bytes of the mesh file. + token_per_repo_id (`dict`, *optional*): + To access and decode + mesh files from private repositories on the Hub, you can pass + a dictionary repo_id (`str`) -> token (`bool` or `str`). + + Returns: + `trimesh.Trimesh` or `trimesh.Scene` + """ + if not self.decode: + raise RuntimeError("Decoding is disabled for this feature. Please use Mesh(decode=True) instead.") + + if config.TRIMESH_AVAILABLE: + import trimesh + else: + raise ImportError("To support decoding meshes, please install 'trimesh'.") + + if token_per_repo_id is None: + token_per_repo_id = {} + + path, bytes_ = value["path"], value["bytes"] + if bytes_ is None: + if path is None: + raise ValueError(f"A mesh should have one of 'path' or 'bytes' but both are None in {value}.") + if is_local_path(path): + file_type = _infer_mesh_file_type(path) + if file_type is None: + raise ValueError("A mesh path should have a .glb, .ply, or .stl extension.") + return trimesh.load(path, file_type=file_type) + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL + ) + source_url_fields = string_to_dict(source_url, pattern) + token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + bytes_ = f.read() + + file_type = _infer_mesh_file_type(path) + if file_type is None: + raise ValueError( + "Decoding mesh bytes requires a 'path' value with a .glb, .ply, or .stl extension " + "to infer the mesh file type." + ) + return trimesh.load(BytesIO(bytes_), file_type=file_type) + + def flatten(self) -> Union["FeatureType", dict[str, "FeatureType"]]: + """If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary.""" + from .features import Value + + return ( + self + if self.decode + else { + "bytes": Value("binary"), + "path": Value("string"), + } + ) + + def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.StructArray: + """Cast an Arrow array to the Mesh arrow storage type. + The Arrow types that can be converted to the Mesh pyarrow storage type are: + + - `pa.string()` - it must contain the "path" data + - `pa.large_string()` - it must contain the "path" data (will be cast to string if possible) + - `pa.binary()` - it must contain the mesh bytes + - `pa.struct({"bytes": pa.binary()})` + - `pa.struct({"path": pa.string()})` + - `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter + + Args: + storage (`Union[pa.StringArray, pa.StructArray]`): + PyArrow array to cast. + + Returns: + `pa.StructArray`: Array in the Mesh arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + if pa.types.is_large_string(storage.type): + try: + storage = storage.cast(pa.string()) + except pa.ArrowInvalid as e: + raise ValueError( + f"Failed to cast large_string to string for Mesh feature. " + f"This can happen if string values exceed 2GB. " + f"Original error: {e}" + ) from e + if pa.types.is_string(storage.type): + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_large_binary(storage.type): + storage = array_cast( + storage, pa.binary() + ) # this can fail in case of big meshes, paths should be used instead + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_binary(storage.type): + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) + + return array_cast(storage, self.pa_type) + + def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray: + """Embed mesh files into the Arrow array. + + Args: + storage (`pa.StructArray`): + PyArrow array to embed. + + Returns: + `pa.StructArray`: Array in the Mesh arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + if token_per_repo_id is None: + token_per_repo_id = {} + + def path_to_bytes(path): + if path is None: + return None + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL + ) + source_url_fields = string_to_dict(source_url, pattern) + token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + return f.read() + + bytes_array = pa.array( + [ + (path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None + for x in storage.to_pylist() + ], + type=pa.binary(), + ) + path_array = pa.array( + [os.path.basename(path) if path is not None else None for path in storage.field("path").to_pylist()], + type=pa.string(), + ) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()) + return array_cast(storage, self.pa_type) + + +def _infer_mesh_file_type(path: Optional[str]) -> Optional[str]: + supported_file_types = {"glb", "ply", "stl"} + if path is None: + return None + path_without_archive = path.split("::", 1)[0] + path_without_query = path_without_archive.split("?", 1)[0] + extension = os.path.splitext(path_without_query)[1].lower().lstrip(".") + return extension if extension in supported_file_types else None + + +def encode_trimesh_mesh(mesh: Union["trimesh.Trimesh", "trimesh.Scene"]) -> dict[str, Optional[bytes | str]]: + """Encode a trimesh mesh or scene object into GLB bytes.""" + metadata = getattr(mesh, "metadata", None) or {} + path = metadata.get("file_path") or metadata.get("file_name") if isinstance(metadata, dict) else None + if path is not None and os.path.isfile(path): + return {"path": path, "bytes": None} + bytes_ = mesh.export(file_type="glb") + return {"path": os.path.basename(path) if path else "mesh.glb", "bytes": bytes_} diff --git a/src/datasets/load.py b/src/datasets/load.py index 560bcad3a44..01b71f8ebdb 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1232,7 +1232,7 @@ def load_dataset_builder( You can find the list of datasets on the [Hub](https://huggingface.co/datasets) or with [`huggingface_hub.list_datasets`]. A dataset is a directory that contains some data files in generic formats (JSON, CSV, Parquet, etc.) and possibly - in a generic structure (Webdataset, ImageFolder, AudioFolder, VideoFolder, etc.) + in a generic structure (Webdataset, ImageFolder, AudioFolder, VideoFolder, MeshFolder, etc.) Args: @@ -1252,7 +1252,7 @@ def load_dataset_builder( e.g. `'./path/to/directory/with/my/csv/data'`. - if `path` is the name of a dataset builder and `data_files` or `data_dir` is specified - (available builders are "json", "csv", "parquet", "arrow", "text", "xml", "webdataset", "imagefolder", "audiofolder", "videofolder") + (available builders are "json", "csv", "parquet", "arrow", "text", "xml", "webdataset", "imagefolder", "audiofolder", "videofolder", "meshfolder") -> load the dataset builder from the files in `data_files` or `data_dir` e.g. `'parquet'`. @@ -1489,13 +1489,13 @@ def load_dataset( You can find the list of datasets on the [Hub](https://huggingface.co/datasets) or with [`huggingface_hub.list_datasets`]. A dataset is a directory that contains some data files in generic formats (JSON, CSV, Parquet, etc.) and possibly - in a generic structure (Webdataset, ImageFolder, AudioFolder, VideoFolder, etc.) + in a generic structure (Webdataset, ImageFolder, AudioFolder, VideoFolder, MeshFolder, etc.) This function does the following under the hood: 1. Load a dataset builder: - * Find the most common data format in the dataset and pick its associated builder (JSON, CSV, Parquet, Webdataset, ImageFolder, AudioFolder, etc.) + * Find the most common data format in the dataset and pick its associated builder (JSON, CSV, Parquet, Webdataset, ImageFolder, AudioFolder, MeshFolder, etc.) * Find which file goes into which split (e.g. train/test) based on file and directory names or on the YAML configuration * It is also possible to specify `data_files` manually, and which dataset builder to use (e.g. "parquet"). @@ -1533,7 +1533,7 @@ def load_dataset( e.g. `'./path/to/directory/with/my/csv/data'`. - if `path` is the name of a dataset builder and `data_files` or `data_dir` is specified - (available builders are "json", "csv", "parquet", "arrow", "text", "xml", "webdataset", "imagefolder", "audiofolder", "videofolder") + (available builders are "json", "csv", "parquet", "arrow", "text", "xml", "webdataset", "imagefolder", "audiofolder", "videofolder", "meshfolder") -> load the dataset from the files in `data_files` or `data_dir` e.g. `'parquet'`. diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index f0ebcb79693..5024d1fa419 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -13,6 +13,7 @@ from .imagefolder import imagefolder from .json import json from .lance import lance +from .meshfolder import meshfolder from .niftifolder import niftifolder from .pandas import pandas from .parquet import parquet @@ -48,6 +49,7 @@ def _hash_python_lines(lines: list[str]) -> str: "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())), + "meshfolder": (meshfolder.__name__, _hash_python_lines(inspect.getsource(meshfolder).splitlines())), "pdffolder": (pdffolder.__name__, _hash_python_lines(inspect.getsource(pdffolder).splitlines())), "niftifolder": (niftifolder.__name__, _hash_python_lines(inspect.getsource(niftifolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), @@ -95,6 +97,8 @@ def _hash_python_lines(lines: list[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext: ("meshfolder", {}) for ext in meshfolder.MeshFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("meshfolder", {}) for ext in meshfolder.MeshFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("niftifolder", {}) for ext in niftifolder.NiftiFolder.EXTENSIONS}) @@ -115,6 +119,7 @@ def _hash_python_lines(lines: list[str]) -> str: _MODULE_TO_METADATA_FILE_NAMES["imagefolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["audiofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["videofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES +_MODULE_TO_METADATA_FILE_NAMES["meshfolder"] = meshfolder.MeshFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["pdffolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["niftifolder"] = imagefolder.ImageFolder.METADATA_FILENAMES diff --git a/src/datasets/packaged_modules/meshfolder/__init__.py b/src/datasets/packaged_modules/meshfolder/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/meshfolder/meshfolder.py b/src/datasets/packaged_modules/meshfolder/meshfolder.py new file mode 100644 index 00000000000..293c66f62e1 --- /dev/null +++ b/src/datasets/packaged_modules/meshfolder/meshfolder.py @@ -0,0 +1,31 @@ +import datasets + +from ..folder_based_builder import folder_based_builder + + +logger = datasets.utils.logging.get_logger(__name__) + + +class MeshFolderConfig(folder_based_builder.FolderBasedBuilderConfig): + """BuilderConfig for MeshFolder.""" + + drop_labels: bool = None + drop_metadata: bool = None + + def __post_init__(self): + super().__post_init__() + + +class MeshFolder(folder_based_builder.FolderBasedBuilder): + BASE_FEATURE = datasets.Mesh + BASE_COLUMN_NAME = "mesh" + BUILDER_CONFIG_CLASS = MeshFolderConfig + EXTENSIONS: list[str] # definition at the bottom of the script + + +MESH_EXTENSIONS = [ + ".glb", + ".ply", + ".stl", +] +MeshFolder.EXTENSIONS = MESH_EXTENSIONS diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index 153b2228a24..4c8c21b17cd 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -22,6 +22,7 @@ class WebDataset(datasets.GeneratorBasedBuilder): IMAGE_EXTENSIONS: list[str] # definition at the bottom of the script AUDIO_EXTENSIONS: list[str] # definition at the bottom of the script VIDEO_EXTENSIONS: list[str] # definition at the bottom of the script + MESH_EXTENSIONS: list[str] # definition at the bottom of the script DECODERS: dict[str, Callable[[Any], Any]] # definition at the bottom of the script NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5 @@ -101,6 +102,9 @@ def _split_generators(self, dl_manager): # Set Video types if extension in self.VIDEO_EXTENSIONS: features[field_name] = datasets.Video() + # Set Mesh types + if extension in self.MESH_EXTENSIONS: + features[field_name] = datasets.Mesh() self.info.features = features return splits @@ -115,13 +119,19 @@ def _generate_examples(self, tar_paths, tar_iterators): audio_field_names = [ field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Audio) ] + video_field_names = [ + field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Video) + ] + mesh_field_names = [ + field_name for field_name, feature in self.info.features.items() if isinstance(feature, datasets.Mesh) + ] all_field_names = list(self.info.features.keys()) for tar_idx, (tar_path, tar_iterator) in enumerate(zip(tar_paths, tar_iterators)): for example_idx, example in enumerate(self._get_pipeline_from_tar(tar_path, tar_iterator)): for field_name in all_field_names: if field_name not in example: example[field_name] = None - for field_name in image_field_names + audio_field_names: + for field_name in image_field_names + audio_field_names + video_field_names + mesh_field_names: if example[field_name] is not None: example[field_name] = { "path": example["__key__"] + "." + field_name, @@ -278,6 +288,14 @@ def base_plus_ext(path): WebDataset.VIDEO_EXTENSIONS = VIDEO_EXTENSIONS +MESH_EXTENSIONS = [ + "glb", + "ply", + "stl", +] +WebDataset.MESH_EXTENSIONS = MESH_EXTENSIONS + + def text_loads(data: bytes): return data.decode("utf-8") diff --git a/tests/features/data/test_mesh_glb.glb b/tests/features/data/test_mesh_glb.glb new file mode 100644 index 00000000000..dbd83ca8fa5 Binary files /dev/null and b/tests/features/data/test_mesh_glb.glb differ diff --git a/tests/features/data/test_mesh_ply.ply b/tests/features/data/test_mesh_ply.ply new file mode 100644 index 00000000000..b40cbf7cdfa Binary files /dev/null and b/tests/features/data/test_mesh_ply.ply differ diff --git a/tests/features/data/test_mesh_stl.stl b/tests/features/data/test_mesh_stl.stl new file mode 100644 index 00000000000..bca4f39e90f Binary files /dev/null and b/tests/features/data/test_mesh_stl.stl differ diff --git a/tests/features/test_mesh.py b/tests/features/test_mesh.py new file mode 100644 index 00000000000..0303be77002 --- /dev/null +++ b/tests/features/test_mesh.py @@ -0,0 +1,152 @@ +import os +from pathlib import Path + +import pyarrow as pa +import pytest + +from datasets import Column, Dataset, Features, Mesh, Sequence, concatenate_datasets +from datasets.features.features import require_decoding + +from ..utils import require_trimesh + + +def test_mesh_instantiation(): + mesh = Mesh() + assert mesh.id is None + assert mesh.pa_type == pa.struct({"bytes": pa.binary(), "path": pa.string()}) + assert mesh._type == "Mesh" + + +def test_mesh_feature_type_to_arrow(): + features = Features({"mesh": Mesh()}) + assert features.arrow_schema == pa.schema({"mesh": Mesh().pa_type}) + features = Features({"struct_containing_a_mesh": {"mesh": Mesh()}}) + assert features.arrow_schema == pa.schema({"struct_containing_a_mesh": pa.struct({"mesh": Mesh().pa_type})}) + features = Features({"sequence_of_meshes": Sequence(Mesh())}) + assert features.arrow_schema == pa.schema({"sequence_of_meshes": pa.list_(Mesh().pa_type)}) + + +@pytest.mark.parametrize( + "build_example", + [ + lambda mesh_path: mesh_path, + lambda mesh_path: Path(mesh_path), + lambda mesh_path: open(mesh_path, "rb").read(), + lambda mesh_path: {"path": mesh_path}, + lambda mesh_path: {"path": mesh_path, "bytes": None}, + lambda mesh_path: {"path": mesh_path, "bytes": open(mesh_path, "rb").read()}, + lambda mesh_path: {"path": None, "bytes": open(mesh_path, "rb").read()}, + lambda mesh_path: {"bytes": open(mesh_path, "rb").read()}, + ], +) +def test_mesh_feature_encode_example(mesh_file, build_example): + mesh = Mesh() + encoded_example = mesh.encode_example(build_example(mesh_file)) + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["bytes"] is not None or encoded_example["path"] is not None + + +@require_trimesh +def test_mesh_decode_example(mesh_file): + import trimesh + + mesh = Mesh() + with open(mesh_file, "rb") as f: + mesh_bytes = f.read() + + decoded_example = mesh.decode_example({"path": mesh_file, "bytes": None}) + assert isinstance(decoded_example, (trimesh.Trimesh, trimesh.Scene)) + + decoded_example = mesh.decode_example({"path": mesh_file, "bytes": mesh_bytes}) + assert isinstance(decoded_example, (trimesh.Trimesh, trimesh.Scene)) + + with pytest.raises(ValueError, match="requires a 'path' value"): + mesh.decode_example({"path": None, "bytes": mesh_bytes}) + + with pytest.raises(RuntimeError): + Mesh(decode=False).decode_example({"path": mesh_file, "bytes": None}) + + +@require_trimesh +def test_dataset_with_mesh_feature(mesh_file): + import trimesh + + data = {"mesh": [mesh_file]} + features = Features({"mesh": Mesh()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"mesh"} + assert isinstance(item["mesh"], (trimesh.Trimesh, trimesh.Scene)) + + batch = dset[:1] + assert len(batch) == 1 + assert batch.keys() == {"mesh"} + assert isinstance(batch["mesh"], list) + assert isinstance(batch["mesh"][0], (trimesh.Trimesh, trimesh.Scene)) + + column = dset["mesh"] + assert len(column) == 1 + assert isinstance(column, Column) + assert isinstance(column[0], (trimesh.Trimesh, trimesh.Scene)) + + +def test_dataset_with_mesh_feature_decode_false(mesh_file): + data = {"mesh": [mesh_file]} + features = Features({"mesh": Mesh(decode=False)}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"mesh"} + assert isinstance(item["mesh"], dict) + assert item["mesh"]["path"] == mesh_file + + +@require_trimesh +def test_dataset_cast_to_mesh_features(mesh_file): + import trimesh + + data = {"mesh": [mesh_file]} + dset = Dataset.from_dict(data) + dset = dset.cast(Features({"mesh": Mesh()})) + item = dset[0] + assert isinstance(item["mesh"], (trimesh.Trimesh, trimesh.Scene)) + + +def test_dataset_concatenate_mesh_features(mesh_file): + data1 = {"mesh": [mesh_file]} + dset1 = Dataset.from_dict(data1, features=Features({"mesh": Mesh(decode=False)})) + with open(mesh_file, "rb") as f: + data2 = {"mesh": [{"bytes": f.read()}]} + dset2 = Dataset.from_dict(data2, features=Features({"mesh": Mesh(decode=False)})) + concatenated_dataset = concatenate_datasets([dset1, dset2]) + assert len(concatenated_dataset) == 2 + assert concatenated_dataset[0]["mesh"]["path"] == dset1[0]["mesh"]["path"] + assert concatenated_dataset[1]["mesh"]["bytes"] == dset2[0]["mesh"]["bytes"] + + +@require_trimesh +def test_mesh_feature_encode_trimesh_object(): + import trimesh + + mesh = trimesh.creation.box() + encoded_example = Mesh().encode_example(mesh) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["path"] == "mesh.glb" + assert encoded_example["bytes"] is not None + decoded_example = Mesh().decode_example(encoded_example) + assert isinstance(decoded_example, trimesh.Scene) + + +def test_require_decoding(): + assert require_decoding(Mesh()) + + +def test_mesh_embed_storage(mesh_file): + features = Features({"mesh": Mesh()}) + + with open(mesh_file, "rb") as f: + content = f.read() + + # Test bytes are embedded + storage = features["mesh"].embed_storage(pa.array([{"path": mesh_file, "bytes": None}])) + assert storage.to_pylist() == [{"path": os.path.basename(mesh_file), "bytes": content}] diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index a784f98d93d..a3dee4b5cd2 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -648,3 +648,18 @@ def data_dir_with_hidden_files(tmp_path_factory): f.write("bar\n" * 10) return data_dir + + +@pytest.fixture(scope="session") +def mesh_file(): + return os.path.join("tests", "features", "data", "test_mesh_glb.glb") + + +@pytest.fixture(scope="session") +def mesh_file_ply(): + return os.path.join("tests", "features", "data", "test_mesh_ply.ply") + + +@pytest.fixture(scope="session") +def mesh_file_stl(): + return os.path.join("tests", "features", "data", "test_mesh_stl.stl") diff --git a/tests/packaged_modules/test_meshfolder.py b/tests/packaged_modules/test_meshfolder.py new file mode 100644 index 00000000000..b2e60faeec8 --- /dev/null +++ b/tests/packaged_modules/test_meshfolder.py @@ -0,0 +1,86 @@ +import shutil +import textwrap + +import pytest + +from datasets import ClassLabel, Features, Mesh +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesDict, get_data_patterns +from datasets.packaged_modules.meshfolder.meshfolder import MeshFolder, MeshFolderConfig + + +@pytest.fixture +def cache_dir(tmp_path): + return str(tmp_path / "meshfolder_cache_dir") + + +@pytest.fixture +def data_files_with_labels_no_metadata(tmp_path, mesh_file): + data_dir = tmp_path / "data_files_with_labels_no_metadata" + data_dir.mkdir(parents=True, exist_ok=True) + subdir_class_0 = data_dir / "chair" + subdir_class_0.mkdir(parents=True, exist_ok=True) + subdir_class_1 = data_dir / "table" + subdir_class_1.mkdir(parents=True, exist_ok=True) + + mesh_filename = subdir_class_0 / "mesh_chair.glb" + shutil.copyfile(mesh_file, mesh_filename) + mesh_filename2 = subdir_class_1 / "mesh_table.glb" + shutil.copyfile(mesh_file, mesh_filename2) + + data_files_with_labels_no_metadata = DataFilesDict.from_patterns( + get_data_patterns(str(data_dir)), data_dir.as_posix() + ) + + return data_files_with_labels_no_metadata + + +@pytest.fixture +def mesh_file_with_metadata(tmp_path, mesh_file): + mesh_filename = tmp_path / "mesh_file.glb" + shutil.copyfile(mesh_file, mesh_filename) + mesh_metadata_filename = tmp_path / "metadata.jsonl" + mesh_metadata = textwrap.dedent( + """\ + {"file_name": "mesh_file.glb", "text": "Mesh description"} + """ + ) + with open(mesh_metadata_filename, "w", encoding="utf-8") as f: + f.write(mesh_metadata) + return str(mesh_filename), str(mesh_metadata_filename) + + +def test_meshfolder_config_and_extensions(): + # Verify extensions + assert MeshFolder.EXTENSIONS == [".glb", ".ply", ".stl"] + assert MeshFolder.BASE_FEATURE == Mesh + assert MeshFolder.BASE_COLUMN_NAME == "mesh" + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = MeshFolderConfig(name="name-with-*-invalid-character") + + +def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir): + # there are no metadata.jsonl files in this test case + meshfolder = MeshFolder(data_files=data_files_with_labels_no_metadata, cache_dir=cache_dir, drop_labels=False) + meshfolder.download_and_prepare() + assert meshfolder.info.features == Features({"mesh": Mesh(), "label": ClassLabel(names=["chair", "table"])}) + dataset = list(meshfolder.as_dataset()["train"]) + label_feature = meshfolder.info.features["label"] + + assert dataset[0]["label"] == label_feature._str2int["chair"] + assert dataset[1]["label"] == label_feature._str2int["table"] + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_data_files_with_metadata_and_single_split(streaming, cache_dir, mesh_file_with_metadata): + mesh_file, mesh_metadata_file = mesh_file_with_metadata + meshfolder = MeshFolder(data_files={"train": [mesh_file, mesh_metadata_file]}, cache_dir=cache_dir) + meshfolder.download_and_prepare() + dataset = meshfolder.as_streaming_dataset()["train"] if streaming else meshfolder.as_dataset()["train"] + + item = next(iter(dataset)) if streaming else dataset[0] + assert "mesh" in item + assert item["text"] == "Mesh description" diff --git a/tests/packaged_modules/test_webdataset.py b/tests/packaged_modules/test_webdataset.py index 33b63d91a86..7eef6db566a 100644 --- a/tests/packaged_modules/test_webdataset.py +++ b/tests/packaged_modules/test_webdataset.py @@ -295,7 +295,7 @@ def test_video_webdataset(video_wds_file): assert len(examples) == 3 assert isinstance(examples[0]["json"], dict) assert isinstance(examples[0]["json"]["caption"], str) - assert isinstance(examples[0]["mov"], bytes) + assert isinstance(examples[0]["mov"], dict) def test_webdataset_errors_on_bad_file(bad_wds_file): diff --git a/tests/utils.py b/tests/utils.py index 88bff466297..48de31a5959 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -249,6 +249,18 @@ def require_nibabel(test_case): return test_case +def require_trimesh(test_case): + """ + Decorator marking a test that requires trimesh. + + These tests are skipped when trimesh isn't installed. + + """ + if not config.TRIMESH_AVAILABLE: + test_case = unittest.skip("test requires trimesh")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.