Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 75 additions & 38 deletions processor/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Flow field estimation from SOFIMA."""

import dataclasses
import gc
import time
from typing import Any, Sequence

Expand Down Expand Up @@ -594,6 +595,7 @@ def __init__(
)

self._config = config
logging.info('EstimateMissingFlow running with config: %r', config)

def _build_mask(
self,
Expand Down Expand Up @@ -661,6 +663,20 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
out_box = out_box.adjusted_by(end=-offset)
input_ndarray = input_ndarray[:, :, : out_box.size[1], : out_box.size[0]]

# The input flow forms the initial state of the output. We will try
# to fill-in any invalid (NaN) pixels by computing flow against
# earlier sections.
ret = np.zeros([3] + list(out_box.size[::-1]))
ret[:2, ...] = input_ndarray
ret[2, ...] = self._config.delta_z

sel_mask = None
if self._config.selection_mask_configs:
sel_mask = self._build_mask(self._config.selection_mask_configs, out_box)

mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
invalid = np.isnan(input_ndarray[0, ...])

patch_size = self._config.patch_size
curr_image_box = bounding_box.BoundingBox(
start=(
Expand All @@ -671,25 +687,55 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
size=(
(out_box.size[0] - 1) * stride + patch_size,
(out_box.size[1] - 1) * stride + patch_size,
1,
invalid.shape[0],
),
)
curr_image_box = image_volume.clip_box_to_volume(curr_image_box)
assert curr_image_box is not None

# The input flow forms the initial state of the output. We will try
# to fill-in any invalid (NaN) pixels by computing flow against
# earlier sections.
ret = np.zeros([3] + list(out_box.size[::-1]))
ret[:2, ...] = input_ndarray
ret[2, ...] = self._config.delta_z
if self._config.delta_z > 0:
search_deltas = range(
self._config.delta_z + 1, self._config.max_delta_z + 1
)
load_start_z = out_box.start[2] - self._config.max_delta_z
load_end_z = out_box.end[2]
else:
search_deltas = range(
self._config.delta_z - 1, self._config.max_delta_z - 1, -1
)
load_start_z = out_box.start[2]
# max_delta_z is negative.
load_end_z = out_box.end[2] - self._config.max_delta_z

sel_mask = None
if self._config.selection_mask_configs:
sel_mask = self._build_mask(self._config.selection_mask_configs, out_box)
load_box = bounding_box.BoundingBox(
start=(
prev_image_box.start[0],
prev_image_box.start[1],
load_start_z,
),
size=(
prev_image_box.size[0],
prev_image_box.size[1],
load_end_z - load_start_z,
),
)
load_box = image_volume.clip_box_to_volume(load_box)

logging.info('Loading image data: %r', load_box)
full_image_stack = image_volume.asarray[load_box.to_slice4d()][0, ...]
full_mask = None
if self._config.mask_configs:
full_mask = self._build_mask(self._config.mask_configs, load_box)
logging.info('Loaaded mask: %r', full_mask.shape)

# The 'curr' image is a subset of the loaded stack, centered within the
# 'prev' image (which includes the search radius).
curr_rel_start = curr_image_box.start - load_box.start
curr_slice = (
slice(curr_rel_start[1], curr_rel_start[1] + curr_image_box.size[1]),
slice(curr_rel_start[0], curr_rel_start[0] + curr_image_box.size[0]),
)

mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
invalid = np.isnan(input_ndarray[0, ...])
for z in range(0, invalid.shape[0]):
z0 = box.start[2] + z
logging.info('Processing rel_z=%d abs_z=%d', z, z0)
Expand All @@ -698,12 +744,13 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
beam_utils.counter(namespace, 'sections-already-valid').inc()
continue

image_box = curr_image_box.translate([0, 0, z])
curr_z_idx = (out_box.start[2] + z) - load_box.start[2]
assert curr_z_idx >= 0
assert curr_z_idx < full_image_stack.shape[0]

curr_mask = None
if self._config.mask_configs:
curr_mask = self._build_mask(
self._config.mask_configs, image_box
).squeeze()
curr_mask = full_mask[curr_z_idx, ...][curr_slice]
if np.all(curr_mask):
beam_utils.counter(namespace, 'sections-masked').inc()
continue
Expand All @@ -715,37 +762,23 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
if sel_mask is not None:
mask &= sel_mask[z, ...]

curr = image_volume.asarray[image_box.to_slice4d()].squeeze()

delta_z = self._config.delta_z
if delta_z > 0:
rng = range(delta_z + 1, self._config.max_delta_z + 1)
else:
rng = range(delta_z - 1, self._config.max_delta_z - 1, -1)
curr = full_image_stack[curr_z_idx, ...][curr_slice]

for delta_z in rng:
if (
box.start[2] - delta_z < 0
or box.end[2] - delta_z >= image_volume.volume_size[2]
):
for delta_z in search_deltas:
prev_z_idx = curr_z_idx - delta_z
if prev_z_idx < 0 or prev_z_idx >= full_image_stack.shape[0]:
break

t_start = time.time()
prev_box = prev_image_box.translate([0, 0, z - delta_z])
logging.info('Trying delta_z=%d (%r)', delta_z, prev_box)
prev = image_volume.asarray[prev_box.to_slice4d()].squeeze()
logging.info('.. image loaded.')
logging.info('Trying delta_z=%d', delta_z)
prev_mask = None
prev = full_image_stack[prev_z_idx, ...]
t1 = time.time()

if self._config.mask_configs:
prev_mask = self._build_mask(
self._config.mask_configs, prev_box
).squeeze()
prev_mask = full_mask[prev_z_idx, ...]
if np.all(prev_mask):
continue
else:
prev_mask = None
logging.info('.. mask loaded.')

# Limit the number of estimation attempts per voxel. Attempts
# are only counted when voxels in both sections are unmasked.
Expand Down Expand Up @@ -804,4 +837,8 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany:
t5 - t4,
)

del full_image_stack
del full_mask
gc.collect()

return Subvolume(ret, out_box)
174 changes: 174 additions & 0 deletions processor/flow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# coding=utf-8
# Copyright 2026 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from connectomics.common import bounding_box
from connectomics.volume import subvolume
import numpy as np
from sofima.processor import flow


class MockVolume:

def __init__(self, data):
self._data = data # CZYX

def clip_box_to_volume(self, box):
vol_box = bounding_box.BoundingBox(start=(0, 0, 0), size=self.volume_size)
return box.intersection(vol_box)

@property
def asarray(self):
return self._data

@property
def volume_size(self):
# XYZ
return (self._data.shape[3], self._data.shape[2], self._data.shape[1])

def __getitem__(self, key):
return self._data[key]


class TestEstimateMissingFlow(flow.EstimateMissingFlow):

def __init__(self, config, image_vol):
super().__init__(config)
self.image_vol = image_vol

def _open_volume(self, path):
return self.image_vol


class EstimateMissingFlowTest(absltest.TestCase):

def test_process(self):
config = flow.EstimateMissingFlow.Config(
patch_size=16,
stride=16,
delta_z=1,
max_delta_z=2,
max_attempts=1,
mask_configs=None,
mask_only_for_patch_selection=False,
selection_mask_configs=None,
min_peak_sharpness=0.0,
min_peak_ratio=0.0,
max_magnitude=0,
batch_size=10, # Must be > 0 for batch processing
image_volinfo="dummy_path",
image_cache_bytes=0,
mask_cache_bytes=0,
search_radius=16,
)

# Larger volume to avoid boundary clipping with required context size
vol_shape = (1, 10, 128, 128)
vol_data = np.random.rand(*vol_shape).astype(np.float32)

# Create a synthetic shift between z=3 and z=5.
dx, dy = 2, 3
prev_slice = vol_data[0, 3, :, :]
shifted_slice = np.zeros_like(prev_slice)
shifted_slice[dy:, dx:] = prev_slice[:-dy, :-dx]
shifted_slice[:dy, :] = np.random.rand(dy, 128)
shifted_slice[:, :dx] = np.random.rand(128, dx)

vol_data[0, 5, :, :] = shifted_slice

mock_vol = MockVolume(vol_data)
processor = TestEstimateMissingFlow(config, mock_vol)

# Start at 2,2,5 (flow coords) corresponds to 32,32,5 (image coords).
box = bounding_box.BoundingBox((2, 2, 5), (2, 2, 1))

# No pre-existing flow data.
input_data = np.full((2, 1, 2, 2), np.nan, dtype=np.float32)
subvol = subvolume.Subvolume(input_data, box)

result_subvol = processor.process(subvol)

self.assertEqual(result_subvol.data.shape, (3, 1, 2, 2))
self.assertFalse(
np.any(np.isnan(result_subvol.data)), "Result contains NaNs"
)

np.testing.assert_allclose(
result_subvol.data[2, ...], 2, err_msg="delta_z incorrect"
)
np.testing.assert_allclose(
result_subvol.data[0, 0, 0, 0],
-dx,
atol=0.5,
err_msg="Flow X incorrect",
)
np.testing.assert_allclose(
result_subvol.data[1, 0, 0, 0],
-dy,
atol=0.5,
err_msg="Flow Y incorrect",
)

def test_process_clipped_context(self):
config = flow.EstimateMissingFlow.Config(
patch_size=16,
stride=16,
delta_z=1,
max_delta_z=5, # Large lookback
max_attempts=1,
mask_configs=None,
mask_only_for_patch_selection=False,
selection_mask_configs=None,
min_peak_sharpness=0.0,
min_peak_ratio=0.0,
max_magnitude=0,
batch_size=10,
image_volinfo="dummy_path",
image_cache_bytes=0,
mask_cache_bytes=0,
search_radius=16,
)

vol_shape = (1, 10, 128, 128)
vol_data = np.random.rand(*vol_shape).astype(np.float32)

mock_vol = MockVolume(vol_data)
processor = TestEstimateMissingFlow(config, mock_vol)

box = bounding_box.BoundingBox(start=(2, 2, 1), size=(2, 2, 1))

# No pre-existing flow data.
input_data = np.full((2, 1, 2, 2), np.nan, dtype=np.float32)
subvol = subvolume.Subvolume(input_data, box)

result_subvol = processor.process(subvol)

self.assertEqual(result_subvol.data.shape, (3, 1, 2, 2))

# Result should be NaNs because z=1 only has z=0 as valid prev.
# delta_z=1 (matching z=0) was not calculated (assumed missing).
# delta_z=2,3,4,5 look at z < 0, which is out of bounds.
self.assertTrue(
np.all(np.isnan(result_subvol.data[0, ...])), "Result X should be NaN"
)
self.assertTrue(
np.all(np.isnan(result_subvol.data[1, ...])), "Result Y should be NaN"
)
# Channel 2 is initialized to delta_z (1).
self.assertEqual(result_subvol.data[2, 0, 0, 0], 1)


if __name__ == "__main__":
absltest.main()