From d94bf622d5f9225edd00a41f94b942e2095d8959 Mon Sep 17 00:00:00 2001 From: eperot Date: Tue, 21 Oct 2025 17:08:32 +0200 Subject: [PATCH 1/5] clean --- keymorph/keypoint_aligners.py | 7 +++++-- keymorph/transformations.py | 7 +++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/keymorph/keypoint_aligners.py b/keymorph/keypoint_aligners.py index 6c45fa1..8cf2307 100644 --- a/keymorph/keypoint_aligners.py +++ b/keymorph/keypoint_aligners.py @@ -396,6 +396,7 @@ def get_flow_field( # See make_base_grid_5d() in https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/AffineGridGenerator.cpp return transformed_grid.flip(-1) + @torch.compile() def transform_points(self, theta, ctrl, points): """Evaluate the thin-plate-spline (TPS) surface at xy locations arranged in a grid. The TPS surface is a minimum bend interpolation surface defined by a set of control points. @@ -428,8 +429,10 @@ def transform_points(self, theta, ctrl, points): P[:, :, 1:] = points[:, :, : self.dim] # U is NxHxWxT - b = torch.bmm(U.transpose(1, 2), weights) - z = torch.bmm(P.view(N, -1, self.dim + 1), affine) + # b = torch.bmm(U.transpose(1, 2), weights) + # z = torch.bmm(P.view(N, -1, self.dim + 1), affine) + b = torch.einsum('btp,btd->bpd', U, weights) + z = torch.einsum('bpd,bda->bpa', P, affine) return z + b def get_inverse_transformed_points(self, points): diff --git a/keymorph/transformations.py b/keymorph/transformations.py index e54647c..06484f3 100644 --- a/keymorph/transformations.py +++ b/keymorph/transformations.py @@ -28,9 +28,9 @@ def __init__( self.transform_matrix = torch.inverse(inverse_matrix) else: raise ValueError("Only one of matrix or inverse_matrix should be provided") - def _square(self, matrix): - square = torch.eye(self.dim + 1)[None] + batch_size = matrix.shape[0] + square = torch.eye(self.dim + 1).unsqueeze(0).repeat(batch_size, 1, 1) square[:, : self.dim, : self.dim + 1] = matrix return square @@ -109,6 +109,5 @@ def get_inverse_transformed_points(self, points): # Convert to homogeneous coordinates ones = torch.ones(batch_size, num_points, 1).to(points.device) points = torch.cat([points, ones], dim=2) - points = torch.bmm(transform_matrix, points.permute(0, 2, 1)).permute(0, 2, 1) - + points = torch.einsum('brc,bpc->bpr', transform_matrix, points) return points From 8330e06848e870dc9fd4e85587ba7d45fca5e346 Mon Sep 17 00:00:00 2001 From: eperot Date: Wed, 22 Oct 2025 11:06:35 +0200 Subject: [PATCH 2/5] forgot to push this --- keymorph/transformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keymorph/transformations.py b/keymorph/transformations.py index 06484f3..d7b1c3a 100644 --- a/keymorph/transformations.py +++ b/keymorph/transformations.py @@ -53,7 +53,7 @@ def affine_grid(self, grid_shape): moving_voxel_coords = self.get_inverse_transformed_points(grid_flat) - transformed_grid = moving_voxel_coords.reshape(1, *grid_shape[2:], self.dim) + transformed_grid = moving_voxel_coords.reshape(-1, *grid_shape[2:], self.dim) return transformed_grid From 96292dab0ca18b98eaa63c9a52998f598812c32a Mon Sep 17 00:00:00 2001 From: eperot Date: Thu, 23 Oct 2025 15:52:33 +0200 Subject: [PATCH 3/5] use pinv, more stable than inv --- keymorph/keypoint_aligners.py | 7 ++++--- keymorph/transformations.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/keymorph/keypoint_aligners.py b/keymorph/keypoint_aligners.py index 8cf2307..5808aa5 100644 --- a/keymorph/keypoint_aligners.py +++ b/keymorph/keypoint_aligners.py @@ -104,7 +104,8 @@ def fit(self, x, y, w=None): out = torch.bmm(out, torch.transpose(x, -2, -1)) else: out = torch.bmm(x, torch.transpose(x, -2, -1)) - inv = torch.inverse(out) + # inv = torch.linalg.pinv(out) + inv = torch.linalg.pinv(out) if w is not None: out = torch.bmm(w, torch.transpose(x, -2, -1)) out = torch.bmm(out, inv) @@ -697,8 +698,8 @@ def get_forward_transformed_points(self, points): # perm_mat = perm_mat[None, [0, 2, 1, 3], :] # 012, 021, 102, 120, 201, 210 # # Calculate the overall transformation matrix from moving to fixed image space -# overall_affine = torch.bmm(rescale_voxel2norm, torch.inverse(moving_affine)) -# overall_affine = torch.bmm(overall_affine, torch.inverse(registration_affine)) +# overall_affine = torch.bmm(rescale_voxel2norm, torch.linalg.pinv(moving_affine)) +# overall_affine = torch.bmm(overall_affine, torch.linalg.pinv(registration_affine)) # overall_affine = torch.bmm(overall_affine, fixed_affine) # overall_affine = torch.bmm(overall_affine, rescale_norm2voxel) diff --git a/keymorph/transformations.py b/keymorph/transformations.py index d7b1c3a..634a2d7 100644 --- a/keymorph/transformations.py +++ b/keymorph/transformations.py @@ -22,10 +22,10 @@ def __init__( self.dim = dim if matrix is not None and inverse_matrix is None: self.transform_matrix = matrix - self.inverse_transform_matrix = torch.inverse(matrix) + self.inverse_transform_matrix = torch.linalg.pinb(matrix) elif matrix is None and inverse_matrix is not None: self.inverse_transform_matrix = inverse_matrix - self.transform_matrix = torch.inverse(inverse_matrix) + self.transform_matrix = torch.linalg.pinv(inverse_matrix) else: raise ValueError("Only one of matrix or inverse_matrix should be provided") def _square(self, matrix): From 1b869f7a24f8bc51fed4ec17d2022eefd3ecc0c7 Mon Sep 17 00:00:00 2001 From: eperot Date: Sun, 26 Oct 2025 14:03:55 +0100 Subject: [PATCH 4/5] typo pinv --- keymorph/transformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keymorph/transformations.py b/keymorph/transformations.py index 634a2d7..ad45782 100644 --- a/keymorph/transformations.py +++ b/keymorph/transformations.py @@ -22,7 +22,7 @@ def __init__( self.dim = dim if matrix is not None and inverse_matrix is None: self.transform_matrix = matrix - self.inverse_transform_matrix = torch.linalg.pinb(matrix) + self.inverse_transform_matrix = torch.linalg.pinv(matrix) elif matrix is None and inverse_matrix is not None: self.inverse_transform_matrix = inverse_matrix self.transform_matrix = torch.linalg.pinv(inverse_matrix) From c62c09eaf95b2905845c9585096a946e249c51bb Mon Sep 17 00:00:00 2001 From: eperot Date: Thu, 13 Nov 2025 09:51:37 +0100 Subject: [PATCH 5/5] added obvious and working unit test --- test/test.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/test/test.py b/test/test.py index 66c0061..775817a 100644 --- a/test/test.py +++ b/test/test.py @@ -459,16 +459,32 @@ def test_affine_1(self): def test_affine_2(self): """Test 3d rotation around z-axis for 2d points plus scaling.""" - input1 = torch.tensor([[1, 0, 0], [0, -1, 0], [-1, 0, 0], [0, 1, 0]]).float() - input1 = input1.view(1, 4, 3) - input2 = torch.tensor([[0, -1, 0], [-1, 0, 0], [0, 1, 0], [1, 0, 0]]).float() - input2 = input2.view(1, 4, 3) + # Create a simple 3D tetrahedron (4 points minimum for 3D affine) + # Plus one more point to ensure unique solution + input1 = torch.tensor([[1, 0, 0], # Point on x-axis + [0, 1, 0], # Point on y-axis + [0, 0, 1], # Point on z-axis + [0, 0, 0], # Origin + [1, 1, 1]]).float() # Corner point + input1 = input1.view(1, 5, 3) + + # Apply a 60-degree (π/3) rotation around z-axis + # This is obvious: rotates x→between x&y, y→between y&-x + r = np.pi / 3 # 60 degrees + cos_r = np.cos(r) # = 0.5 + sin_r = np.sin(r) # = √3/2 ≈ 0.866 + + input2 = torch.tensor([[cos_r, sin_r, 0], # (1,0,0) rotated + [-sin_r, cos_r, 0], # (0,1,0) rotated + [0, 0, 1], # (0,0,1) unchanged by z-rotation + [0, 0, 0], # Origin unchanged + [cos_r - sin_r, sin_r + cos_r, 1]]).float() # (1,1,1) rotated + input2 = input2.view(1, 5, 3) - r = -np.pi / 2 true = torch.tensor( [ - [np.cos(r), -np.sin(r), 0, 0], - [np.sin(r), np.cos(r), 0, 0], + [cos_r, -sin_r, 0, 0], + [sin_r, cos_r, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], ]