diff --git a/test/test_functions.py b/test/test_functions.py index 0ac7c51..5f830ea 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -80,10 +80,30 @@ def test_bernoulli_inplace(self): def test_flatten(self): with self.env: - a = torch.randn((2, 3, 4)) + a = torch.randn((2, 3, 4)).to(device="jax") a = a.flatten(0, 1) self.assertEqual(tuple(a.shape), (6, 4)) + # New test case for testing tensor flattening on zero dimension + a = torch.ones((16, 0)).to(device="jax") + a = a.flatten(0, -2) + self.assertEqual(a.shape, torch.Size([16, 0])) + + # Flattening tensors with zero dimension containing in flattening dimension + a = torch.randn((2, 1, 0, 5)).to(device="jax") + a = a.flatten(2, 3) + self.assertEqual(a.shape, torch.Size([2, 1, 0])) + + # Flattening tensors with zero dimension containing in non-flattening dimension + a = torch.randn((2, 0, 1, 5)).to(device="jax") + a = a.flatten(2, 3) + self.assertEqual(a.shape, torch.Size([2, 0, 5])) + + # Flattening tensors with zero dimension containing in flattening dimension and non-flattening dimension + a = torch.randn((2, 0, 0, 5)).to(device="jax") + a = a.flatten(2, 3) + self.assertEqual(a.shape, torch.Size([2, 0, 0])) + def test_rnn(self): model = SeqModel() x = torch.randn((2, 100, 20)) diff --git a/torchax/tensor.py b/torchax/tensor.py index 441ac82..973db82 100644 --- a/torchax/tensor.py +++ b/torchax/tensor.py @@ -15,6 +15,7 @@ import contextlib import itertools import logging +import math import sys import threading from collections.abc import Callable @@ -96,7 +97,10 @@ def ndim(self): def flatten(self, start_dim=0, end_dim=-1): if end_dim == -1: end_dim = self.ndim - new_shape = self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :] + flattened_size = math.prod(self._elem.shape[start_dim : end_dim + 1]) + new_shape = ( + self._elem.shape[:start_dim] + (flattened_size,) + self._elem.shape[end_dim + 1 :] + ) new_elem = jnp.reshape(self._elem, new_shape) return Tensor(new_elem, self._env) # return torch.reshape(self, new_shape)