From 584152ef76754d32738c3e639078167c81acd9ac Mon Sep 17 00:00:00 2001 From: lhparker1 Date: Tue, 17 Jun 2025 14:50:32 -0400 Subject: [PATCH 1/2] massive update to subsampler --- aion/codecs/modules/subsampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aion/codecs/modules/subsampler.py b/aion/codecs/modules/subsampler.py index af66f7a..ae9b615 100644 --- a/aion/codecs/modules/subsampler.py +++ b/aion/codecs/modules/subsampler.py @@ -32,7 +32,7 @@ def _subsample_in(self, x, labels: Bool[torch.Tensor, " b c"]): # Normalize label_sizes = labels.sum(dim=1, keepdim=True) - scales = ((self.dim_in / label_sizes) ** 0.5).squeeze() + scales = ((self.dim_in / label_sizes) ** 0.5).squeeze(-1) # Apply linear layer return scales[:, None, None, None] * F.linear(x, self.weight, self.bias) From 0eb131b994029c992bcea9190e930459843da2d3 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Tue, 17 Jun 2025 21:49:45 +0200 Subject: [PATCH 2/2] adding new image codec --- tests/codecs/test_image_codec.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/codecs/test_image_codec.py b/tests/codecs/test_image_codec.py index f60d2d1..76fdb36 100644 --- a/tests/codecs/test_image_codec.py +++ b/tests/codecs/test_image_codec.py @@ -81,3 +81,26 @@ def test_hf_previous_predictions(data_dir): rtol=1e-3, atol=1e-4, ) + + +def test_batch_size_one(): + """Test ImageCodec with batch_size=1 to ensure subsampler works correctly.""" + codec = ImageCodec.from_pretrained(HF_REPO_ID, modality=Image) + + # Test with batch_size=1 + batch_size = 1 + flux_tensor = torch.randn(batch_size, 4, 96, 96) + input_image_obj = Image( + flux=flux_tensor, + bands=["DES-G", "DES-R", "DES-I", "DES-Z"], + ) + + # This should not raise an error (previously failed due to squeeze() issue) + with torch.no_grad(): + encoded = codec.encode(input_image_obj) + decoded_image_obj = codec.decode( + encoded, bands=["DES-G", "DES-R", "DES-I", "DES-Z"] + ) + + assert isinstance(decoded_image_obj, Image) + assert decoded_image_obj.flux.shape == flux_tensor.shape