From e38ebd737277ba565506d654b0ceecaca0c267e0 Mon Sep 17 00:00:00 2001 From: Abe Levitan Date: Fri, 24 Oct 2025 15:11:46 +0200 Subject: [PATCH] Fix a high priority bug with the masking system which was introduced in the switch to reconstructors classes --- src/cdtools/reconstructors/base.py | 2 +- tests/models/test_fancy_ptycho.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cdtools/reconstructors/base.py b/src/cdtools/reconstructors/base.py index cf3de8c8..16668149 100644 --- a/src/cdtools/reconstructors/base.py +++ b/src/cdtools/reconstructors/base.py @@ -190,7 +190,7 @@ def closure(): sim_patterns = self.model.forward(*inp) # Calculate the loss - if hasattr(self, 'mask'): + if hasattr(self.model, 'mask'): loss = self.model.loss(pats, sim_patterns, mask=self.model.mask) diff --git a/tests/models/test_fancy_ptycho.py b/tests/models/test_fancy_ptycho.py index 8bc4d87e..02893f9f 100644 --- a/tests/models/test_fancy_ptycho.py +++ b/tests/models/test_fancy_ptycho.py @@ -52,6 +52,10 @@ def test_lab_ptycho(lab_ptycho_cxi, reconstruction_device, show_plot): print('\nTesting performance on the standard transmission ptycho dataset') dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(lab_ptycho_cxi) + # Test the masking system + dataset.mask[110:115,65:70] = 0 + dataset.patterns[...,~dataset.mask] = t.max(dataset.patterns) + model = cdtools.models.FancyPtycho.from_dataset( dataset, n_modes=3,