diff --git a/keymorph/keypoint_aligners.py b/keymorph/keypoint_aligners.py index 6c45fa1..5808aa5 100644 --- a/keymorph/keypoint_aligners.py +++ b/keymorph/keypoint_aligners.py @@ -104,7 +104,8 @@ def fit(self, x, y, w=None): out = torch.bmm(out, torch.transpose(x, -2, -1)) else: out = torch.bmm(x, torch.transpose(x, -2, -1)) - inv = torch.inverse(out) + # inv = torch.linalg.pinv(out) + inv = torch.linalg.pinv(out) if w is not None: out = torch.bmm(w, torch.transpose(x, -2, -1)) out = torch.bmm(out, inv) @@ -396,6 +397,7 @@ def get_flow_field( # See make_base_grid_5d() in https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/AffineGridGenerator.cpp return transformed_grid.flip(-1) + @torch.compile() def transform_points(self, theta, ctrl, points): """Evaluate the thin-plate-spline (TPS) surface at xy locations arranged in a grid. The TPS surface is a minimum bend interpolation surface defined by a set of control points. @@ -428,8 +430,10 @@ def transform_points(self, theta, ctrl, points): P[:, :, 1:] = points[:, :, : self.dim] # U is NxHxWxT - b = torch.bmm(U.transpose(1, 2), weights) - z = torch.bmm(P.view(N, -1, self.dim + 1), affine) + # b = torch.bmm(U.transpose(1, 2), weights) + # z = torch.bmm(P.view(N, -1, self.dim + 1), affine) + b = torch.einsum('btp,btd->bpd', U, weights) + z = torch.einsum('bpd,bda->bpa', P, affine) return z + b def get_inverse_transformed_points(self, points): @@ -694,8 +698,8 @@ def get_forward_transformed_points(self, points): # perm_mat = perm_mat[None, [0, 2, 1, 3], :] # 012, 021, 102, 120, 201, 210 # # Calculate the overall transformation matrix from moving to fixed image space -# overall_affine = torch.bmm(rescale_voxel2norm, torch.inverse(moving_affine)) -# overall_affine = torch.bmm(overall_affine, torch.inverse(registration_affine)) +# overall_affine = torch.bmm(rescale_voxel2norm, torch.linalg.pinv(moving_affine)) +# overall_affine = torch.bmm(overall_affine, torch.linalg.pinv(registration_affine)) # overall_affine = torch.bmm(overall_affine, fixed_affine) # overall_affine = torch.bmm(overall_affine, rescale_norm2voxel) diff --git a/keymorph/transformations.py b/keymorph/transformations.py index e54647c..ad45782 100644 --- a/keymorph/transformations.py +++ b/keymorph/transformations.py @@ -22,15 +22,15 @@ def __init__( self.dim = dim if matrix is not None and inverse_matrix is None: self.transform_matrix = matrix - self.inverse_transform_matrix = torch.inverse(matrix) + self.inverse_transform_matrix = torch.linalg.pinv(matrix) elif matrix is None and inverse_matrix is not None: self.inverse_transform_matrix = inverse_matrix - self.transform_matrix = torch.inverse(inverse_matrix) + self.transform_matrix = torch.linalg.pinv(inverse_matrix) else: raise ValueError("Only one of matrix or inverse_matrix should be provided") - def _square(self, matrix): - square = torch.eye(self.dim + 1)[None] + batch_size = matrix.shape[0] + square = torch.eye(self.dim + 1).unsqueeze(0).repeat(batch_size, 1, 1) square[:, : self.dim, : self.dim + 1] = matrix return square @@ -53,7 +53,7 @@ def affine_grid(self, grid_shape): moving_voxel_coords = self.get_inverse_transformed_points(grid_flat) - transformed_grid = moving_voxel_coords.reshape(1, *grid_shape[2:], self.dim) + transformed_grid = moving_voxel_coords.reshape(-1, *grid_shape[2:], self.dim) return transformed_grid @@ -109,6 +109,5 @@ def get_inverse_transformed_points(self, points): # Convert to homogeneous coordinates ones = torch.ones(batch_size, num_points, 1).to(points.device) points = torch.cat([points, ones], dim=2) - points = torch.bmm(transform_matrix, points.permute(0, 2, 1)).permute(0, 2, 1) - + points = torch.einsum('brc,bpc->bpr', transform_matrix, points) return points diff --git a/test/test.py b/test/test.py index 66c0061..775817a 100644 --- a/test/test.py +++ b/test/test.py @@ -459,16 +459,32 @@ def test_affine_1(self): def test_affine_2(self): """Test 3d rotation around z-axis for 2d points plus scaling.""" - input1 = torch.tensor([[1, 0, 0], [0, -1, 0], [-1, 0, 0], [0, 1, 0]]).float() - input1 = input1.view(1, 4, 3) - input2 = torch.tensor([[0, -1, 0], [-1, 0, 0], [0, 1, 0], [1, 0, 0]]).float() - input2 = input2.view(1, 4, 3) + # Create a simple 3D tetrahedron (4 points minimum for 3D affine) + # Plus one more point to ensure unique solution + input1 = torch.tensor([[1, 0, 0], # Point on x-axis + [0, 1, 0], # Point on y-axis + [0, 0, 1], # Point on z-axis + [0, 0, 0], # Origin + [1, 1, 1]]).float() # Corner point + input1 = input1.view(1, 5, 3) + + # Apply a 60-degree (π/3) rotation around z-axis + # This is obvious: rotates x→between x&y, y→between y&-x + r = np.pi / 3 # 60 degrees + cos_r = np.cos(r) # = 0.5 + sin_r = np.sin(r) # = √3/2 ≈ 0.866 + + input2 = torch.tensor([[cos_r, sin_r, 0], # (1,0,0) rotated + [-sin_r, cos_r, 0], # (0,1,0) rotated + [0, 0, 1], # (0,0,1) unchanged by z-rotation + [0, 0, 0], # Origin unchanged + [cos_r - sin_r, sin_r + cos_r, 1]]).float() # (1,1,1) rotated + input2 = input2.view(1, 5, 3) - r = -np.pi / 2 true = torch.tensor( [ - [np.cos(r), -np.sin(r), 0, 0], - [np.sin(r), np.cos(r), 0, 0], + [cos_r, -sin_r, 0, 0], + [sin_r, cos_r, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], ]