Port to Python and libtorch stable ABI#102
Open
woct0rdho wants to merge 3 commits into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I've ported this repo to Python stable ABI (ABI3) and libtorch stable ABI. See https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html for a modern guide of torch custom ops.
This means we no longer need to build a different wheel for every Python and PyTorch version. We only need to build different wheels for Windows/Linux and CUDA 12/CUDA 13/ROCm 7. It will help the adoption of this package, notably because Unsloth is recommending this package in Qwen3.5 training. The same porting is already done in packages like flash-attn-3.
I've built the wheel and run the unit tests on a machine with Windows, RTX 3080 (sm86), CUDA 13, torch 2.11, and another machine with Linux, Strix Halo (gfx1151), ROCm 7, torch 2.12 nightly.
A concern is that libtorch stable ABI only supports Python >= 3.10 and torch >= 2.10 . It's possible to make it support torch 2.9 with some extra effort. As you've dropped support for Pascal and Volta, maybe we can also drop support for torch < 2.10 . Or we can keep the old code without stable ABI in a legacy branch.
A notable change is that I moved the detection of deterministic mode from C++ to Python, because it's not in the stable C++ API. Also,
os.getenvcannot be traced intorch.compile, so I detect it when the package is imported rather than when the function is called. I've updated the corresponding tests.