diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 86f68a90d79..4db04d0165a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -121,6 +121,8 @@ jobs: run: pip install --upgrade uv - name: Install dependencies run: uv pip install --system "datasets[tests] @ ." + - name: Install tsfile (py3.14 only) + run: uv pip install --system "tsfile>=2.3.0" - name: Print dependencies run: uv pip list - name: Test with pytest diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index cc6b7195fe2..b1e93647ed8 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -101,6 +101,10 @@ - local: tabular_load title: Load tabular data title: "Tabular" + - sections: + - local: tsfile_load + title: Load TsFile data + title: "Time-series" - sections: - local: share title: Share diff --git a/docs/source/about_dataset_load.mdx b/docs/source/about_dataset_load.mdx index a5ac45077e6..dac34dacc11 100644 --- a/docs/source/about_dataset_load.mdx +++ b/docs/source/about_dataset_load.mdx @@ -14,7 +14,7 @@ A dataset is a directory that contains: The [`load_dataset`] function fetches the requested dataset locally or from the Hugging Face Hub. The Hub is a central repository where all the Hugging Face datasets and models are stored. -If the dataset only contains data files, then [`load_dataset`] automatically infers how to load the data files from their extensions (json, csv, parquet, txt, etc.). +If the dataset only contains data files, then [`load_dataset`] automatically infers how to load the data files from their extensions (json, csv, parquet, tsfile, txt, etc.). Under the hood, πŸ€— Datasets will use an appropriate [`DatasetBuilder`] based on the data files format. There exist one builder per data file format in πŸ€— Datasets: * [`datasets.packaged_modules.text.Text`] for text @@ -23,6 +23,7 @@ Under the hood, πŸ€— Datasets will use an appropriate [`DatasetBuilder`] based o * [`datasets.packaged_modules.parquet.Parquet`] for Parquet * [`datasets.packaged_modules.arrow.Arrow`] for Arrow (streaming file format) * [`datasets.packaged_modules.sql.Sql`] for SQL databases +* [`datasets.packaged_modules.tsfile.TsFile`] for TsFile (time-series data) * [`datasets.packaged_modules.imagefolder.ImageFolder`] for image folders * [`datasets.packaged_modules.audiofolder.AudioFolder`] for audio folders diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index d18a33ac071..6c2bcaf7367 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -68,7 +68,7 @@ The `split` parameter can also map a data file to a specific split: ## Local and remote files -Datasets can be loaded from local files stored on your computer and from remote files. The datasets are most likely stored as a `csv`, `json`, `txt` or `parquet` file. The [`load_dataset`] function can load each of these file types. +Datasets can be loaded from local files stored on your computer and from remote files. The datasets are most likely stored as a `csv`, `json`, `txt`, `parquet` or `tsfile` file. The [`load_dataset`] function can load each of these file types. ### CSV @@ -200,6 +200,34 @@ This will return the image caption and the image bytes in a single request. Note that the HDF5 loader assumes that the file has "tabular" structure, i.e. that all datasets in the file have (the same number of) rows on their first dimension. +### TsFile + +[TsFile](https://tsfile.apache.org/) is a columnar file format designed for time-series data, used as the native storage layer of [Apache IoTDB](https://iotdb.apache.org/). It natively represents timestamps, device tags, and measurement fields, and maintains an internal time index that enables efficient time-range pruning. + +Each row in the resulting dataset corresponds to one **device** (identified by its TAG columns); the `time` column and every FIELD column are list columns containing that device's full time series, sorted in ascending time order. + +To load a TsFile: + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset("tsfile", data_files="my_data.tsfile") +``` + +Filter by time range β€” bounds are pushed down to TsFile's internal time index and accept `int` epochs, `datetime`, `date`, ISO-8601 strings, or `pyarrow` timestamp scalars: + +```py +>>> from datetime import datetime +>>> dataset = load_dataset( +... "tsfile", +... data_files="my_data.tsfile", +... start_time=datetime(2023, 11, 14), +... end_time=datetime(2023, 11, 15), +... ) +``` + +> [!TIP] +> For more details, check out the [how to load TsFile data](tsfile_load) guide. + ### SQL Read database contents with [`~datasets.Dataset.from_sql`] by specifying the URI to connect to your database. You can read both table names and queries: diff --git a/docs/source/package_reference/loading_methods.mdx b/docs/source/package_reference/loading_methods.mdx index 4792d1b88f7..4dc1ffa5d52 100644 --- a/docs/source/package_reference/loading_methods.mdx +++ b/docs/source/package_reference/loading_methods.mdx @@ -97,6 +97,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t") [[autodoc]] datasets.packaged_modules.hdf5.HDF5 +### TsFile + +[[autodoc]] datasets.packaged_modules.tsfile.TsFileConfig + +[[autodoc]] datasets.packaged_modules.tsfile.TsFile + ### Pdf [[autodoc]] datasets.packaged_modules.pdffolder.PdfFolderConfig diff --git a/docs/source/tsfile_load.mdx b/docs/source/tsfile_load.mdx new file mode 100644 index 00000000000..6c37eae29fd --- /dev/null +++ b/docs/source/tsfile_load.mdx @@ -0,0 +1,172 @@ +# Load TsFile data + +[TsFile](https://tsfile.apache.org/) is a columnar file format designed for time-series data and used as the native storage layer of [Apache IoTDB](https://iotdb.apache.org/). Compared with general-purpose columnar formats such as Parquet, TsFile is aware of the time-series data model (timestamps, devices, and measurements) and maintains an internal time index that enables time-range pruning without scanning entire files. + +This loader is provided as a separate guide because it does not follow the usual one-row-per-record tabular convention: each output row corresponds to one *device*, and per-measurement values are returned as Arrow `list<...>` columns. The mapping is described in detail below. + +## Installation + +The loader depends on the [`tsfile`](https://pypi.org/project/tsfile/) Python package: + +```bash +pip install "tsfile>=2.3.0" +``` + +## Data model and output layout + +The loader follows the TsFile *table model*. Each table column is one of: + +- **TAG** β€” a string-typed identifier; the tuple of TAG values uniquely identifies a *device* (i.e. a single time-series source). +- **FIELD** β€” a measurement whose value evolves over time. +- **TIME** β€” the timestamp column, named `time` by default. + +The loader emits one dataset row per device. Within a row, the `time` column and every FIELD column are Arrow `list<...>` columns containing that device's full time series, sorted in ascending time order. TAG columns appear as scalar `string` columns. + +Concretely, the output schema has the form: + +```text +: string +: string # one column per TAG +... +time: list +: list # one column per FIELD +: list +... +``` + +When the same device appears in multiple input files of a split, its per-file chunks are concatenated and sorted by timestamp before being emitted as a single row. Duplicate timestamps for the same device raise `ValueError`. + +## Basic usage + +Load a single TsFile: + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset("tsfile", data_files="my_data.tsfile") +``` + +Map files to splits explicitly: + +```py +>>> dataset = load_dataset( +... "tsfile", +... data_files={"train": "train_data.tsfile", "test": "test_data.tsfile"}, +... ) +``` + +## Example dataset on the Hub + +A ready-to-use example is available at [`tsfile/lotsa_data`](https://huggingface.co/datasets/tsfile/lotsa_data). Because `.tsfile` files are recognized automatically, you can load it by repository id without specifying `data_files`: + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset("tsfile/lotsa_data") +>>> dataset +DatasetDict({ + train: Dataset({ + features: ['timeseries_id', 'time', 'value'], + num_rows: 91 + }) +}) +``` + +Each row is one device. The TAG column `timeseries_id` identifies the device, while `time` and `value` are `list<...>` columns holding that device's full series: + +```py +>>> row = dataset["train"][0] +>>> row["timeseries_id"] +'Bear_assembly_Angel' +>>> len(row["time"]), len(row["value"]) +(8760, 8760) +>>> row["time"][:3] +[datetime.datetime(2017, 1, 1, 0, 0), datetime.datetime(2017, 1, 1, 1, 0), datetime.datetime(2017, 1, 1, 2, 0)] +``` + +## Selecting a table + +A TsFile can contain multiple tables. When `table_name` is omitted, the first table found in the first valid file is used. Lookups are case-insensitive. + +```py +>>> dataset = load_dataset("tsfile", data_files="my_data.tsfile", table_name="sensor_data") +``` + +## Selecting columns + +`columns` restricts the FIELD columns that are read. The TAG columns and the `time` column are always returned because they identify the device and its timeline. Names in `columns` that refer to a TAG or to the `time` column are silently ignored (they are emitted as usual, just once); names that match a field absent from every file become all-null list columns. + +```py +>>> dataset = load_dataset( +... "tsfile", +... data_files="my_data.tsfile", +... columns=["temperature", "humidity"], +... ) +``` + +## Filtering by time range + +`start_time` and `end_time` are inclusive bounds; either may be omitted. The bounds are pushed down to TsFile's internal time index, so only the matching data blocks are read from disk. Both bounds accept any of: + +- `int` β€” raw epoch in `timestamp_unit` (default milliseconds); +- `datetime.datetime` β€” naive values are interpreted as UTC, tz-aware values are converted to UTC; +- `datetime.date`; +- ISO-8601 `str`, e.g. `"2024-01-01T00:00:00"`; +- `pyarrow.TimestampScalar`. + +```py +>>> from datetime import datetime +>>> dataset = load_dataset( +... "tsfile", +... data_files="my_data.tsfile", +... start_time=datetime(2023, 11, 14), +... end_time="2023-11-15T00:00:00", +... ) +``` + +## Schema evolution across files + +When different files expose different columns β€” for example a new sensor field is introduced later β€” the loader takes the union of all FIELD columns and fills missing values with nulls. Numeric FIELD types are promoted following IoTDB's widening rules (`INT32 β†’ INT64 β†’ DOUBLE`, `INT32 β†’ FLOAT β†’ DOUBLE`). + +```py +>>> dataset = load_dataset("tsfile", data_files=["day1.tsfile", "day2.tsfile"]) +``` + +## Handling unreadable files + +By default, an unreadable or non-TsFile input raises an error. Set `on_bad_files` to `"warn"` to log and continue, or `"skip"` to silently drop the file. + +```py +>>> dataset = load_dataset("tsfile", data_files="data/*.tsfile", on_bad_files="skip") +``` + +## Timestamp unit and time zone + +`timestamp_unit` (default `"ms"`, matching IoTDB) controls the resolution of the `time` column and the interpretation of integer time bounds. `timestamp_tz` attaches a time zone to the Arrow timestamp type; `None` (the default) yields a timezone-naive type. + +```py +>>> dataset = load_dataset( +... "tsfile", +... data_files="my_data.tsfile", +... timestamp_unit="us", +... timestamp_tz="UTC", +... ) +``` + +## Memory and batching + +Two parameters control memory usage: + +- `input_batch_size` (default `65_536`) β€” maximum number of rows fetched per Arrow batch from `TsFileReader.query_table`. Bounds peak memory while streaming a single device. +- `output_batch_size` (default `32`) β€” number of devices packed into each Arrow record batch yielded to the writer. Smaller values give more responsive progress reporting; larger values reduce per-batch overhead. + +```py +>>> dataset = load_dataset( +... "tsfile", +... data_files="large_data.tsfile", +... input_batch_size=32_768, +... output_batch_size=128, +... ) +``` + +Peak memory is bounded by the payload of a single device across the split, not by the size of the split as a whole. + +See [`~datasets.packaged_modules.tsfile.TsFileConfig`] for the full list of parameters. diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 7ff455cc0da..0de63b9c785 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -22,6 +22,7 @@ from .pdffolder import pdffolder from .sql import sql from .text import text +from .tsfile import tsfile from .videofolder import videofolder from .webdataset import webdataset from .xml import xml @@ -60,6 +61,7 @@ def _hash_python_lines(lines: list[str]) -> str: "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), "eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())), "lance": (lance.__name__, _hash_python_lines(inspect.getsource(lance).splitlines())), + "tsfile": (tsfile.__name__, _hash_python_lines(inspect.getsource(tsfile).splitlines())), "iceberg": (iceberg.__name__, _hash_python_lines(inspect.getsource(iceberg).splitlines())), } @@ -96,6 +98,7 @@ def _hash_python_lines(lines: list[str]) -> str: ".h5": ("hdf5", {}), ".eval": ("eval", {}), ".lance": ("lance", {}), + ".tsfile": ("tsfile", {}), } _EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) diff --git a/src/datasets/packaged_modules/tsfile/__init__.py b/src/datasets/packaged_modules/tsfile/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/tsfile/tsfile.py b/src/datasets/packaged_modules/tsfile/tsfile.py new file mode 100644 index 00000000000..96712b4516e --- /dev/null +++ b/src/datasets/packaged_modules/tsfile/tsfile.py @@ -0,0 +1,773 @@ +"""TsFile (table model) packaged builder β€” per-device wide format. + +Each output row corresponds to a single device (identified by its TAG values). +The ``time`` column and every FIELD column are Arrow ``list<...>`` columns +holding the entire time series for that device. When the same device is +present in multiple TsFiles within a split, its data is merged across files +and the resulting lists are sorted in ascending time order. + +Output schema layout:: + + : string + : string (one column per TAG) + ... + time: list + : list (one column per FIELD) + : list + ... + +Reading model +------------- +Data is fetched **per device** via ``TsFileReader.query_table`` with a +push-down ``tag_filter``. For each split the builder: + +1. Opens every input file once, calls ``get_all_devices`` to enumerate the + ``(tag-tuple) β†’ [files]`` index across all shards. +2. Iterates the index in stable order. For each device, streams Arrow + batches from every contributing file, concatenates and sorts by time, + and emits one wide row. + +Peak memory is bounded by **one device's** total payload across the split, +not by the split's total size. +""" + +from __future__ import annotations + +import datetime as _dt +from dataclasses import dataclass +from typing import Any, Literal, Optional + +import numpy as np +import pyarrow as pa + +import datasets +from datasets.builder import Key +from datasets.table import table_cast +from datasets.utils.tqdm import tqdm + + +logger = datasets.utils.logging.get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Type helpers +# --------------------------------------------------------------------------- + + +def _arrow_type(ts_dtype, *, unit: str, tz: Optional[str]) -> pa.DataType: + """Map a tsfile ``TSDataType`` to its Arrow representation.""" + from tsfile.constants import TSDataType + + return { + TSDataType.BOOLEAN: pa.bool_(), + TSDataType.INT32: pa.int32(), + TSDataType.INT64: pa.int64(), + TSDataType.FLOAT: pa.float32(), + TSDataType.DOUBLE: pa.float64(), + TSDataType.TEXT: pa.string(), + TSDataType.STRING: pa.string(), + TSDataType.TIMESTAMP: pa.timestamp(unit, tz=tz), + TSDataType.DATE: pa.date32(), + TSDataType.BLOB: pa.binary(), + }.get(ts_dtype, pa.string()) + + +def _promote_tsdatatype(a, b): + """Return the widest of two ``TSDataType`` values. + + Mirrors IoTDB's ``ALTER COLUMN ... SET DATA TYPE`` rules: + + - ``INT32 β†’ INT64 β†’ DOUBLE`` + - ``INT32 β†’ FLOAT β†’ DOUBLE`` + + ``INT64`` and ``FLOAT`` cannot widen losslessly into either, so their + join is ``DOUBLE``. Non-numeric or otherwise unrelated pairs raise. + """ + if a == b: + return a + + from tsfile.constants import TSDataType + + table = { + (TSDataType.INT32, TSDataType.INT64): TSDataType.INT64, + (TSDataType.INT32, TSDataType.FLOAT): TSDataType.FLOAT, + (TSDataType.INT32, TSDataType.DOUBLE): TSDataType.DOUBLE, + (TSDataType.INT64, TSDataType.FLOAT): TSDataType.DOUBLE, + (TSDataType.INT64, TSDataType.DOUBLE): TSDataType.DOUBLE, + (TSDataType.FLOAT, TSDataType.DOUBLE): TSDataType.DOUBLE, + } + if (a, b) in table: + return table[(a, b)] + if (b, a) in table: + return table[(b, a)] + raise ValueError( + f"Incompatible column types across files: {a.name} vs {b.name}. " + "Only numeric widening (INT32β†’INT64β†’DOUBLE, INT32β†’FLOATβ†’DOUBLE) is supported." + ) + + +def _to_epoch(value: Any, unit: str) -> int: + """Coerce a timestamp boundary to an integer epoch in ``unit``. + + Accepts ``int`` (raw epoch in ``unit``), ``datetime``/``date``, + ISO-8601 ``str``, or any ``pa.Scalar`` of timestamp type. + """ + if isinstance(value, bool): # bool is a subclass of int; reject explicitly + raise TypeError(f"start_time/end_time must be a timestamp, got bool: {value!r}") + if isinstance(value, int): + return value + try: + # Normalize the various input shapes into something pa.scalar() can absorb + # under a `timestamp[unit]` target type. + if isinstance(value, _dt.datetime): + if value.tzinfo is not None: + value = value.astimezone(_dt.timezone.utc).replace(tzinfo=None) + elif isinstance(value, _dt.date): + value = _dt.datetime(value.year, value.month, value.day) + elif isinstance(value, str): + value = _dt.datetime.fromisoformat(value) + return pa.scalar(value, type=pa.timestamp(unit)).value + except (pa.ArrowInvalid, pa.ArrowTypeError, TypeError, ValueError) as e: + raise TypeError( + f"start_time/end_time must be a datetime, date, pa.TimestampScalar, " + f"ISO-8601 str, or int epoch; got {type(value).__name__}: {value!r}" + ) from e + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +@dataclass +class TsFileConfig(datasets.BuilderConfig): + """BuilderConfig for TsFile (table model) β€” per-device wide format. + + Args: + table_name (`str`, *optional*): + Name of the table to read. When unset, the first table found in + the first valid file is used. Lookups are case-insensitive. + columns (`list[str]`, *optional*): + Subset of FIELD columns to keep. TAG columns and the TIME column + are *always* returned (they identify the device / its timeline + and cannot be excluded). Names that refer to TAG or TIME columns, + or to fields absent from every file, resolve quietly: TAGs/TIME + are emitted as usual, and never-seen fields become all-null list + columns. When unset, all FIELDs are returned. + start_time, end_time (`datetime`, `date`, `pa.TimestampScalar`, ISO-8601 `str`, or `int`, *optional*): + Inclusive timestamp range. Either bound may be omitted. + ``datetime`` values are taken in their own tz (UTC if naive); + ``int`` is interpreted as a raw epoch in ``timestamp_unit``. + input_batch_size (`int`, *optional*, defaults to 65_536): + Maximum number of rows fetched per Arrow batch from + ``TsFileReader.query_table``. Controls peak memory while + streaming a single device. + output_batch_size (`int`, *optional*, defaults to 32): + Number of devices (output dataset rows) packed into each Arrow + record batch yielded to the writer. Also the granularity at + which the dataset progress bar advances; smaller values give + more responsive feedback on slow per-device reads, larger ones + reduce per-batch overhead. + features (`Features`, *optional*): + Final Features schema. When provided, the metadata scan over + input files is skipped. + on_bad_files (`Literal["error", "warn", "skip"]`, *optional*, defaults to "error"): + What to do if a file cannot be opened or lacks the requested table. + timestamp_unit (`Literal["s", "ms", "us", "ns"]`, *optional*, defaults to "ms"): + Time unit for the timestamp column. IoTDB defaults to milliseconds. + timestamp_tz (`str`, *optional*): + Time zone for the timestamp column. ``None`` means timezone-naive. + """ + + table_name: Optional[str] = None + columns: Optional[list[str]] = None + start_time: Optional[Any] = None + end_time: Optional[Any] = None + input_batch_size: int = 65_536 + output_batch_size: int = 32 + features: Optional[datasets.Features] = None + on_bad_files: Literal["error", "warn", "skip"] = "error" + timestamp_unit: Literal["s", "ms", "us", "ns"] = "ms" + timestamp_tz: Optional[str] = None + + def __post_init__(self): + super().__post_init__() + if self.input_batch_size is None or self.input_batch_size <= 0: + raise ValueError(f"`input_batch_size` must be a positive integer, got {self.input_batch_size}") + if self.output_batch_size is None or self.output_batch_size <= 0: + raise ValueError(f"`output_batch_size` must be a positive integer, got {self.output_batch_size}") + if self.columns is not None and len(self.columns) == 0: + raise ValueError("`columns` must be a non-empty list when provided.") + if self.timestamp_unit not in ("s", "ms", "us", "ns"): + raise ValueError(f"`timestamp_unit` must be one of 's', 'ms', 'us', 'ns', got {self.timestamp_unit!r}") + if self.on_bad_files not in ("error", "warn", "skip"): + raise ValueError(f"`on_bad_files` must be one of 'error', 'warn', 'skip', got {self.on_bad_files!r}") + if self.start_time is not None: + self.start_time = _to_epoch(self.start_time, self.timestamp_unit) + if self.end_time is not None: + self.end_time = _to_epoch(self.end_time, self.timestamp_unit) + + +# --------------------------------------------------------------------------- +# Internal sentinels +# --------------------------------------------------------------------------- + + +class _SkipSplit(Exception): + """Raised internally to abort emitting a split entirely.""" + + +class _MissingTableError(ValueError): + def __init__(self, table: Optional[str], available): + super().__init__(f"Table {table!r} not found in file. Available tables: {available}") + + +_TSFILE_MAGIC = b"TsFile" + + +# --------------------------------------------------------------------------- +# Builder +# --------------------------------------------------------------------------- + + +class TsFile(datasets.ArrowBasedBuilder): + """Per-device wide-format builder for TsFile (table model).""" + + BUILDER_CONFIG_CLASS = TsFileConfig + + # ----- builder hooks ------------------------------------------------ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._table: Optional[str] = None + self._time_col: str = "time" + self._tag_cols: list[str] = [] + self._field_inner: dict[str, pa.DataType] = {} + self._requested_fields: Optional[list[str]] = None # lowercased + + def _info(self): + if ( + self.config.columns is not None + and self.config.features is not None + and not set(self.config.columns).issubset(set(self.config.features)) + ): + raise ValueError( + "Every entry in `columns` must also appear in `features`, but got " + f"columns={self.config.columns} and features={list(self.config.features)}" + ) + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + dl_manager.download_config.extract_on_the_fly = True + data_files = dl_manager.download(self.config.data_files) + + # Lowercase user-facing names to match tsfile's case-insensitive convention. + self._table = self.config.table_name.lower() if self.config.table_name else None + self._requested_fields = [c.lower() for c in self.config.columns] if self.config.columns else None + + all_files = [f for files in data_files.values() for f in files] + scan = self._scan_metadata(all_files) + if scan is None: + raise ValueError( + "Could not infer schema from any of the provided files. " + "Set `features` explicitly or check the input files." + ) + self._table = scan["table"] + self._time_col = scan["time_col"] + self._tag_cols = scan["tag_cols"] + self._field_inner = scan["field_inner"] + + if self.info.features is None: + self.info.features = self._build_features() + + return [ + datasets.SplitGenerator(name=split, gen_kwargs={"files": list(files)}) + for split, files in data_files.items() + ] + + def _generate_shards(self, files): + yield from files + + def _generate_tables(self, files): + target_schema = self.info.features.arrow_schema + try: + yield from self._fold_split(files, target_schema) + except _SkipSplit: + return + + # ----- metadata scan ------------------------------------------------ + + def _scan_metadata(self, files) -> Optional[dict]: + """Walk every file and unify table name, TAG columns, FIELD types.""" + from tsfile.constants import TIME_COLUMN, ColumnCategory + + wanted_table = self._table + wanted_fields = set(self._requested_fields) if self._requested_fields is not None else None + + table: Optional[str] = wanted_table + time_col: Optional[str] = None + tag_cols: list[str] = [] + tag_seen: set[str] = set() + # Per-field widest TSDataType seen so far (we map to Arrow at the end). + field_widest: dict = {} + + for file in files: + try: + with self._open_reader(file) as reader: + schemas = self._schemas_by_lc(reader) + self._require_table_model(file, schemas) + if table is None: + table = next(iter(schemas)) + if table not in schemas: + raise _MissingTableError(table, list(schemas)) + for col in schemas[table].get_columns(): + name = col.get_column_name() + cat = col.get_category() + ts_dtype = col.get_data_type() + if cat == ColumnCategory.TIME: + time_col = name + elif cat == ColumnCategory.TAG: + if name not in tag_seen: + tag_seen.add(name) + tag_cols.append(name) + else: # FIELD + if wanted_fields is not None and name not in wanted_fields: + continue + prev = field_widest.get(name) + field_widest[name] = ts_dtype if prev is None else _promote_tsdatatype(prev, ts_dtype) + except Exception as e: + if self._should_reraise(file, e): + raise + continue + + if table is None: + return None + + unit = self.config.timestamp_unit + tz = self.config.timestamp_tz + + if self._requested_fields is not None: + # Honor user order; silently drop names that turned out to be TAGs + # or the TIME column (TAGs are emitted as their own scalar columns + # and TIME is always emitted as a list column β€” neither may also + # appear as a list-typed field, which would collide on schema name). + reserved = tag_seen | {time_col} if time_col is not None else tag_seen + field_inner: dict[str, pa.DataType] = {} + for name in self._requested_fields: + if name in reserved: + continue + ts_dtype = field_widest.get(name) + if ts_dtype is not None: + field_inner[name] = _arrow_type(ts_dtype, unit=unit, tz=tz) + else: + # Field never appeared in any file β€” keep as a nullable + # float64 list, fully filled with nulls at read time. + field_inner[name] = pa.float64() + else: + field_inner = {n: _arrow_type(d, unit=unit, tz=tz) for n, d in field_widest.items()} + + return { + "table": table, + "time_col": time_col or TIME_COLUMN, + "tag_cols": tag_cols, + "field_inner": field_inner, + } + + def _build_features(self) -> datasets.Features: + unit = self.config.timestamp_unit + tz = self.config.timestamp_tz + fields: list[pa.Field] = [pa.field(t, pa.string()) for t in self._tag_cols] + fields.append(pa.field(self._time_col, pa.list_(pa.timestamp(unit, tz=tz)))) + for name, inner in self._field_inner.items(): + fields.append(pa.field(name, pa.list_(inner))) + return datasets.Features.from_arrow_schema(pa.schema(fields)) + + # ----- per-split folding ------------------------------------------- + + def _fold_split(self, files, target_schema: pa.Schema): + """Stream every device in this split via per-device tag-filter pushdown. + + Open one ``TsFileReader`` per file, build a cross-file device index + keyed by ``(tag-tuple)``, then iterate devices in stable order. For + each device, ``query_table(tag_filter=...)`` reads only that device's + rows from each contributing file, so peak memory is bounded by one + device's payload across the split β€” never the split's total size. + """ + if self._table is None: + raise _SkipSplit + + readers: dict[str, Any] = {} + try: + for file in files: + try: + readers[file] = self._open_reader(file) + except Exception as e: + if self._should_reraise(file, e): + raise + continue + + device_index, file_meta = self._build_device_index(readers) + if not device_index: + return + + yield from self._iter_device_batches(device_index, file_meta, readers, target_schema) + finally: + for reader in readers.values(): + try: + reader.close() + except Exception: + pass + + def _build_device_index(self, readers: dict): + """Walk every open reader and build the cross-file device index. + + Returns ``(device_index, file_meta)``: + + - ``device_index``: list of ``(device_key, [file_path, ...])`` pairs + in stable first-seen order. ``device_key`` is a tuple aligned to + ``self._tag_cols`` (the unified tag-column order). + - ``file_meta``: maps each readable file to its per-file context + (``tag_cols``, ``field_cols``, ``time_col``). + """ + from tsfile.constants import ColumnCategory + + device_to_files: dict[tuple, list[str]] = {} + device_order: list[tuple] = [] + file_meta: dict[str, dict] = {} + # ``self._table`` was lowercased either by user-input normalization in + # ``_split_generators`` or by ``_schemas_by_lc`` during auto-detect. + table_lc = self._table + + files_iter = tqdm( + readers.items(), + total=len(readers), + desc="Indexing TsFile devices", + unit="file", + ) + + for file, reader in files_iter: + try: + schemas = self._schemas_by_lc(reader) + self._require_table_model(file, schemas) + if table_lc not in schemas: + raise _MissingTableError(table_lc, list(schemas)) + schema = schemas[table_lc] + + file_tag_cols: list[str] = [] + file_field_cols: set[str] = set() + time_col = self._time_col + for col in schema.get_columns(): + name = col.get_column_name() + cat = col.get_category() + if cat == ColumnCategory.TIME: + time_col = name + elif cat == ColumnCategory.TAG: + file_tag_cols.append(name) + elif cat == ColumnCategory.FIELD: + file_field_cols.add(name) + + file_meta[file] = { + "tag_cols": file_tag_cols, + "field_cols": file_field_cols, + "time_col": time_col, + } + + for device in reader.get_all_devices(): + if device.table_name is None or device.table_name.lower() != table_lc: + continue + file_tag_values = list(device.segments[1 : 1 + len(file_tag_cols)]) + file_tag_dict = dict(zip(file_tag_cols, file_tag_values)) + unified_key = tuple(file_tag_dict.get(c) for c in self._tag_cols) + if any(v is None for v in unified_key): + raise ValueError( + f"Device in file '{file}' has missing tag values: " + f"{dict(zip(self._tag_cols, unified_key))}. " + "Schema-evolution devices with NULL tag values are not " + "supported because tsfile lacks an IS NULL tag filter." + ) + if unified_key not in device_to_files: + device_to_files[unified_key] = [] + device_order.append(unified_key) + device_to_files[unified_key].append(file) + except Exception as e: + if self._should_reraise(file, e): + raise + file_meta.pop(file, None) + continue + + device_index = [(key, device_to_files[key]) for key in device_order] + return device_index, file_meta + + def _iter_device_batches(self, device_index, file_meta, readers, target_schema: pa.Schema): + """Materialize devices in order and emit packed Arrow tables.""" + field_names = list(self._field_inner.keys()) + rows: list[dict] = [] + batch_idx = 0 + for device_key, contributing_files in device_index: + time_chunks: list[np.ndarray] = [] + field_chunks: dict[str, list] = {f: [] for f in field_names} + for file in contributing_files: + if file not in file_meta: + continue + try: + ts_arr, vals = self._read_device_from_file(readers[file], file_meta[file], device_key) + except Exception as e: + if self._should_reraise(file, e): + raise + continue + if len(ts_arr) == 0: + continue + time_chunks.append(ts_arr) + for f in field_names: + field_chunks[f].append(vals.get(f)) # None β†’ all-null contribution + if not time_chunks: + continue # device produced no rows in time range β†’ skip + + row = self._finalize_device(device_key, time_chunks, field_chunks, field_names) + rows.append(row) + if len(rows) >= self.config.output_batch_size: + yield Key(0, batch_idx), self._rows_to_table(rows, target_schema) + rows = [] + batch_idx += 1 + if rows: + yield Key(0, batch_idx), self._rows_to_table(rows, target_schema) + + def _read_device_from_file(self, reader, meta: dict, device_key: tuple) -> tuple: + """Stream one device's rows from one file via ``query_table`` pushdown. + + Returns ``(timestamps, {field_name: values})``. The dict only includes + field columns that this file owns *and* that the builder requested; + callers fill missing fields with all-null contributions. + """ + from tsfile import tag_eq + + file_tag_cols: list[str] = meta["tag_cols"] + file_field_cols: set[str] = meta["field_cols"] + time_col: str = meta["time_col"] + + # Build the tag filter only over this file's tag columns. The unified + # device key carries one value per builder tag; map back by name. + unified_to_value = dict(zip(self._tag_cols, device_key)) + tag_filter = None + for c in file_tag_cols: + v = unified_to_value.get(c) + if v is None: + # Caught earlier in _build_device_index; defensive guard. + return np.array([], dtype=np.int64), {} + expr = tag_eq(c, str(v)) + tag_filter = expr if tag_filter is None else tag_filter & expr + + # Project: requested fields ∩ this file's fields. ``query_table`` + # always returns the time column, but it requires at least one + # non-time column β€” fall back to any owned field if the user's + # selection has nothing in this file. + requested = list(self._field_inner.keys()) + fields_to_query = [f for f in requested if f in file_field_cols] + fallback_only = False + if not fields_to_query: + if not file_field_cols: + return np.array([], dtype=np.int64), {} + fields_to_query = [next(iter(file_field_cols))] + fallback_only = True + + kwargs: dict = {"tag_filter": tag_filter, "batch_size": self.config.input_batch_size} + if self.config.start_time is not None: + kwargs["start_time"] = self.config.start_time + if self.config.end_time is not None: + kwargs["end_time"] = self.config.end_time + + ts_parts: list[np.ndarray] = [] + field_parts: dict[str, list] = {f: [] for f in fields_to_query} + with reader.query_table(self._table, fields_to_query, **kwargs) as rs: + while True: + batch = rs.read_arrow_batch() + if batch is None: + break + if batch.num_rows == 0: + continue + ts_parts.append(np.asarray(batch.column(time_col).to_numpy(), dtype=np.int64)) + for f in fields_to_query: + col = batch.column(f) + # tsfile's arrow reader tags TIMESTAMP / DATE field columns + # with a fixed unit (e.g. ``timestamp[ns]``) regardless of + # the value's original write unit. Reinterpret as raw + # int64/int32 ticks so the downstream + # ``pa.array(type=timestamp[])`` treats them as + # ticks in the unit declared by our schema, instead of + # cross-unit casting (which would raise on data loss). + if pa.types.is_timestamp(col.type): + field_parts[f].append(col.cast(pa.int64()).to_numpy(zero_copy_only=False)) + elif pa.types.is_date(col.type): + field_parts[f].append(col.cast(pa.int32()).to_numpy(zero_copy_only=False)) + else: + field_parts[f].append(col.to_numpy(zero_copy_only=False)) + + if not ts_parts: + return np.array([], dtype=np.int64), {} + ts_full = np.concatenate(ts_parts) if len(ts_parts) > 1 else ts_parts[0] + vals_full = {f: (np.concatenate(parts) if len(parts) > 1 else parts[0]) for f, parts in field_parts.items()} + + # Defensive boundary mask: native query paths may emit rows just + # outside the requested window in some chunk-boundary cases. + if self.config.start_time is not None or self.config.end_time is not None: + lo = self.config.start_time if self.config.start_time is not None else np.iinfo(np.int64).min + hi = self.config.end_time if self.config.end_time is not None else np.iinfo(np.int64).max + mask = (ts_full >= lo) & (ts_full <= hi) + if not mask.all(): + ts_full = ts_full[mask] + vals_full = {f: arr[mask] for f, arr in vals_full.items()} + + # Drop the fallback "pick one" column from the user-visible payload. + if fallback_only: + vals_full = {} + return ts_full, vals_full + + def _finalize_device( + self, + device_key: tuple, + time_chunks: list, + field_chunks: dict, + field_names: list[str], + ) -> dict: + """Concatenate per-file chunks, sort by time, and return one row. + + Raises ``ValueError`` if the same timestamp appears more than once + for a device (within or across files) β€” tsfile's per-device timeline + is required to be unique-by-timestamp. + """ + time_arr = np.concatenate(time_chunks) if time_chunks else np.array([], dtype=np.int64) + n_total = len(time_arr) + + if n_total > 0: + sort_idx = np.argsort(time_arr, kind="stable") + time_sorted = time_arr[sort_idx] + if n_total > 1: + dup_mask = time_sorted[1:] == time_sorted[:-1] + if dup_mask.any(): + dup_ts = int(time_sorted[1:][dup_mask][0]) + raise ValueError( + f"Duplicate timestamp {dup_ts} for device " + f"{dict(zip(self._tag_cols, device_key))}. " + "Cross-file or within-file duplicate timestamps are not supported." + ) + else: + sort_idx = None + time_sorted = time_arr + + row: dict = {} + for tag_name, tag_val in zip(self._tag_cols, device_key): + row[tag_name] = None if tag_val is None else str(tag_val) + row[self._time_col] = time_sorted + + for fname in field_names: + chunks = field_chunks.get(fname, []) + materialized: list = [] + for tchunk, fchunk in zip(time_chunks, chunks): + if fchunk is None: + materialized.append(np.full(len(tchunk), None, dtype=object)) + else: + materialized.append(fchunk) + arr = np.concatenate(materialized) if materialized else np.array([], dtype=object) + if sort_idx is not None and len(arr) == n_total: + arr = arr[sort_idx] + row[fname] = arr + return row + + # ----- arrow assembly ---------------------------------------------- + + def _rows_to_table(self, rows: list[dict], target_schema: pa.Schema) -> pa.Table: + """Convert a batch of row dicts into an Arrow table matching ``target_schema``.""" + arrays: list[pa.Array] = [] + for f in target_schema: + values = [r[f.name] for r in rows] + if pa.types.is_list(f.type): + arrays.append(self._build_list_array(values, f.type)) + else: + arrays.append(pa.array(values, type=f.type)) + pa_table = pa.Table.from_arrays(arrays, names=[f.name for f in target_schema]) + return table_cast(pa_table, target_schema) + + @staticmethod + def _build_list_array(values: list, list_type: pa.ListType) -> pa.ListArray: + """Build a ``ListArray`` from a list of per-row 1D arrays / sequences.""" + inner_type = list_type.value_type + offsets = [0] + flat_chunks: list = [] + total = 0 + for v in values: + if v is None: + length = 0 + else: + length = len(v) + flat_chunks.append(v) + total += length + offsets.append(total) + + if flat_chunks: + try: + flat = np.concatenate([np.asarray(c) for c in flat_chunks]) + except ValueError: + # Heterogeneous shapes / dtypes β†’ fall back to a Python list. + flat = [] + for c in flat_chunks: + flat.extend(list(c)) + flat_arr = pa.array(flat, type=inner_type, from_pandas=True) + else: + flat_arr = pa.array([], type=inner_type) + + return pa.ListArray.from_arrays(pa.array(offsets, type=pa.int32()), flat_arr) + + # ----- file / error handling --------------------------------------- + + @staticmethod + def _open_reader(file: str): + """Open a file as a ``TsFileReader`` after verifying its magic header. + + The C library's ``TsFileReader`` constructor silently returns an + invalid handle for non-tsfile inputs, and any subsequent call on it + segfaults. The 6-byte ``TsFile`` magic header is checked first to + bail out cleanly. + """ + from tsfile import TsFileReader + + try: + with open(file, "rb") as fh: + header = fh.read(len(_TSFILE_MAGIC)) + except OSError as e: + raise ValueError(f"Cannot open file {file!r}: {e}") from e + if header != _TSFILE_MAGIC: + raise ValueError(f"File {file!r} is not a valid TsFile (bad magic header).") + return TsFileReader(file) + + @staticmethod + def _require_table_model(file: str, schemas) -> None: + if not schemas: + raise ValueError( + f"File {file!r} is a tree-model TsFile, which is not supported. " + "Only table-model TsFiles can be loaded." + ) + + @staticmethod + def _schemas_by_lc(reader) -> dict: + """Return ``get_all_table_schemas()`` keyed by lowercased table name. + + TsFile / IoTDB treat table names case-insensitively, but the Python + binding's ``get_all_table_schemas()`` returns a dict keyed by whatever + casing the file was written with. Lowercasing the keys here lets all + downstream lookups use a single canonical form. + """ + return {name.lower(): schema for name, schema in reader.get_all_table_schemas().items()} + + def _should_reraise(self, file: str, exc: BaseException) -> bool: + """Apply ``on_bad_files`` policy. Returns True iff the caller should re-raise.""" + mode = self.config.on_bad_files + if mode == "error": + logger.error(f"Failed to read file '{file}' with error {type(exc).__name__}: {exc}") + return True + if mode == "warn": + logger.warning(f"Skipping bad file '{file}'. {type(exc).__name__}: {exc}") + else: + logger.debug(f"Skipping bad file '{file}'. {type(exc).__name__}: {exc}") + return False diff --git a/tests/packaged_modules/test_tsfile.py b/tests/packaged_modules/test_tsfile.py new file mode 100644 index 00000000000..b640c8e579a --- /dev/null +++ b/tests/packaged_modules/test_tsfile.py @@ -0,0 +1,748 @@ +"""Tests for the per-device wide-format TsFile builder.""" + +from __future__ import annotations + +import logging +from datetime import date, datetime, timedelta, timezone +from typing import Any, Sequence + +import pyarrow as pa +import pytest + + +# `tsfile` requires pyarrow<20 for python<3.14, which conflicts with datasets' +# pyarrow>=21.0.0. It is therefore only installed in the py3.14 CI. Skip this +# whole module (at collection time) when tsfile is not importable. +pytest.importorskip("tsfile") + +from tsfile import ColumnCategory, ColumnSchema, TableSchema, Tablet, TsFileWriter # noqa: E402 +from tsfile.constants import TSDataType # noqa: E402 + +from datasets import IterableDataset, load_dataset # noqa: E402 +from datasets.builder import InvalidConfigName # noqa: E402 +from datasets.data_files import DataFilesList # noqa: E402 +from datasets.packaged_modules.tsfile.tsfile import TsFileConfig, _to_epoch # noqa: E402 + + +# --------------------------------------------------------------------------- +# Time-base constants +# --------------------------------------------------------------------------- +# +# Every fixture's timestamps live in a disjoint epoch-ms slice off ``T0`` so +# that, when two files of the same device are merged, the resulting +# time-sorted order is fully determined by writer-side timestamps. This lets +# the assertions below check both *content* and *order* unambiguously. +T0 = 1_700_000_000_000 # base: single_device, multi_device, all_types +T_EVOLVED = T0 + 500_000 # evolved file (after single_device's 5 points) +T_INT32 = T0 + 1_000_000 +T_INT64 = T0 + 2_000_000 +T_FLOAT = T0 + 3_000_000 + + +# --------------------------------------------------------------------------- +# Generic writer +# --------------------------------------------------------------------------- + + +# A row maps column name -> Python value, plus a special "time" -> int (epoch). +Row = dict[str, Any] +ColumnSpec = tuple[str, TSDataType, ColumnCategory] + + +def _write_tsfile(path: str, tables: Sequence[tuple[str, Sequence[ColumnSpec], Sequence[Sequence[Row]]]]) -> None: + """Write one or more tables, each as one or more tablets, to ``path``. + + Each ``tables`` entry is ``(table_name, columns, tablets)`` where: + + - ``columns`` is the table schema as ``[(name, TSDataType, ColumnCategory), ...]``; + it must include exactly one TIME column called ``"time"``. + - ``tablets`` is a list of tablets; each tablet is a list of row dicts. A + row dict must carry ``"time"`` plus every TAG/FIELD column in the table. + """ + writer = TsFileWriter(path) + try: + # Register schemas first so multiple-table files validate up-front. + for table_name, columns, _ in tables: + writer.register_table(TableSchema(table_name, [ColumnSchema(*c) for c in columns])) + + for table_name, columns, tablets in tables: + non_time = [(n, t) for (n, t, c) in columns if c != ColumnCategory.TIME] + col_names = [n for n, _ in non_time] + col_types = [t for _, t in non_time] + for rows in tablets: + tablet = Tablet(col_names, col_types, len(rows)) + tablet.set_table_name(table_name) + for i, row in enumerate(rows): + tablet.add_timestamp(i, row["time"]) + for name in col_names: + tablet.add_value_by_name(name, i, row[name]) + writer.write_table(tablet) + finally: + writer.close() + + +# --------------------------------------------------------------------------- +# Per-fixture writers (each declarative + tiny) +# --------------------------------------------------------------------------- + + +def _write_single_device(path: str) -> None: + """One device 'd1', two DOUBLE fields, 5 points starting at T0.""" + cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("temperature", TSDataType.DOUBLE, ColumnCategory.FIELD), + ("humidity", TSDataType.DOUBLE, ColumnCategory.FIELD), + ] + rows = [{"time": T0 + i * 1000, "device": "d1", "temperature": 20.0 + i, "humidity": 50.0 + i} for i in range(5)] + _write_tsfile(path, [("mytable", cols, [rows])]) + + +def _write_multi_device(path: str) -> None: + """Three devices, 3 points each, all sharing the same field schema.""" + cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("temperature", TSDataType.DOUBLE, ColumnCategory.FIELD), + ("humidity", TSDataType.DOUBLE, ColumnCategory.FIELD), + ] + tablets = [ + [{"time": T0 + i * 1000, "device": dev, "temperature": 10.0 + i, "humidity": 50.0 + i} for i in range(3)] + for dev in ("d1", "d2", "d3") + ] + _write_tsfile(path, [("plant", cols, tablets)]) + + +def _write_evolved(path: str) -> None: + """Same table+device as single_device, plus a new ``voltage`` field.""" + cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("temperature", TSDataType.DOUBLE, ColumnCategory.FIELD), + ("humidity", TSDataType.DOUBLE, ColumnCategory.FIELD), + ("voltage", TSDataType.DOUBLE, ColumnCategory.FIELD), + ] + rows = [ + { + "time": T_EVOLVED + i * 1000, + "device": "d1", + "temperature": 30.0 + i, + "humidity": 60.0 + i, + "voltage": 220.0 + i, + } + for i in range(3) + ] + _write_tsfile(path, [("mytable", cols, [rows])]) + + +def _write_numeric_field(path: str, ts_type: TSDataType, base_ts: int, value_fn) -> None: + """One device 'd1' with a single numeric ``temperature`` field.""" + cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("temperature", ts_type, ColumnCategory.FIELD), + ] + rows = [{"time": base_ts + i * 1000, "device": "d1", "temperature": value_fn(i)} for i in range(3)] + _write_tsfile(path, [("mytable", cols, [rows])]) + + +def _write_two_tables(path: str) -> None: + """Two distinct tables in one file: ``table_a`` (registered first) and ``table_b``.""" + a_cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("a", TSDataType.DOUBLE, ColumnCategory.FIELD), + ] + b_cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("b", TSDataType.DOUBLE, ColumnCategory.FIELD), + ] + a_rows = [{"time": 1_000 + i, "device": "d1", "a": float(i)} for i in range(2)] + b_rows = [{"time": 2_000 + i, "device": "d1", "b": 100.0 + i} for i in range(2)] + _write_tsfile(path, [("table_a", a_cols, [a_rows]), ("table_b", b_cols, [b_rows])]) + + +def _write_all_types(path: str) -> None: + """Every supported TSDataType represented as a FIELD.""" + cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("tag", TSDataType.STRING, ColumnCategory.TAG), + ("col_boolean", TSDataType.BOOLEAN, ColumnCategory.FIELD), + ("col_int32", TSDataType.INT32, ColumnCategory.FIELD), + ("col_int64", TSDataType.INT64, ColumnCategory.FIELD), + ("col_float", TSDataType.FLOAT, ColumnCategory.FIELD), + ("col_double", TSDataType.DOUBLE, ColumnCategory.FIELD), + ("col_text", TSDataType.TEXT, ColumnCategory.FIELD), + ("col_string", TSDataType.STRING, ColumnCategory.FIELD), + ("col_timestamp", TSDataType.TIMESTAMP, ColumnCategory.FIELD), + ("col_date", TSDataType.DATE, ColumnCategory.FIELD), + ("col_blob", TSDataType.BLOB, ColumnCategory.FIELD), + ] + rows = [ + { + "time": T0 + i * 1000, + "tag": "d1", + "col_boolean": i % 2 == 0, + "col_int32": 100 + i, + "col_int64": 1_000_000 + i, + "col_float": 1.5 + i, + "col_double": 100.5 + i, + "col_text": f"text_{i}", + "col_string": f"str_{i}", + "col_timestamp": 1_600_000_000_000 + i * 1000, + "col_date": date(2024, 1, 1 + i), + "col_blob": f"blob{i}".encode(), + } + for i in range(3) + ] + _write_tsfile(path, [("alltypes", cols, [rows])]) + + +def _write_large_device(path: str, n_points: int = 200) -> None: + """Single device with many points, used to exercise multi-batch concat.""" + cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("v", TSDataType.INT64, ColumnCategory.FIELD), + ] + rows = [{"time": T0 + i, "device": "d1", "v": i} for i in range(n_points)] + _write_tsfile(path, [("mytable", cols, [rows])]) + + +def _write_two_devices_subset(path: str, devices: Sequence[str], base_ts: int) -> None: + """A multi-device fixture used to assemble cross-file device sets.""" + cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("v", TSDataType.DOUBLE, ColumnCategory.FIELD), + ] + tablets = [[{"time": base_ts + i * 1000, "device": dev, "v": float(i)} for i in range(3)] for dev in devices] + _write_tsfile(path, [("mytable", cols, tablets)]) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def make_tsfile(tmp_path): + """Factory fixture: ``make_tsfile("name", writer_fn, *args, **kwargs)``.""" + + def _make(name: str, writer_fn, *args, **kwargs) -> str: + p = str(tmp_path / f"{name}.tsfile") + writer_fn(p, *args, **kwargs) + return p + + return _make + + +@pytest.fixture +def tsfile_path(make_tsfile): + return make_tsfile("sample", _write_single_device) + + +@pytest.fixture +def multi_device_tsfile_path(make_tsfile): + return make_tsfile("multi", _write_multi_device) + + +@pytest.fixture +def evolved_tsfile_path(make_tsfile): + return make_tsfile("evolved", _write_evolved) + + +@pytest.fixture +def two_tables_tsfile_path(make_tsfile): + return make_tsfile("two_tables", _write_two_tables) + + +@pytest.fixture +def all_types_tsfile_path(make_tsfile): + return make_tsfile("alltypes", _write_all_types) + + +# --------------------------------------------------------------------------- +# Config-level +# --------------------------------------------------------------------------- + + +def test_config_raises_when_invalid_name(): + with pytest.raises(InvalidConfigName, match="Bad characters"): + TsFileConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files): + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + TsFileConfig(name="name", data_files=data_files) + + +@pytest.mark.parametrize( + "kwargs, match", + [ + ({"input_batch_size": 0}, "input_batch_size"), + ({"output_batch_size": 0}, "output_batch_size"), + ({"columns": []}, "non-empty"), + ({"timestamp_unit": "minute"}, "timestamp_unit"), + ({"on_bad_files": "boom"}, "on_bad_files"), + ], +) +def test_config_rejects_invalid_values(kwargs, match): + with pytest.raises(ValueError, match=match): + TsFileConfig(name="x", **kwargs) + + +def test_config_normalizes_time_bounds(): + cfg = TsFileConfig( + name="x", + start_time=pa.scalar(1500, type=pa.timestamp("ms")), + end_time=2000, + ) + assert cfg.start_time == 1500 + assert cfg.end_time == 2000 + + +# --------------------------------------------------------------------------- +# _to_epoch unit tests +# --------------------------------------------------------------------------- + + +def test_to_epoch_int_passthrough(): + assert _to_epoch(1234, "ms") == 1234 + + +def test_to_epoch_naive_datetime(): + assert _to_epoch(datetime(1970, 1, 1, 0, 0, 1), "ms") == 1000 + + +@pytest.mark.parametrize( + "aware", + [ + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=8))), + "2024-01-01T00:00:00+08:00", + ], + ids=["datetime", "iso_string"], +) +def test_to_epoch_aware_inputs_normalized_to_utc(aware): + # 2024-01-01T00:00:00 in UTC+8 == 2023-12-31T16:00:00 UTC. + naive_utc = datetime(2023, 12, 31, 16, 0, 0) + assert _to_epoch(aware, "ms") == _to_epoch(naive_utc, "ms") + + +def test_to_epoch_date(): + assert _to_epoch(date(1970, 1, 2), "ms") == 86_400_000 + + +def test_to_epoch_iso_string(): + assert _to_epoch("1970-01-01T00:00:01", "ms") == 1000 + + +def test_to_epoch_pa_scalar(): + assert _to_epoch(pa.scalar(1500, type=pa.timestamp("ms")), "ms") == 1500 + + +def test_to_epoch_rejects_bool(): + with pytest.raises(TypeError, match="bool"): + _to_epoch(True, "ms") + + +@pytest.mark.parametrize("value", [object(), b"bytes", "not-a-date"]) +def test_to_epoch_rejects_garbage(value): + with pytest.raises(TypeError, match="must be a"): + _to_epoch(value, "ms") + + +# --------------------------------------------------------------------------- +# End-to-end: single device, full table +# --------------------------------------------------------------------------- + + +def test_load_full_table(tsfile_path): + ds = load_dataset("tsfile", data_files=tsfile_path)["train"] + + # One row per device. TAG = scalar string; time + fields = lists. + assert ds.column_names == ["device", "time", "temperature", "humidity"] + assert len(ds) == 1 + row = ds[0] + assert row["device"] == "d1" + assert len(row["time"]) == 5 + assert row["time"][0] == datetime(2023, 11, 14, 22, 13, 20) + assert row["time"][-1] == datetime(2023, 11, 14, 22, 13, 24) + assert row["temperature"] == [20.0, 21.0, 22.0, 23.0, 24.0] + assert row["humidity"] == [50.0, 51.0, 52.0, 53.0, 54.0] + + +def test_load_with_field_subset(tsfile_path): + ds = load_dataset("tsfile", data_files=tsfile_path, columns=["temperature"])["train"] + assert ds.column_names == ["device", "time", "temperature"] + assert ds[0]["temperature"] == [20.0, 21.0, 22.0, 23.0, 24.0] + + +def test_columns_are_lowercased(tsfile_path): + ds = load_dataset("tsfile", data_files=tsfile_path, columns=["TEMPERATURE", "Humidity"])["train"] + assert ds.column_names == ["device", "time", "temperature", "humidity"] + + +def test_columns_request_tag_is_silently_ignored(tsfile_path): + """Passing a TAG name in `columns` is a no-op (TAGs are always emitted).""" + ds = load_dataset("tsfile", data_files=tsfile_path, columns=["device", "temperature"])["train"] + + assert ds.column_names == ["device", "time", "temperature"] + assert ds.features["device"].dtype == "string" + assert ds.features["temperature"].feature.dtype == "float64" + assert ds["device"] == ["d1"] + assert ds[0]["temperature"] == [20.0, 21.0, 22.0, 23.0, 24.0] + + +def test_columns_request_time_is_silently_ignored(tsfile_path): + """Passing the TIME column name in `columns` is a no-op (TIME is always emitted).""" + ds = load_dataset("tsfile", data_files=tsfile_path, columns=["time", "temperature"])["train"] + + # `time` should appear exactly once, and as the real timestamp list β€” not + # as a duplicate all-null float64 list column. + assert ds.column_names == ["device", "time", "temperature"] + assert ds.features["time"].feature.dtype.startswith("timestamp") + row = ds[0] + assert len(row["time"]) == 5 + assert row["time"][0] == datetime(2023, 11, 14, 22, 13, 20) + assert row["time"][-1] == datetime(2023, 11, 14, 22, 13, 24) + assert row["temperature"] == [20.0, 21.0, 22.0, 23.0, 24.0] + + +def test_columns_request_only_time(tsfile_path): + """`columns=["time"]` should still produce TAG + TIME, with no FIELD list columns.""" + ds = load_dataset("tsfile", data_files=tsfile_path, columns=["time"])["train"] + + assert ds.column_names == ["device", "time"] + assert ds.features["time"].feature.dtype.startswith("timestamp") + row = ds[0] + assert row["device"] == "d1" + assert len(row["time"]) == 5 + assert row["time"][0] == datetime(2023, 11, 14, 22, 13, 20) + assert row["time"][-1] == datetime(2023, 11, 14, 22, 13, 24) + + +def test_columns_unknown_field_filled_with_null(tsfile_path): + ds = load_dataset( + "tsfile", + data_files=tsfile_path, + columns=["temperature", "voltage"], # voltage is absent + )["train"] + + assert ds.column_names == ["device", "time", "temperature", "voltage"] + row = ds[0] + assert row["temperature"] == [20.0, 21.0, 22.0, 23.0, 24.0] + assert row["voltage"] == [None] * 5 + + +def test_columns_all_unknown_still_returns_time_and_tags(tsfile_path): + ds = load_dataset( + "tsfile", + data_files=tsfile_path, + columns=["nonexistent_a", "nonexistent_b"], + )["train"] + + assert ds.column_names == ["device", "time", "nonexistent_a", "nonexistent_b"] + row = ds[0] + assert row["device"] == "d1" + assert len(row["time"]) == 5 + assert row["nonexistent_a"] == [None] * 5 + assert row["nonexistent_b"] == [None] * 5 + + +# --------------------------------------------------------------------------- +# Time-range filtering +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "start, end", + [ + # pa.scalar from datetime + ( + pa.scalar(datetime(2023, 11, 14, 22, 13, 21), type=pa.timestamp("ms")), + pa.scalar(datetime(2023, 11, 14, 22, 13, 23), type=pa.timestamp("ms")), + ), + # pa.scalar from int epoch + ( + pa.scalar(T0 + 1000, type=pa.timestamp("ms")), + pa.scalar(T0 + 3000, type=pa.timestamp("ms")), + ), + # plain int epoch + (T0 + 1000, T0 + 3000), + # datetime + (datetime(2023, 11, 14, 22, 13, 21), datetime(2023, 11, 14, 22, 13, 23)), + # ISO-8601 string + ("2023-11-14T22:13:21", "2023-11-14T22:13:23"), + ], +) +def test_load_with_time_range_inputs(tsfile_path, start, end): + ds = load_dataset("tsfile", data_files=tsfile_path, start_time=start, end_time=end)["train"] + assert len(ds[0]["time"]) == 3 + + +# --------------------------------------------------------------------------- +# Multi-device & cross-file folding +# --------------------------------------------------------------------------- + + +def test_load_multi_device_one_row_per_device(multi_device_tsfile_path): + ds = load_dataset("tsfile", data_files=multi_device_tsfile_path)["train"] + + assert len(ds) == 3 + assert sorted(ds["device"]) == ["d1", "d2", "d3"] + for row in ds: + assert len(row["time"]) == 3 + assert row["temperature"] == [10.0, 11.0, 12.0] + assert row["humidity"] == [50.0, 51.0, 52.0] + + +def test_schema_evolution_merges_same_device(tsfile_path, evolved_tsfile_path): + """Same device d1 in two files β†’ one row, lists merged in time order.""" + ds = load_dataset("tsfile", data_files=[tsfile_path, evolved_tsfile_path])["train"] + + assert "voltage" in ds.column_names + assert len(ds) == 1 + row = ds[0] + assert row["device"] == "d1" + # 5 (old) + 3 (new) points, fully time-ordered. + assert len(row["time"]) == 8 + # Old file lacked `voltage` β†’ null on its 5 points; new file fills the rest. + assert row["voltage"] == [None] * 5 + [220.0, 221.0, 222.0] + # `temperature` is present in both files, contiguous in the merged order. + assert row["temperature"] == [20.0, 21.0, 22.0, 23.0, 24.0, 30.0, 31.0, 32.0] + + +def test_multi_file_multi_device_partial_overlap(make_tsfile): + """Two files Γ— two devices each, with one device shared. + + File A: devices {d1, d2}; file B: devices {d2, d3}. The merged dataset + must have 3 rows (one per unique device), and d2 must have *6* points + (3 from each file) sorted by time. + """ + fa = make_tsfile("a", _write_two_devices_subset, devices=["d1", "d2"], base_ts=T0) + fb = make_tsfile("b", _write_two_devices_subset, devices=["d2", "d3"], base_ts=T0 + 100_000) + + ds = load_dataset("tsfile", data_files=[fa, fb])["train"] + by_dev = {row["device"]: row for row in ds} + assert set(by_dev) == {"d1", "d2", "d3"} + assert len(by_dev["d1"]["time"]) == 3 + assert len(by_dev["d3"]["time"]) == 3 + # Shared device gets all 6 points, time-sorted (file A first, then B). + assert len(by_dev["d2"]["time"]) == 6 + assert by_dev["d2"]["v"] == [0.0, 1.0, 2.0, 0.0, 1.0, 2.0] + + +# --------------------------------------------------------------------------- +# Type promotion across files +# --------------------------------------------------------------------------- + + +def test_type_promotion_int32_to_int64(make_tsfile): + int32_path = make_tsfile("narrow", _write_numeric_field, TSDataType.INT32, T_INT32, lambda i: 10 + i) + int64_path = make_tsfile("wide", _write_numeric_field, TSDataType.INT64, T_INT64, lambda i: 1_000_000 + i) + + ds = load_dataset("tsfile", data_files=[int32_path, int64_path])["train"] + assert len(ds) == 1 + assert ds.features["temperature"].feature.dtype == "int64" + # int32 timestamps come earlier (T_INT32 < T_INT64). + assert ds[0]["temperature"] == [10, 11, 12, 1_000_000, 1_000_001, 1_000_002] + + +def test_type_promotion_float_to_double(make_tsfile): + float_path = make_tsfile("narrow", _write_numeric_field, TSDataType.FLOAT, T_FLOAT, lambda i: 1.5 + i) + double_path = make_tsfile("wide", _write_single_device) + + ds = load_dataset("tsfile", data_files=[float_path, double_path])["train"] + assert len(ds) == 1 + assert ds.features["temperature"].feature.dtype == "float64" + # double fixture lives at T0..T0+4s; float at T_FLOAT (later). + assert ds[0]["temperature"] == [20.0, 21.0, 22.0, 23.0, 24.0, 1.5, 2.5, 3.5] + + +def test_type_promotion_int32_to_double(make_tsfile): + int32_path = make_tsfile("int", _write_numeric_field, TSDataType.INT32, T_INT32, lambda i: 10 + i) + double_path = make_tsfile("double", _write_single_device) + + ds = load_dataset("tsfile", data_files=[int32_path, double_path])["train"] + assert len(ds) == 1 + # INT32 + DOUBLE β†’ DOUBLE (two-step widening). + assert ds.features["temperature"].feature.dtype == "float64" + assert ds[0]["temperature"] == [20.0, 21.0, 22.0, 23.0, 24.0, 10.0, 11.0, 12.0] + + +# --------------------------------------------------------------------------- +# All-types +# --------------------------------------------------------------------------- + + +def test_load_all_supported_types(all_types_tsfile_path): + ds = load_dataset("tsfile", data_files=all_types_tsfile_path)["train"] + + assert len(ds) == 1 + assert ds.column_names == [ + "tag", + "time", + "col_boolean", + "col_int32", + "col_int64", + "col_float", + "col_double", + "col_text", + "col_string", + "col_timestamp", + "col_date", + "col_blob", + ] + row = ds[0] + assert row["tag"] == "d1" + assert row["col_boolean"] == [True, False, True] + assert row["col_int32"] == [100, 101, 102] + assert row["col_int64"] == [1_000_000, 1_000_001, 1_000_002] + assert row["col_float"] == [1.5, 2.5, 3.5] + assert row["col_double"] == [100.5, 101.5, 102.5] + assert row["col_text"] == ["text_0", "text_1", "text_2"] + assert row["col_string"] == ["str_0", "str_1", "str_2"] + assert row["col_timestamp"][0] == datetime(2020, 9, 13, 12, 26, 40) + assert row["col_date"][0] == date(2024, 1, 1) + assert row["col_date"][2] == date(2024, 1, 3) + assert row["col_blob"][0] == b"blob0" + assert row["col_blob"][2] == b"blob2" + + +# --------------------------------------------------------------------------- +# Multi-table file: explicit `table_name` selection +# --------------------------------------------------------------------------- + + +def test_default_table_is_first(two_tables_tsfile_path): + ds = load_dataset("tsfile", data_files=two_tables_tsfile_path)["train"] + # `table_a` registered first β†’ default pick. + assert "a" in ds.column_names + assert "b" not in ds.column_names + + +def test_explicit_table_name(two_tables_tsfile_path): + ds = load_dataset("tsfile", data_files=two_tables_tsfile_path, table_name="table_b")["train"] + assert "b" in ds.column_names + assert "a" not in ds.column_names + + +# --------------------------------------------------------------------------- +# Streaming (IterableDataset) +# --------------------------------------------------------------------------- + + +def test_streaming_yields_same_rows(multi_device_tsfile_path): + ds = load_dataset("tsfile", data_files=multi_device_tsfile_path, streaming=True)["train"] + assert isinstance(ds, IterableDataset) + rows = list(ds) + assert len(rows) == 3 + assert sorted(r["device"] for r in rows) == ["d1", "d2", "d3"] + for r in rows: + assert r["temperature"] == [10.0, 11.0, 12.0] + + +# --------------------------------------------------------------------------- +# Timezone +# --------------------------------------------------------------------------- + + +def test_load_with_timezone(make_tsfile): + """`timestamp_tz="UTC"` round-trips: list values come back tz-aware.""" + path = make_tsfile("tz", _write_single_device) + ds = load_dataset("tsfile", data_files=path, timestamp_tz="UTC")["train"] + ts = ds[0]["time"][0] + assert ts.tzinfo is not None + # Same wall-clock as the naive case, attached to UTC. + assert ts == datetime(2023, 11, 14, 22, 13, 20, tzinfo=timezone.utc) + + +# --------------------------------------------------------------------------- +# Large-batch / multi-chunk concat +# --------------------------------------------------------------------------- + + +def test_large_device_with_small_batch_size(make_tsfile): + """Force multiple Arrow batches per device β†’ exercise the concat path.""" + path = make_tsfile("big", _write_large_device, n_points=200) + ds = load_dataset("tsfile", data_files=path, input_batch_size=64)["train"] + assert len(ds) == 1 + row = ds[0] + assert len(row["time"]) == 200 + assert row["v"] == list(range(200)) + + +# --------------------------------------------------------------------------- +# Duplicate-timestamp detection (cross-file) +# --------------------------------------------------------------------------- + + +def test_duplicate_timestamp_across_files_raises(make_tsfile): + """Same device, same ts in two files β†’ `_finalize_device` must raise.""" + cols = [ + ("time", TSDataType.TIMESTAMP, ColumnCategory.TIME), + ("device", TSDataType.STRING, ColumnCategory.TAG), + ("v", TSDataType.DOUBLE, ColumnCategory.FIELD), + ] + rows = [{"time": 5_000, "device": "d1", "v": 1.0}] + + a = make_tsfile("dupA", lambda p: _write_tsfile(p, [("mytable", cols, [rows])])) + b = make_tsfile("dupB", lambda p: _write_tsfile(p, [("mytable", cols, [rows])])) + + with pytest.raises(Exception) as excinfo: + load_dataset("tsfile", data_files=[a, b]) + # The ValueError is wrapped by `_prepare_split_single` into a + # DatasetGenerationError; check the cause chain for the original message. + chain = [excinfo.value, *(_iter_causes(excinfo.value))] + assert any("Duplicate timestamp" in str(e) for e in chain) + + +def _iter_causes(exc: BaseException): + while exc.__cause__ is not None: + exc = exc.__cause__ + yield exc + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +def test_on_bad_files_skip(tmp_path, tsfile_path): + bad = tmp_path / "broken.tsfile" + bad.write_bytes(b"not a real tsfile") + + ds = load_dataset( + "tsfile", + data_files=[tsfile_path, str(bad)], + on_bad_files="skip", + )["train"] + assert len(ds) == 1 + assert len(ds[0]["time"]) == 5 + + +def test_on_bad_files_warn(tmp_path, tsfile_path, caplog): + bad = tmp_path / "broken.tsfile" + bad.write_bytes(b"not a real tsfile") + + with caplog.at_level(logging.WARNING, logger="datasets.packaged_modules.tsfile.tsfile"): + ds = load_dataset( + "tsfile", + data_files=[tsfile_path, str(bad)], + on_bad_files="warn", + )["train"] + assert len(ds) == 1 + assert any("Skipping bad file" in rec.message for rec in caplog.records) + + +def test_on_bad_files_default_raises(tmp_path, tsfile_path): + bad = tmp_path / "broken.tsfile" + bad.write_bytes(b"not a real tsfile") + + with pytest.raises(Exception) as excinfo: + load_dataset("tsfile", data_files=[tsfile_path, str(bad)]) + chain = [excinfo.value, *(_iter_causes(excinfo.value))] + assert any("not a valid TsFile" in str(e) for e in chain)