From 69c187c9841def9565a10cfca2a4d2967e8d180f Mon Sep 17 00:00:00 2001 From: Anderson Chiu Date: Wed, 6 May 2026 17:49:32 +0800 Subject: [PATCH] Fix torchax tensor flatten bug on empty tensors Signed-off-by: Anderson Chiu --- test/test_functions.py | 22 +++++++++++++++++++++- torchax/tensor.py | 6 +++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/test/test_functions.py b/test/test_functions.py index 0ac7c51..5f830ea 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -80,10 +80,30 @@ def test_bernoulli_inplace(self): def test_flatten(self): with self.env: - a = torch.randn((2, 3, 4)) + a = torch.randn((2, 3, 4)).to(device="jax") a = a.flatten(0, 1) self.assertEqual(tuple(a.shape), (6, 4)) + # New test case for testing tensor flattening on zero dimension + a = torch.ones((16, 0)).to(device="jax") + a = a.flatten(0, -2) + self.assertEqual(a.shape, torch.Size([16, 0])) + + # Flattening tensors with zero dimension containing in flattening dimension + a = torch.randn((2, 1, 0, 5)).to(device="jax") + a = a.flatten(2, 3) + self.assertEqual(a.shape, torch.Size([2, 1, 0])) + + # Flattening tensors with zero dimension containing in non-flattening dimension + a = torch.randn((2, 0, 1, 5)).to(device="jax") + a = a.flatten(2, 3) + self.assertEqual(a.shape, torch.Size([2, 0, 5])) + + # Flattening tensors with zero dimension containing in flattening dimension and non-flattening dimension + a = torch.randn((2, 0, 0, 5)).to(device="jax") + a = a.flatten(2, 3) + self.assertEqual(a.shape, torch.Size([2, 0, 0])) + def test_rnn(self): model = SeqModel() x = torch.randn((2, 100, 20)) diff --git a/torchax/tensor.py b/torchax/tensor.py index 441ac82..973db82 100644 --- a/torchax/tensor.py +++ b/torchax/tensor.py @@ -15,6 +15,7 @@ import contextlib import itertools import logging +import math import sys import threading from collections.abc import Callable @@ -96,7 +97,10 @@ def ndim(self): def flatten(self, start_dim=0, end_dim=-1): if end_dim == -1: end_dim = self.ndim - new_shape = self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :] + flattened_size = math.prod(self._elem.shape[start_dim : end_dim + 1]) + new_shape = ( + self._elem.shape[:start_dim] + (flattened_size,) + self._elem.shape[end_dim + 1 :] + ) new_elem = jnp.reshape(self._elem, new_shape) return Tensor(new_elem, self._env) # return torch.reshape(self, new_shape)