Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/.beta.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
- python3.11
- python3.12
- python3.13
- python3.14
fail-fast: false
runs-on: linux-amd64-cpu16
defaults:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/.source.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ jobs:
- python3.11
- python3.12
- python3.13
- python3.14
fail-fast: false
runs-on: linux-amd64-cpu16
defaults:
Expand Down
1 change: 1 addition & 0 deletions .gitlab/pipelines/.beta.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Run E2E Tests:
- python3.11
- python3.12
- python3.13
- python3.14
script:
- |
nix develop '.#"'${PYTHON_BINARY}'"' --command bash -c "
Expand Down
1 change: 1 addition & 0 deletions .gitlab/pipelines/.source.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ multi-storage-client:
- python3.11
- python3.12
- python3.13
- python3.14
script:
- |
nix develop '.#"'${PYTHON_BINARY}'"' --command bash -c "
Expand Down
13 changes: 8 additions & 5 deletions multi-storage-client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13"
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14"
]
dependencies = [
"filelock>=3.20.3,<4",
Expand Down Expand Up @@ -98,7 +99,7 @@ oci = [
"oci>=2.169,<3"
]
parquet = [
"pyarrow>=21"
"pyarrow>=22"
]

# Higher-level libraries.
Expand All @@ -109,7 +110,8 @@ hydra-core = [
"hydra-core>=1.3,<2"
]
numpy = [
"numpy>=2.1,<3"
"numpy>=2.2.6,<3; python_version=='3.10'",
"numpy>=2.3.2,<3; python_version>='3.11'"
]
ray = [
"ray>=2.55,<3"
Expand All @@ -121,8 +123,9 @@ xarray = [
"xarray>=2024.07.0"
]
zarr = [
"zarr>=2.18.3,<3; python_version=='3.10'",
"zarr>=2.18.7,<3; python_version>='3.11'"
"zarr>=2.18.3,<4; python_version=='3.10'",
"zarr>=2.18.7,<4; python_version>='3.11' and python_version<='3.13'",
"zarr>=3.2,<4; python_version>='3.14'"
]

[dependency-groups]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import uuid

import pytest

# https://github.com/pytorch/pytorch/issues/131765
import torch
import torch.distributed.checkpoint as dcp

Expand All @@ -27,7 +29,7 @@
@pytest.fixture
def sample_data(tmp_path):
# Create a small tensor
tensor = torch.tensor([1, 2, 3, 4])
tensor = torch.tensor([1, 2, 3, 4]) # type: ignore[reportPrivateImportUsage]
# Create a filepath
filepath = tmp_path / "test.pt"
# Save the tensor to the file
Expand All @@ -39,29 +41,29 @@ def test_torch_load_with_filepath(sample_data):
filepath, expected_tensor = sample_data

result = msc.torch.load(str(filepath))
assert torch.equal(result, expected_tensor)
assert torch.equal(result, expected_tensor) # type: ignore[reportPrivateImportUsage]


def test_torch_load_with_msc_prefix(sample_data):
filepath, expected_tensor = sample_data

result = msc.torch.load(f"{MSC_PROTOCOL}__filesystem__{filepath}")
assert torch.equal(result, expected_tensor)
assert torch.equal(result, expected_tensor) # type: ignore[reportPrivateImportUsage]


def test_torch_save_with_msc_path(sample_data):
filepath, expected_tensor = sample_data

msc.torch.save(expected_tensor, msc.Path(filepath))
result = torch.load(filepath)
assert torch.equal(result, expected_tensor)
assert torch.equal(result, expected_tensor) # type: ignore[reportPrivateImportUsage]


def test_torch_load_with_msc_path(sample_data):
filepath, expected_tensor = sample_data

result = msc.torch.load(msc.Path(filepath))
assert torch.equal(result, expected_tensor)
assert torch.equal(result, expected_tensor) # type: ignore[reportPrivateImportUsage]


class SimpleModel(torch.nn.Module):
Expand Down Expand Up @@ -125,7 +127,7 @@ def test_filesystem_reader_writer(temp_data_store_type: type[tempdatastore.Tempo

# Compare each parameter tensor
for param_name in original_state_dict:
assert torch.equal(original_state_dict[param_name], loaded_state_dict_params[param_name]), (
assert torch.equal(original_state_dict[param_name], loaded_state_dict_params[param_name]), ( # type: ignore[reportPrivateImportUsage]
f"Parameter {param_name} does not match"
)

Expand Down Expand Up @@ -213,7 +215,7 @@ def test_torch_save_with_attributes(temp_data_store_type: type[tempdatastore.Tem

test_uuid = str(uuid.uuid4())
file_path = f"test-torch-attributes-{test_uuid}.pt"
tensor = torch.tensor([1, 2, 3, 4])
tensor = torch.tensor([1, 2, 3, 4]) # type: ignore[reportPrivateImportUsage]

test_attributes = {
"method": "torch.save",
Expand All @@ -227,7 +229,7 @@ def test_torch_save_with_attributes(temp_data_store_type: type[tempdatastore.Tem

# Verify content was written correctly
result = msc.torch.load(f"{MSC_PROTOCOL}test/{file_path}")
assert torch.equal(result, tensor)
assert torch.equal(result, tensor) # type: ignore[reportPrivateImportUsage]

# Verify attributes for storage providers that support metadata
if hasattr(temp_data_store, "_bucket_name"):
Expand All @@ -244,7 +246,7 @@ def test_torch_save_with_attributes(temp_data_store_type: type[tempdatastore.Tem
msc.torch.save(tensor, msc.Path(f"{MSC_PROTOCOL}test/{file_path2}"), attributes=test_attributes)

result = msc.torch.load(msc.Path(f"{MSC_PROTOCOL}test/{file_path2}"))
assert torch.equal(result, tensor)
assert torch.equal(result, tensor) # type: ignore[reportPrivateImportUsage]

finally:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import os
import sys
import tempfile

import pytest
Expand All @@ -32,8 +33,12 @@ def sample_zarr_data():
root = zarr.open(store_path, mode="w")
assert isinstance(root, zarr.Group)

array1 = root.create_dataset("array1", shape=(100, 100), dtype="int32")
array2 = root.create_dataset("array2", shape=(50, 50), dtype="float64")
if sys.version_info >= (3, 14):
array1 = root.create_array("array1", shape=(100, 100), dtype="int32")
array2 = root.create_array("array2", shape=(50, 50), dtype="float64")
else:
array1 = root.create_dataset("array1", shape=(100, 100), dtype="int32")
array2 = root.create_dataset("array2", shape=(50, 50), dtype="float64")
array1[:] = 1
array2[:] = 2.0

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
multi-storage-client,
python314,
}:
multi-storage-client.devShells.default.override {
pythonInterpreter = python314;
}
Loading
Loading