Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
to ``RankPartitionedDataLoader`` (mirroring how the vision recipe uses
``get_sft_dataset``).
"""

from __future__ import annotations

from typing import Any
Expand Down Expand Up @@ -42,7 +43,6 @@ def get_shuffle_blocks(self):
return self._dataset.get_shuffle_blocks()



class ActionIterableShuffleDataset(IterableDataset):
"""Streaming view of a map-style ``ActionSFTDataset``.

Expand Down Expand Up @@ -70,6 +70,8 @@ def __iter__(self):
import torch

blocks = self._dataset.get_shuffle_blocks()
if not blocks:
raise ValueError("No shuffle blocks found")
wi = get_worker_info()
wid = wi.id if wi is not None else 0
nw = wi.num_workers if wi is not None else 1
Expand Down
67 changes: 67 additions & 0 deletions tests/action_sft_dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

"""Hermetic tests for ActionIterableShuffleDataset."""

from __future__ import annotations

import importlib
import sys
import types

import pytest

pytestmark = [pytest.mark.level(0), pytest.mark.gpus(0)]


@pytest.fixture
def action_sft_dataset_module(monkeypatch: pytest.MonkeyPatch):
module_name = "cosmos_framework.data.vfm.action.datasets.action_sft_dataset"

fake_torch = types.ModuleType("torch")
fake_torch_utils = types.ModuleType("torch.utils")
fake_torch_utils_data = types.ModuleType("torch.utils.data")
fake_torch_utils_data.Dataset = type("Dataset", (), {})
fake_torch_utils_data.IterableDataset = type("IterableDataset", (), {})
fake_torch_utils_data.get_worker_info = lambda: None
fake_torch.utils = fake_torch_utils

fake_datasets_package = types.ModuleType("cosmos_framework.data.vfm.action.datasets")
fake_datasets_package.__path__ = [
"/Users/hoangvu/Code/OSS/cosmos-framework/cosmos_framework/data/vfm/action/datasets"
]

fake_droid_dataset = types.ModuleType("cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset")
fake_droid_dataset.DROIDLeRobotDataset = type("DROIDLeRobotDataset", (), {})

fake_transforms = types.ModuleType("cosmos_framework.data.vfm.action.transforms")
fake_transforms.ActionTransformPipeline = type("ActionTransformPipeline", (), {})

monkeypatch.setitem(sys.modules, "torch", fake_torch)
monkeypatch.setitem(sys.modules, "torch.utils", fake_torch_utils)
monkeypatch.setitem(sys.modules, "torch.utils.data", fake_torch_utils_data)
monkeypatch.setitem(sys.modules, "cosmos_framework.data.vfm.action.datasets", fake_datasets_package)
monkeypatch.setitem(
sys.modules, "cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset", fake_droid_dataset
)
monkeypatch.setitem(sys.modules, "cosmos_framework.data.vfm.action.transforms", fake_transforms)
sys.modules.pop(module_name, None)

module = importlib.import_module(module_name)
yield module

sys.modules.pop(module_name, None)


def test_action_iterable_shuffle_dataset_raises_when_shuffle_blocks_are_empty(action_sft_dataset_module) -> None:
class _Dataset:
def get_shuffle_blocks(self):
return []

def __len__(self):
return 0

dataset = action_sft_dataset_module.ActionIterableShuffleDataset(_Dataset())

with pytest.raises(ValueError, match="No shuffle blocks found"):
next(iter(dataset))