diff --git a/processor/flow.py b/processor/flow.py index b064d4c..2f5d369 100644 --- a/processor/flow.py +++ b/processor/flow.py @@ -15,6 +15,7 @@ """Flow field estimation from SOFIMA.""" import dataclasses +import gc import time from typing import Any, Sequence @@ -594,6 +595,7 @@ def __init__( ) self._config = config + logging.info('EstimateMissingFlow running with config: %r', config) def _build_mask( self, @@ -661,6 +663,20 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: out_box = out_box.adjusted_by(end=-offset) input_ndarray = input_ndarray[:, :, : out_box.size[1], : out_box.size[0]] + # The input flow forms the initial state of the output. We will try + # to fill-in any invalid (NaN) pixels by computing flow against + # earlier sections. + ret = np.zeros([3] + list(out_box.size[::-1])) + ret[:2, ...] = input_ndarray + ret[2, ...] = self._config.delta_z + + sel_mask = None + if self._config.selection_mask_configs: + sel_mask = self._build_mask(self._config.selection_mask_configs, out_box) + + mfc = flow_field.JAXMaskedXCorrWithStatsCalculator() + invalid = np.isnan(input_ndarray[0, ...]) + patch_size = self._config.patch_size curr_image_box = bounding_box.BoundingBox( start=( @@ -671,25 +687,55 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: size=( (out_box.size[0] - 1) * stride + patch_size, (out_box.size[1] - 1) * stride + patch_size, - 1, + invalid.shape[0], ), ) curr_image_box = image_volume.clip_box_to_volume(curr_image_box) assert curr_image_box is not None - # The input flow forms the initial state of the output. We will try - # to fill-in any invalid (NaN) pixels by computing flow against - # earlier sections. - ret = np.zeros([3] + list(out_box.size[::-1])) - ret[:2, ...] = input_ndarray - ret[2, ...] = self._config.delta_z + if self._config.delta_z > 0: + search_deltas = range( + self._config.delta_z + 1, self._config.max_delta_z + 1 + ) + load_start_z = out_box.start[2] - self._config.max_delta_z + load_end_z = out_box.end[2] + else: + search_deltas = range( + self._config.delta_z - 1, self._config.max_delta_z - 1, -1 + ) + load_start_z = out_box.start[2] + # max_delta_z is negative. + load_end_z = out_box.end[2] - self._config.max_delta_z - sel_mask = None - if self._config.selection_mask_configs: - sel_mask = self._build_mask(self._config.selection_mask_configs, out_box) + load_box = bounding_box.BoundingBox( + start=( + prev_image_box.start[0], + prev_image_box.start[1], + load_start_z, + ), + size=( + prev_image_box.size[0], + prev_image_box.size[1], + load_end_z - load_start_z, + ), + ) + load_box = image_volume.clip_box_to_volume(load_box) + + logging.info('Loading image data: %r', load_box) + full_image_stack = image_volume.asarray[load_box.to_slice4d()][0, ...] + full_mask = None + if self._config.mask_configs: + full_mask = self._build_mask(self._config.mask_configs, load_box) + logging.info('Loaaded mask: %r', full_mask.shape) + + # The 'curr' image is a subset of the loaded stack, centered within the + # 'prev' image (which includes the search radius). + curr_rel_start = curr_image_box.start - load_box.start + curr_slice = ( + slice(curr_rel_start[1], curr_rel_start[1] + curr_image_box.size[1]), + slice(curr_rel_start[0], curr_rel_start[0] + curr_image_box.size[0]), + ) - mfc = flow_field.JAXMaskedXCorrWithStatsCalculator() - invalid = np.isnan(input_ndarray[0, ...]) for z in range(0, invalid.shape[0]): z0 = box.start[2] + z logging.info('Processing rel_z=%d abs_z=%d', z, z0) @@ -698,12 +744,13 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: beam_utils.counter(namespace, 'sections-already-valid').inc() continue - image_box = curr_image_box.translate([0, 0, z]) + curr_z_idx = (out_box.start[2] + z) - load_box.start[2] + assert curr_z_idx >= 0 + assert curr_z_idx < full_image_stack.shape[0] + curr_mask = None if self._config.mask_configs: - curr_mask = self._build_mask( - self._config.mask_configs, image_box - ).squeeze() + curr_mask = full_mask[curr_z_idx, ...][curr_slice] if np.all(curr_mask): beam_utils.counter(namespace, 'sections-masked').inc() continue @@ -715,37 +762,23 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: if sel_mask is not None: mask &= sel_mask[z, ...] - curr = image_volume.asarray[image_box.to_slice4d()].squeeze() - - delta_z = self._config.delta_z - if delta_z > 0: - rng = range(delta_z + 1, self._config.max_delta_z + 1) - else: - rng = range(delta_z - 1, self._config.max_delta_z - 1, -1) + curr = full_image_stack[curr_z_idx, ...][curr_slice] - for delta_z in rng: - if ( - box.start[2] - delta_z < 0 - or box.end[2] - delta_z >= image_volume.volume_size[2] - ): + for delta_z in search_deltas: + prev_z_idx = curr_z_idx - delta_z + if prev_z_idx < 0 or prev_z_idx >= full_image_stack.shape[0]: break t_start = time.time() - prev_box = prev_image_box.translate([0, 0, z - delta_z]) - logging.info('Trying delta_z=%d (%r)', delta_z, prev_box) - prev = image_volume.asarray[prev_box.to_slice4d()].squeeze() - logging.info('.. image loaded.') + logging.info('Trying delta_z=%d', delta_z) + prev_mask = None + prev = full_image_stack[prev_z_idx, ...] t1 = time.time() if self._config.mask_configs: - prev_mask = self._build_mask( - self._config.mask_configs, prev_box - ).squeeze() + prev_mask = full_mask[prev_z_idx, ...] if np.all(prev_mask): continue - else: - prev_mask = None - logging.info('.. mask loaded.') # Limit the number of estimation attempts per voxel. Attempts # are only counted when voxels in both sections are unmasked. @@ -804,4 +837,8 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: t5 - t4, ) + del full_image_stack + del full_mask + gc.collect() + return Subvolume(ret, out_box) diff --git a/processor/flow_test.py b/processor/flow_test.py new file mode 100644 index 0000000..7a6c499 --- /dev/null +++ b/processor/flow_test.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# Copyright 2026 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from connectomics.common import bounding_box +from connectomics.volume import subvolume +import numpy as np +from sofima.processor import flow + + +class MockVolume: + + def __init__(self, data): + self._data = data # CZYX + + def clip_box_to_volume(self, box): + vol_box = bounding_box.BoundingBox(start=(0, 0, 0), size=self.volume_size) + return box.intersection(vol_box) + + @property + def asarray(self): + return self._data + + @property + def volume_size(self): + # XYZ + return (self._data.shape[3], self._data.shape[2], self._data.shape[1]) + + def __getitem__(self, key): + return self._data[key] + + +class TestEstimateMissingFlow(flow.EstimateMissingFlow): + + def __init__(self, config, image_vol): + super().__init__(config) + self.image_vol = image_vol + + def _open_volume(self, path): + return self.image_vol + + +class EstimateMissingFlowTest(absltest.TestCase): + + def test_process(self): + config = flow.EstimateMissingFlow.Config( + patch_size=16, + stride=16, + delta_z=1, + max_delta_z=2, + max_attempts=1, + mask_configs=None, + mask_only_for_patch_selection=False, + selection_mask_configs=None, + min_peak_sharpness=0.0, + min_peak_ratio=0.0, + max_magnitude=0, + batch_size=10, # Must be > 0 for batch processing + image_volinfo="dummy_path", + image_cache_bytes=0, + mask_cache_bytes=0, + search_radius=16, + ) + + # Larger volume to avoid boundary clipping with required context size + vol_shape = (1, 10, 128, 128) + vol_data = np.random.rand(*vol_shape).astype(np.float32) + + # Create a synthetic shift between z=3 and z=5. + dx, dy = 2, 3 + prev_slice = vol_data[0, 3, :, :] + shifted_slice = np.zeros_like(prev_slice) + shifted_slice[dy:, dx:] = prev_slice[:-dy, :-dx] + shifted_slice[:dy, :] = np.random.rand(dy, 128) + shifted_slice[:, :dx] = np.random.rand(128, dx) + + vol_data[0, 5, :, :] = shifted_slice + + mock_vol = MockVolume(vol_data) + processor = TestEstimateMissingFlow(config, mock_vol) + + # Start at 2,2,5 (flow coords) corresponds to 32,32,5 (image coords). + box = bounding_box.BoundingBox((2, 2, 5), (2, 2, 1)) + + # No pre-existing flow data. + input_data = np.full((2, 1, 2, 2), np.nan, dtype=np.float32) + subvol = subvolume.Subvolume(input_data, box) + + result_subvol = processor.process(subvol) + + self.assertEqual(result_subvol.data.shape, (3, 1, 2, 2)) + self.assertFalse( + np.any(np.isnan(result_subvol.data)), "Result contains NaNs" + ) + + np.testing.assert_allclose( + result_subvol.data[2, ...], 2, err_msg="delta_z incorrect" + ) + np.testing.assert_allclose( + result_subvol.data[0, 0, 0, 0], + -dx, + atol=0.5, + err_msg="Flow X incorrect", + ) + np.testing.assert_allclose( + result_subvol.data[1, 0, 0, 0], + -dy, + atol=0.5, + err_msg="Flow Y incorrect", + ) + + def test_process_clipped_context(self): + config = flow.EstimateMissingFlow.Config( + patch_size=16, + stride=16, + delta_z=1, + max_delta_z=5, # Large lookback + max_attempts=1, + mask_configs=None, + mask_only_for_patch_selection=False, + selection_mask_configs=None, + min_peak_sharpness=0.0, + min_peak_ratio=0.0, + max_magnitude=0, + batch_size=10, + image_volinfo="dummy_path", + image_cache_bytes=0, + mask_cache_bytes=0, + search_radius=16, + ) + + vol_shape = (1, 10, 128, 128) + vol_data = np.random.rand(*vol_shape).astype(np.float32) + + mock_vol = MockVolume(vol_data) + processor = TestEstimateMissingFlow(config, mock_vol) + + box = bounding_box.BoundingBox(start=(2, 2, 1), size=(2, 2, 1)) + + # No pre-existing flow data. + input_data = np.full((2, 1, 2, 2), np.nan, dtype=np.float32) + subvol = subvolume.Subvolume(input_data, box) + + result_subvol = processor.process(subvol) + + self.assertEqual(result_subvol.data.shape, (3, 1, 2, 2)) + + # Result should be NaNs because z=1 only has z=0 as valid prev. + # delta_z=1 (matching z=0) was not calculated (assumed missing). + # delta_z=2,3,4,5 look at z < 0, which is out of bounds. + self.assertTrue( + np.all(np.isnan(result_subvol.data[0, ...])), "Result X should be NaN" + ) + self.assertTrue( + np.all(np.isnan(result_subvol.data[1, ...])), "Result Y should be NaN" + ) + # Channel 2 is initialized to delta_z (1). + self.assertEqual(result_subvol.data[2, 0, 0, 0], 1) + + +if __name__ == "__main__": + absltest.main()