Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions keymorph/keypoint_aligners.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def fit(self, x, y, w=None):
out = torch.bmm(out, torch.transpose(x, -2, -1))
else:
out = torch.bmm(x, torch.transpose(x, -2, -1))
inv = torch.inverse(out)
# inv = torch.linalg.pinv(out)
inv = torch.linalg.pinv(out)
if w is not None:
out = torch.bmm(w, torch.transpose(x, -2, -1))
out = torch.bmm(out, inv)
Expand Down Expand Up @@ -396,6 +397,7 @@ def get_flow_field(
# See make_base_grid_5d() in https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/AffineGridGenerator.cpp
return transformed_grid.flip(-1)

@torch.compile()
def transform_points(self, theta, ctrl, points):
"""Evaluate the thin-plate-spline (TPS) surface at xy locations arranged in a grid.
The TPS surface is a minimum bend interpolation surface defined by a set of control points.
Expand Down Expand Up @@ -428,8 +430,10 @@ def transform_points(self, theta, ctrl, points):
P[:, :, 1:] = points[:, :, : self.dim]

# U is NxHxWxT
b = torch.bmm(U.transpose(1, 2), weights)
z = torch.bmm(P.view(N, -1, self.dim + 1), affine)
# b = torch.bmm(U.transpose(1, 2), weights)
# z = torch.bmm(P.view(N, -1, self.dim + 1), affine)
b = torch.einsum('btp,btd->bpd', U, weights)
z = torch.einsum('bpd,bda->bpa', P, affine)
return z + b

def get_inverse_transformed_points(self, points):
Expand Down Expand Up @@ -694,8 +698,8 @@ def get_forward_transformed_points(self, points):
# perm_mat = perm_mat[None, [0, 2, 1, 3], :] # 012, 021, 102, 120, 201, 210

# # Calculate the overall transformation matrix from moving to fixed image space
# overall_affine = torch.bmm(rescale_voxel2norm, torch.inverse(moving_affine))
# overall_affine = torch.bmm(overall_affine, torch.inverse(registration_affine))
# overall_affine = torch.bmm(rescale_voxel2norm, torch.linalg.pinv(moving_affine))
# overall_affine = torch.bmm(overall_affine, torch.linalg.pinv(registration_affine))
# overall_affine = torch.bmm(overall_affine, fixed_affine)
# overall_affine = torch.bmm(overall_affine, rescale_norm2voxel)

Expand Down
13 changes: 6 additions & 7 deletions keymorph/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ def __init__(
self.dim = dim
if matrix is not None and inverse_matrix is None:
self.transform_matrix = matrix
self.inverse_transform_matrix = torch.inverse(matrix)
self.inverse_transform_matrix = torch.linalg.pinv(matrix)
elif matrix is None and inverse_matrix is not None:
self.inverse_transform_matrix = inverse_matrix
self.transform_matrix = torch.inverse(inverse_matrix)
self.transform_matrix = torch.linalg.pinv(inverse_matrix)
else:
raise ValueError("Only one of matrix or inverse_matrix should be provided")

def _square(self, matrix):
square = torch.eye(self.dim + 1)[None]
batch_size = matrix.shape[0]
square = torch.eye(self.dim + 1).unsqueeze(0).repeat(batch_size, 1, 1)
square[:, : self.dim, : self.dim + 1] = matrix
return square

Expand All @@ -53,7 +53,7 @@ def affine_grid(self, grid_shape):

moving_voxel_coords = self.get_inverse_transformed_points(grid_flat)

transformed_grid = moving_voxel_coords.reshape(1, *grid_shape[2:], self.dim)
transformed_grid = moving_voxel_coords.reshape(-1, *grid_shape[2:], self.dim)

return transformed_grid

Expand Down Expand Up @@ -109,6 +109,5 @@ def get_inverse_transformed_points(self, points):
# Convert to homogeneous coordinates
ones = torch.ones(batch_size, num_points, 1).to(points.device)
points = torch.cat([points, ones], dim=2)
points = torch.bmm(transform_matrix, points.permute(0, 2, 1)).permute(0, 2, 1)

points = torch.einsum('brc,bpc->bpr', transform_matrix, points)
return points
30 changes: 23 additions & 7 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,32 @@ def test_affine_1(self):
def test_affine_2(self):
"""Test 3d rotation around z-axis for 2d points plus scaling."""

input1 = torch.tensor([[1, 0, 0], [0, -1, 0], [-1, 0, 0], [0, 1, 0]]).float()
input1 = input1.view(1, 4, 3)
input2 = torch.tensor([[0, -1, 0], [-1, 0, 0], [0, 1, 0], [1, 0, 0]]).float()
input2 = input2.view(1, 4, 3)
# Create a simple 3D tetrahedron (4 points minimum for 3D affine)
# Plus one more point to ensure unique solution
input1 = torch.tensor([[1, 0, 0], # Point on x-axis
[0, 1, 0], # Point on y-axis
[0, 0, 1], # Point on z-axis
[0, 0, 0], # Origin
[1, 1, 1]]).float() # Corner point
input1 = input1.view(1, 5, 3)

# Apply a 60-degree (π/3) rotation around z-axis
# This is obvious: rotates x→between x&y, y→between y&-x
r = np.pi / 3 # 60 degrees
cos_r = np.cos(r) # = 0.5
sin_r = np.sin(r) # = √3/2 ≈ 0.866

input2 = torch.tensor([[cos_r, sin_r, 0], # (1,0,0) rotated
[-sin_r, cos_r, 0], # (0,1,0) rotated
[0, 0, 1], # (0,0,1) unchanged by z-rotation
[0, 0, 0], # Origin unchanged
[cos_r - sin_r, sin_r + cos_r, 1]]).float() # (1,1,1) rotated
input2 = input2.view(1, 5, 3)

r = -np.pi / 2
true = torch.tensor(
[
[np.cos(r), -np.sin(r), 0, 0],
[np.sin(r), np.cos(r), 0, 0],
[cos_r, -sin_r, 0, 0],
[sin_r, cos_r, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
Expand Down