Skip to content

Implement tensornet warp kernels#384

Merged
stefdoerr merged 2 commits intomainfrom
tensornet_warp_ops
Mar 4, 2026
Merged

Implement tensornet warp kernels#384
stefdoerr merged 2 commits intomainfrom
tensornet_warp_ops

Conversation

@sef43
Copy link
Collaborator

@sef43 sef43 commented Mar 3, 2026

  • Implements tensornet warp kernels copied from Accelerated TensorNet with NVIDIA Warp Kernels materialyzeai/matgl#709 as originally implemented by @zubatyuk.

  • To work with the warp kernels the tensornet code has been refactored to use shapes [N,3,3,F] instead of the original [N,F,3,3]. This change required reshaping of weights from models trained by previous code. Older checkpoints are currently auto-detected using the presence of the check-errors flag which was removed in a recent commit (6634490) . The loading method can also be set with a new compatibility_load=True or False flag.

  • If the warp kernels fail to load the pure torch functions will be used. These have been refactored to match the call signatures and shapes of the warp kernels.

  • The speedup of the warp kernels is approximately 3x for inference and training

  • Note that the neighborlist gets converted from a COO format into a CSR format. I have tested with static_shapes=True and neighborlist mode 'brute' which seems to work correctly but there is the potential for subtle bugs to be introduced when static shapes are used.

Implements tensornet warp kernels copied from materialyzeai/matgl#709 as originally implemented by @zubatyuk.

To work with the warp kernels the tensornet code has been refactored to use shapes [N,3,3,F] instead of the original [N,F,3,3].
This change required reshaping of weights from models trained by previous code. Older checkpoints are currently auto-detected using the presence of the check-errors flag which was removed in a recent commit. The loading method can also be set with a new compatibility_load=True|False flag.

If the warp kernels fail to load the pure torch functions will be used. These have been refactored to match the call signatures and shapes of the warp kernels.

The speedup of the warp kernels is approximately 3x for inference and training
@sef43 sef43 requested a review from stefdoerr March 3, 2026 14:02
Copy link
Collaborator

@stefdoerr stefdoerr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see it was for compatibility load. Looks good

@stefdoerr stefdoerr merged commit 7fdb313 into main Mar 4, 2026
19 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants