diff --git a/README.md b/README.md index 229dec8e2..ce88d15a0 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,8 @@ pip3 install --pre torch torchvision torchaudio --index-url https://download.pyt ### Optional Dependencies - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers. +- `cuda-tile`: Required when enabling the optional cuTile backend on CUDA. Use this when your environment already provides CUDA Toolkit 13.1 or newer, or an existing tileiras compiler installation. +- `cuda-tile[tileiras]`: Required when enabling the optional cuTile backend with the tileiras compiler installed directly into your Python environment. > **Note:** > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton). @@ -168,10 +170,26 @@ pip install -e . # Setup Development Dependencies pip install -e ".[dev]" +# Setup cuTile Dependencies +pip install -e ".[cutile]" + +# Or install cuTile with the optional tileiras compiler +pip install -e ".[cutile-tileiras]" + # NOTE -> For AMD users only pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/ ``` +### Enable cuTile Backend + +cuTile is an optional CUDA-only DSL implementation. After installing the `cutile` or `cutile-tileiras` extra, enable it explicitly: + +```bash +LIGER_KERNEL_IMPL=cutile python your_script.py +``` + +`LIGER_KERNEL_IMPL` selects an opt-in implementation registered with Liger (currently `cutile`). Selecting one on an unsupported device, or without the required dependencies installed, raises an error. + ## Getting Started @@ -290,7 +308,7 @@ loss.backward() | **Kernel** | **API** | |---------------------------------|-------------------------------------------------------------| | RMSNorm | `liger_kernel.transformers.LigerRMSNorm` | -| Modulated RMSNorm | `liger_kernel.transformers.LigerModulatedRMSNorm` | +| Modulated RMSNorm | `liger_kernel.transformers.LigerModulatedRMSNorm` | | LayerNorm | `liger_kernel.transformers.LigerLayerNorm` | | RoPE | `liger_kernel.transformers.liger_rotary_pos_emb` | | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` | diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 366ee7db5..7b9b0c1b0 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -2187,3 +2187,35 @@ fused_moe,huggingface,backward,memory,MB,E,num_experts,16,2072.1728515625,2072.1 fused_moe,huggingface,backward,memory,MB,E,num_experts,32,2737.08349609375,2737.08349609375,2737.08349609375,"{""sweep_dim"": ""E"", ""T"": 8192, ""E"": null, ""H"": 2048, ""intermediate_dim"": 768, ""K"": 8, ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2026-04-02 23:59:56,0.7.0 fused_moe,huggingface,backward,memory,MB,E,num_experts,64,4078.97021484375,4078.97021484375,4078.97021484375,"{""sweep_dim"": ""E"", ""T"": 8192, ""E"": null, ""H"": 2048, ""intermediate_dim"": 768, ""K"": 8, ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2026-04-02 23:59:56,0.7.0 fused_moe,huggingface,backward,memory,MB,E,num_experts,128,6763.82275390625,6763.82275390625,6763.82275390625,"{""sweep_dim"": ""E"", ""T"": 8192, ""E"": null, ""H"": 2048, ""intermediate_dim"": 768, ""K"": 8, ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2026-04-02 23:59:56,0.7.0 +jsd,torch,full,speed,ms,BT,total tokens,1024,5.921823978424072,5.921823978424072,5.921823978424072,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:13:37,0.8.0 +jsd,torch,full,speed,ms,BT,total tokens,2048,12.200063705444336,12.200063705444336,12.200063705444336,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:13:37,0.8.0 +jsd,torch,full,speed,ms,BT,total tokens,4096,24.145984649658203,24.145984649658203,24.145984649658203,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:13:37,0.8.0 +jsd,torch,full,speed,ms,BT,total tokens,8192,50.45283126831055,50.45283126831055,50.45283126831055,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:13:37,0.8.0 +jsd,liger,full,speed,ms,BT,total tokens,1024,6.0959038734436035,6.0959038734436035,6.0959038734436035,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:28,0.8.0 +jsd,liger,full,speed,ms,BT,total tokens,2048,10.940447807312012,10.940447807312012,10.940447807312012,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:28,0.8.0 +jsd,liger,full,speed,ms,BT,total tokens,4096,21.781631469726562,21.781631469726562,21.781631469726562,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:28,0.8.0 +jsd,liger,full,speed,ms,BT,total tokens,8192,44.07699203491211,44.07699203491211,44.07699203491211,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:28,0.8.0 +jsd,torch,forward,speed,ms,BT,total tokens,1024,2.2900800704956055,2.2883904933929444,2.2906303882598875,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:29,0.8.0 +jsd,torch,forward,speed,ms,BT,total tokens,2048,4.97105598449707,4.9135422706604,5.02856969833374,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:29,0.8.0 +jsd,torch,forward,speed,ms,BT,total tokens,4096,9.907423973083496,9.907423973083496,9.907423973083496,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:29,0.8.0 +jsd,torch,forward,speed,ms,BT,total tokens,8192,20.02751922607422,20.02751922607422,20.02751922607422,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:29,0.8.0 +jsd,liger,forward,speed,ms,BT,total tokens,1024,5.783552169799805,5.783552169799805,5.783552169799805,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:29,0.8.0 +jsd,liger,forward,speed,ms,BT,total tokens,2048,9.110560417175293,9.110560417175293,9.110560417175293,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:29,0.8.0 +jsd,liger,forward,speed,ms,BT,total tokens,4096,18.322431564331055,18.322431564331055,18.322431564331055,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:29,0.8.0 +jsd,liger,forward,speed,ms,BT,total tokens,8192,37.44358444213867,37.44358444213867,37.44358444213867,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:29,0.8.0 +jsd,torch,backward,speed,ms,BT,total tokens,1024,3.7858558893203735,3.7852798938751224,3.786431884765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,torch,backward,speed,ms,BT,total tokens,2048,7.665791988372803,7.665791988372803,7.665791988372803,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,torch,backward,speed,ms,BT,total tokens,4096,15.20956802368164,15.20956802368164,15.20956802368164,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,torch,backward,speed,ms,BT,total tokens,8192,30.310592651367188,30.310592651367188,30.310592651367188,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,liger,backward,speed,ms,BT,total tokens,1024,1.0158560276031494,1.004588794708252,1.0225855827331543,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,liger,backward,speed,ms,BT,total tokens,2048,1.8555200099945068,1.8544960021972656,1.8571839809417723,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,liger,backward,speed,ms,BT,total tokens,4096,3.7145920991897583,3.7130560874938965,3.71612811088562,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,liger,backward,speed,ms,BT,total tokens,8192,7.243807792663574,7.243807792663574,7.243807792663574,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,torch,full,memory,MB,BT,total tokens,1024,6526.0009765625,6526.0009765625,6526.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,torch,full,memory,MB,BT,total tokens,2048,13026.0009765625,13026.0009765625,13026.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,torch,full,memory,MB,BT,total tokens,4096,26052.0,26052.0,26052.0,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,torch,full,memory,MB,BT,total tokens,8192,52104.0,52104.0,52104.0,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:30,0.8.0 +jsd,liger,full,memory,MB,BT,total tokens,1024,3514.0009765625,3514.0009765625,3514.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 +jsd,liger,full,memory,MB,BT,total tokens,2048,7014.0009765625,7014.0009765625,7014.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 +jsd,liger,full,memory,MB,BT,total tokens,4096,14028.0009765625,14028.0009765625,14028.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 +jsd,liger,full,memory,MB,BT,total tokens,8192,28056.0,28056.0,28056.0,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 diff --git a/benchmark/data/all_benchmark_data_cutile.csv b/benchmark/data/all_benchmark_data_cutile.csv new file mode 100644 index 000000000..deba35977 --- /dev/null +++ b/benchmark/data/all_benchmark_data_cutile.csv @@ -0,0 +1,33 @@ +kernel_name,kernel_provider,kernel_operation_mode,metric_name,metric_unit,x_name,x_label,x_value,y_value_50,y_value_20,y_value_80,extra_benchmark_config_str,gpu_name,timestamp,liger_version +jsd,torch,full,speed,ms,BT,total tokens,1024,5.9279680252075195,5.9279680252075195,5.9279680252075195,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:12,0.8.0 +jsd,torch,full,speed,ms,BT,total tokens,2048,12.093536376953125,12.093536376953125,12.093536376953125,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:12,0.8.0 +jsd,torch,full,speed,ms,BT,total tokens,4096,24.353023529052734,24.353023529052734,24.353023529052734,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:12,0.8.0 +jsd,torch,full,speed,ms,BT,total tokens,8192,51.63132858276367,51.63132858276367,51.63132858276367,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:12,0.8.0 +jsd,liger,full,speed,ms,BT,total tokens,1024,1.5985119938850403,1.5944639444351196,1.6005439758300781,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:15,0.8.0 +jsd,liger,full,speed,ms,BT,total tokens,2048,3.0249600410461426,3.024307155609131,3.0514752864837646,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:15,0.8.0 +jsd,liger,full,speed,ms,BT,total tokens,4096,6.043647766113281,6.043647766113281,6.043647766113281,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:15,0.8.0 +jsd,liger,full,speed,ms,BT,total tokens,8192,12.18057632446289,12.18057632446289,12.18057632446289,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:15,0.8.0 +jsd,torch,forward,speed,ms,BT,total tokens,1024,2.2989439964294434,2.2989439964294434,2.298969554901123,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,torch,forward,speed,ms,BT,total tokens,2048,4.600415945053101,4.598918342590332,4.60191354751587,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,torch,forward,speed,ms,BT,total tokens,4096,9.270400047302246,9.270400047302246,9.270400047302246,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,torch,forward,speed,ms,BT,total tokens,8192,19.314847946166992,19.314847946166992,19.314847946166992,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,liger,forward,speed,ms,BT,total tokens,1024,0.9553920030593872,0.9492863893508912,0.9575616240501403,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,liger,forward,speed,ms,BT,total tokens,2048,1.4541120529174805,1.4528576374053954,1.4553215980529786,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,liger,forward,speed,ms,BT,total tokens,4096,2.5651841163635254,2.5584064960479735,2.5675840854644774,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,liger,forward,speed,ms,BT,total tokens,8192,5.1241278648376465,5.1241278648376465,5.1241278648376465,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,torch,backward,speed,ms,BT,total tokens,1024,3.8217118978500366,3.8216639041900637,3.82175989151001,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,torch,backward,speed,ms,BT,total tokens,2048,7.542975902557373,7.542975902557373,7.542975902557373,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,torch,backward,speed,ms,BT,total tokens,4096,15.150239944458008,15.150239944458008,15.150239944458008,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,torch,backward,speed,ms,BT,total tokens,8192,30.65158462524414,30.65158462524414,30.65158462524414,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:16,0.8.0 +jsd,liger,backward,speed,ms,BT,total tokens,1024,1.018943965435028,1.0006976008415223,1.0215808391571044,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:17,0.8.0 +jsd,liger,backward,speed,ms,BT,total tokens,2048,1.8514400124549866,1.8510143756866455,1.8518656492233276,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:17,0.8.0 +jsd,liger,backward,speed,ms,BT,total tokens,4096,3.6808160543441772,3.680499267578125,3.6811328411102293,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:17,0.8.0 +jsd,liger,backward,speed,ms,BT,total tokens,8192,7.2151360511779785,7.2151360511779785,7.2151360511779785,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:17,0.8.0 +jsd,torch,full,memory,MB,BT,total tokens,1024,6526.0009765625,6526.0009765625,6526.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:17,0.8.0 +jsd,torch,full,memory,MB,BT,total tokens,2048,13026.0009765625,13026.0009765625,13026.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:17,0.8.0 +jsd,torch,full,memory,MB,BT,total tokens,4096,26052.0,26052.0,26052.0,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:17,0.8.0 +jsd,torch,full,memory,MB,BT,total tokens,8192,52104.0,52104.0,52104.0,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:17,0.8.0 +jsd,liger,full,memory,MB,BT,total tokens,1024,3514.0009765625,3514.0009765625,3514.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:18,0.8.0 +jsd,liger,full,memory,MB,BT,total tokens,2048,7014.0009765625,7014.0009765625,7014.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:18,0.8.0 +jsd,liger,full,memory,MB,BT,total tokens,4096,14028.0009765625,14028.0009765625,14028.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:18,0.8.0 +jsd,liger,full,memory,MB,BT,total tokens,8192,28056.0,28056.0,28056.0,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:16:18,0.8.0 diff --git a/benchmark/scripts/utils.py b/benchmark/scripts/utils.py index 16e6000d3..8459cab84 100644 --- a/benchmark/scripts/utils.py +++ b/benchmark/scripts/utils.py @@ -452,7 +452,9 @@ def run_benchmarks( print_benchmark_data(benchmark_data_list) - update_benchmark_data_csv(benchmark_data_list=benchmark_data_list, overwrite=overwrite) + impl_name = os.environ.get("LIGER_KERNEL_IMPL", "").strip().lower() + file_name = "all_benchmark_data.csv" if impl_name == "" else f"all_benchmark_data_{impl_name}.csv" + update_benchmark_data_csv(benchmark_data_list=benchmark_data_list, filename=file_name, overwrite=overwrite) def parse_benchmark_script_args(): diff --git a/setup.py b/setup.py index 23b6eb2f1..8fb93d3dd 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,16 @@ def get_default_dependencies(): def get_optional_dependencies(): """Get optional dependency groups.""" + cutile_deps = [ + "cuda-tile", + ] + cutile_tileiras_deps = [ + "cuda-tile[tileiras]", + ] + return { + "cutile": cutile_deps, + "cutile-tileiras": cutile_tileiras_deps, "dev": [ "transformers>=4.52.0", "matplotlib>=3.7.2", @@ -48,7 +57,7 @@ def get_optional_dependencies(): "mkdocs-material", "torchvision>=0.20", "prek>=0.2.28", - ] + ], } diff --git a/src/liger_kernel/ops/__init__.py b/src/liger_kernel/ops/__init__.py index 66aad2a8c..b37384ec3 100644 --- a/src/liger_kernel/ops/__init__.py +++ b/src/liger_kernel/ops/__init__.py @@ -1,12 +1,12 @@ """ -Liger-Kernel operators with automatic vendor-specific replacement. +Liger-Kernel operators with automatic implementation-specific replacement. This module provides two ways to import operators: 1. Import from this package (recommended for Function classes): from liger_kernel.ops import LigerGELUMulFunction - This automatically uses vendor-specific implementation if available. + This automatically uses the active implementation if any is selected. 2. Import from submodules (for kernel functions or specific access): from liger_kernel.ops.geglu import geglu_forward, geglu_backward @@ -15,10 +15,12 @@ The replacement mechanism: 1. Default implementations are imported from individual modules (e.g., geglu.py) -2. On module load, device is detected via infer_device() -3. If running on a supported vendor device (npu, xpu, etc.), the default - implementations are replaced with vendor-specific ones -4. All subsequent imports from this package get the replaced versions +2. On module load, device is detected via infer_device() and the env var + LIGER_KERNEL_IMPL is read +3. select_impl() picks an active implementation (auto-applied for the device, + or explicitly requested via env var) +4. If one is selected, its operators replace/extend the symbols here +5. All subsequent imports from this package get the replaced versions Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...) are NOT affected by the replacement mechanism. @@ -27,7 +29,7 @@ # ============================================================================= # Import default implementations # Both Function classes and kernel functions are imported here. -# All of these can be replaced by vendor-specific implementations. +# All of these can be replaced by backend-specific implementations. # ============================================================================= from liger_kernel.ops.attn_res import LigerAttnResFunction # noqa: F401 @@ -94,61 +96,94 @@ from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401 # NOTE: __all__ is intentionally NOT defined. -# - Import from this package (liger_kernel.ops) -> subject to vendor replacement +# - Import from this package (liger_kernel.ops) -> subject to backend replacement # - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation # ============================================================================= -# Vendor-specific replacement logic +# Implementation discovery + dispatch # ============================================================================= -def _replace_with_vendor_ops(): +def _discover_impls(): """ - Replace/add vendor-specific operator implementations. + Trigger self-registration of all implementations. + + Two sources of implementations: + - Hardware backends in ``backends/_/`` (loaded by + ``backends/__init__.py``'s own auto-import loop). + - DSL alternatives at the top level of ``ops/`` (e.g., ``cutile/``). + Each DSL subpackage's ``__init__.py`` calls ``register_impl()`` + when imported. + """ + import importlib + import pkgutil + + # Hardware backends self-register when `backends` is imported. + importlib.import_module("liger_kernel.ops.backends") + + # DSL alternatives — non-private subpackages of `ops/`, minus reserved + # directories that aren't implementation containers. + reserved = {"backends", "experimental"} + for _, modname, ispkg in pkgutil.iter_modules(__path__): + if ispkg and not modname.startswith("_") and modname not in reserved: + importlib.import_module(f"{__name__}.{modname}") + + +def _replace_with_impl_ops(): + """ + Replace/add implementation-specific operators on top of the defaults. This function is called automatically on module load. It: - 1. Detects the current device (cuda, npu, xpu, etc.) - 2. Looks up the vendor for that device via VENDOR_REGISTRY - 3. Loads and applies vendor-specific implementations + 1. Detects the current device (cuda, npu, xpu, etc.). + 2. Selects the active implementation via ``select_impl()``, honoring any + explicit ``LIGER_KERNEL_IMPL`` override. + 3. Loads and applies the implementation's operators. - Vendor implementations should be placed in: - liger_kernel/ops/backends/_/ops/ + Implementations live either at: + liger_kernel/ops//ops/ (DSL alternatives) + liger_kernel/ops/backends/_/ops/ (hardware backends) - If the vendor module defines __all__, only those symbols are exported. - Otherwise, all public symbols (not starting with _) are auto-discovered. + If the implementation module defines ``__all__``, only those symbols are + exported. Otherwise, all public symbols (not starting with ``_``) are + auto-discovered. - Note: Vendor can both override existing ops AND add new vendor-specific ops. + Note: Implementations can both override existing ops AND add new ones. """ - from liger_kernel.ops.backends import get_vendor_for_device + import os + + from liger_kernel.ops.backends import LIGER_KERNEL_IMPL_ENV + from liger_kernel.ops.backends import select_impl from liger_kernel.utils import infer_device device = infer_device() - - # Look up vendor info for this device - vendor_info = get_vendor_for_device(device) - if vendor_info is None: + explicit = os.environ.get(LIGER_KERNEL_IMPL_ENV, "").strip().lower() or None + impl_info = select_impl(device, explicit=explicit) + if impl_info is None: return try: import importlib - vendor_ops = importlib.import_module(vendor_info.module_path) - - # Get names to export: use __all__ if defined, otherwise auto-discover - names_to_export = getattr(vendor_ops, "__all__", None) + impl_ops = importlib.import_module(impl_info.module_path) + # Get names to export: use __all__ if defined, otherwise auto-discover. + names_to_export = getattr(impl_ops, "__all__", None) if names_to_export is None: - # Auto-discover: find all public symbols (classes and functions) - names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")] + names_to_export = [name for name in dir(impl_ops) if not name.startswith("_")] - # Replace or add to this module's globals + # Replace or add to this module's globals. for name in names_to_export: - globals()[name] = getattr(vendor_ops, name) + globals()[name] = getattr(impl_ops, name) except ImportError: - # Vendor module not available, use default implementations - pass + # An auto-selected implementation that fails to import (e.g., missing + # optional vendor SDK) silently falls back to defaults. An explicitly + # requested implementation, however, must succeed — re-raise so the + # user sees the underlying error. + if explicit: + raise -_replace_with_vendor_ops() +_discover_impls() +_replace_with_impl_ops() diff --git a/src/liger_kernel/ops/backends/README.md b/src/liger_kernel/ops/backends/README.md index d4067157b..6d06ce7f3 100644 --- a/src/liger_kernel/ops/backends/README.md +++ b/src/liger_kernel/ops/backends/README.md @@ -1,70 +1,93 @@ -# Adding a New Vendor Backend +# Adding a New Hardware Backend -This directory contains vendor-specific operator implementations that automatically replace the default (CUDA) implementations when running on the corresponding device. +This directory holds **alternative hardware backends** — operator implementations for devices other than the default (CUDA). Examples: Ascend NPU, future ROCm, future XPU. + +DSL alternatives for the *default* hardware (CUDA) — cuTile, future CUTLASS / CuteDSL / TileLang — live at the top level of `src/liger_kernel/ops/` (peers of this `backends/` directory), not inside it. The contract for registering them is the same; only the on-disk location differs. ## Concepts -- **Vendor**: Chip manufacturer (e.g., `ascend`, `intel`, `nvidia`) -- **Device**: Device type (e.g., `npu`, `xpu`, `cuda`) -- **VendorInfo**: Defines the mapping between vendor and device +An **implementation** is a named alternative kernel set. Each implementation declares: + +- **`name`** — identifier (e.g., `ascend`, `cutile`). Users select it via `LIGER_KERNEL_IMPL=`. +- **`devices`** — every device the implementation supports. +- **`default_devices`** — the subset where it is auto-applied at import time. On supported devices not in this set, the implementation is opt-in only (requires `LIGER_KERNEL_IMPL=`). Empty means opt-in only on every supported device. +- **`module_path`** — the Python module path where the kernels live (e.g., `liger_kernel.ops.cutile.ops`, `liger_kernel.ops.backends._ascend.ops`). + +Two flavors fall out of the data: + +- **Auto-applied** — `default_devices` includes the current device. Replaces defaults automatically (e.g., Ascend on NPU). +- **Opt-in** — only selected when the user sets `LIGER_KERNEL_IMPL=` (e.g., cuTile on CUDA). -## Directory Structure +## Directory layout (full tree) ``` -backends/ -├── README.md -├── __init__.py -├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY -├── _ascend/ # Ascend (Huawei) vendor - supports NPU -│ ├── __init__.py # Registers VendorInfo for NPU +src/liger_kernel/ops/ +├── jsd.py, rms_norm.py, ... # default Triton-on-CUDA — the canonical kernels +├── cutile/ # opt-in DSL on CUDA +│ ├── __init__.py # register_impl(ImplInfo(name="cutile", devices=("cuda",), module_path=...)) │ └── ops/ -│ ├── __init__.py # Exports vendor-specific implementations -│ └── geglu.py # NPU-specific GEGLU implementation -└── _/ # Your new vendor backend - └── ... +│ ├── __init__.py +│ └── jsd.py +├── backends/ # alternative hardware backends (this directory) +│ ├── README.md +│ ├── __init__.py # auto-imports _/ subpackages +│ ├── registry.py # ImplInfo, register_impl(), select_impl(), IMPL_REGISTRY +│ └── _ascend/ # Ascend NPU — auto-applied on NPU +│ ├── __init__.py # register_impl(ImplInfo(name="ascend", ...)) +│ └── ops/ +│ ├── __init__.py +│ └── jsd.py, ... +└── __init__.py # imports defaults, runs _replace_with_impl_ops() ``` -## How It Works +## How dispatch works -1. When `liger_kernel.ops.backends` is imported, it imports all vendor packages (e.g., `_ascend`) -2. Each vendor's `__init__.py` calls `register_vendor()` to register itself -3. When `liger_kernel.ops` is imported, `_replace_with_vendor_ops()` is called -4. It detects the current device via `infer_device()` and looks up the vendor -5. Vendor implementations replace/add to the `liger_kernel.ops` namespace +1. `liger_kernel.ops` is imported. Default top-level kernels (`ops/jsd.py`, etc.) load first. +2. `_discover_impls()` runs: + - Imports `liger_kernel.ops.backends`, which auto-imports each `_/` subpackage. Each subpackage's `__init__.py` calls `register_impl()`. + - Iterates the top-level non-private subpackages of `ops/` (e.g., `cutile/`), excluding reserved dirs (`backends`, `experimental`), and imports each. Same self-registration pattern. +3. `_replace_with_impl_ops()` runs: + - Detects the current device via `infer_device()`. + - Reads `LIGER_KERNEL_IMPL` from the environment. + - Calls `select_impl(device, explicit=)`: + - If the env var is set, the named implementation is looked up and validated (device must be in `devices`). + - If unset, the first registered implementation listing the current device in its `default_devices` is returned; otherwise no replacement happens. + - If an implementation was selected, its operators replace/extend the `liger_kernel.ops` namespace. -## Adding a New Vendor +If an auto-selected implementation fails to import (e.g., the vendor SDK isn't installed), the dispatcher silently falls back to defaults. An explicitly-requested implementation that fails to import re-raises so the user sees the underlying error. -### Step 1: Create Directory Structure +## Adding a new hardware backend (lives in `backends/_/`) + +### Step 1: Create the directory ```bash -mkdir -p backends/_/ops -touch backends/_/__init__.py -touch backends/_/ops/__init__.py +mkdir -p src/liger_kernel/ops/backends/_/ops +touch src/liger_kernel/ops/backends/_/__init__.py +touch src/liger_kernel/ops/backends/_/ops/__init__.py ``` -### Step 2: Register Your Vendor +### Step 2: Register the implementation -In `backends/_/__init__.py`, register your vendor: +In `backends/_/__init__.py`: ```python -""" - backend for Liger-Kernel. -""" - -from liger_kernel.ops.backends.registry import VendorInfo, register_vendor - -register_vendor( - VendorInfo( - vendor="", - device="", - ) -) +""" hardware backend for Liger-Kernel.""" + +from liger_kernel.ops.backends.registry import ImplInfo +from liger_kernel.ops.backends.registry import register_impl + +# Auto-applied on the listed devices: +register_impl(ImplInfo( + name="", + devices=("",), + default_devices=("",), + module_path=f"{__name__}.ops", +)) ``` +### Step 3: Ensure device detection works -### Step 3: Ensure Device Detection Works - -Make sure `infer_device()` in `liger_kernel/utils.py` can detect your device: +Make sure `infer_device()` in `liger_kernel/utils.py` recognizes the device. Example: ```python def infer_device(): @@ -72,57 +95,41 @@ def infer_device(): return "cuda" if is_npu_available(): return "npu" - # Add your device detection here if is__available(): return "" return "cpu" ``` -### Step 4: Implement Vendor-Specific Operators +### Step 4: Implement the operators -Create operator files in `backends/_/ops/`. For example, `geglu.py`: +Create operator files in `backends/_/ops/`. For example, `geglu.py`: ```python import torch class LigerGELUMulFunction(torch.autograd.Function): - """ - Vendor-specific LigerGELUMulFunction implementation. - """ + """Backend-specific LigerGELUMulFunction.""" + @staticmethod def forward(ctx, a, b): - # Your vendor-specific forward implementation ... @staticmethod def backward(ctx, dc): - # Your vendor-specific backward implementation ... - -# Optional: vendor-specific kernel functions -def geglu_forward_vendor(a, b): - ... - -def geglu_backward_vendor(a, b, dc): - ... ``` -### Step 5: Export in `ops/__init__.py` +### Step 5: Export from `ops/__init__.py` -In `backends/_/ops/__init__.py`, export your implementations: +In `backends/_/ops/__init__.py`: ```python -""" --specific operator implementations. -""" +"""-specific operator implementations.""" -from . import ( - LigerGELUMulFunction, - geglu_forward_vendor as geglu_forward, # Rename to match default API - geglu_backward_vendor as geglu_backward, -) +from .geglu import LigerGELUMulFunction +from .geglu import geglu_backward +from .geglu import geglu_forward -# Explicitly declare what to export (recommended) __all__ = [ "LigerGELUMulFunction", "geglu_forward", @@ -130,22 +137,74 @@ __all__ = [ ] ``` -## Key Points +## Adding a new DSL implementation (lives at top level of `ops/`) -### Incremental Override +The pattern is the same — only the on-disk location and the `module_path` differ: -You **don't need to implement all operators**. Only implement the ones that require vendor-specific adaptations. Unimplemented operators will automatically fall back to the default (CUDA) implementation. +``` +src/liger_kernel/ops// +├── __init__.py # register_impl(...) +└── ops/ + ├── __init__.py # exports symbols + └── jsd.py, ... # kernel files +``` -### Vendor-Specific Additions +```python +# ops//__init__.py -Vendors can also **add new operators** that don't exist in the default implementation. These will be exported to `liger_kernel.ops` namespace for users to import. +from liger_kernel.ops.backends.registry import ImplInfo +from liger_kernel.ops.backends.registry import register_impl -### Naming Convention +# Opt-in only (no `default_devices`): +register_impl(ImplInfo( + name="", + devices=("cuda",), + module_path=f"{__name__}.ops", # liger_kernel.ops..ops +)) +``` + +## Key points + +### Incremental override + +You **don't need to implement all operators**. Only implement the ones that need a different version. Unimplemented operators fall back to the defaults. + +### Adding new operators + +An implementation can also **add new operators** that don't exist in the defaults. They are exported to `liger_kernel.ops` for users to import. + +### Naming convention + +- Use the **same class/function names** as the defaults when overriding — lets user code stay unchanged. +- Use `as` imports to rename if your internal naming differs. + +### Multi-device implementations + +An implementation can support multiple devices by listing them all in `devices`. It can be the default on a subset (or none) of them. + +```python +# Supports CUDA and XPU; default on neither (opt-in everywhere): +register_impl(ImplInfo( + name="inductor", + devices=("cuda", "xpu"), + module_path="liger_kernel.ops.inductor.ops", +)) + +# Supports CUDA and XPU; auto-applied on XPU only: +register_impl(ImplInfo( + name="example", + devices=("cuda", "xpu"), + default_devices=("xpu",), + module_path="liger_kernel.ops.example.ops", +)) +``` -- Use the **same class/function names** as the default implementations for overrides -- This allows seamless replacement without changing user code -- Use `as` imports to rename if your internal naming differs +## Examples in this repo -## Example: Ascend NPU Backend +- `backends/_ascend/` — auto-applied hardware backend (Ascend NPU). +- `../cutile/` — opt-in DSL implementation on CUDA. Enable with: + ```bash + LIGER_KERNEL_IMPL=cutile python your_script.py + ``` -See `_ascend/` directory for a complete example of the Ascend NPU backend implementation. +`select_impl()` validates the request: if the current device isn't in the implementation's `devices`, or its module fails to import, the user gets a clear error instead of a silent fallback. diff --git a/src/liger_kernel/ops/backends/__init__.py b/src/liger_kernel/ops/backends/__init__.py index ad7779c48..f9d0fea38 100644 --- a/src/liger_kernel/ops/backends/__init__.py +++ b/src/liger_kernel/ops/backends/__init__.py @@ -1,13 +1,15 @@ import importlib import pkgutil -from liger_kernel.ops.backends.registry import VENDOR_REGISTRY # noqa: F401 -from liger_kernel.ops.backends.registry import VendorInfo # noqa: F401 -from liger_kernel.ops.backends.registry import get_vendor_for_device # noqa: F401 -from liger_kernel.ops.backends.registry import register_vendor # noqa: F401 +from liger_kernel.ops.backends.registry import IMPL_REGISTRY # noqa: F401 +from liger_kernel.ops.backends.registry import LIGER_KERNEL_IMPL_ENV # noqa: F401 +from liger_kernel.ops.backends.registry import ImplInfo # noqa: F401 +from liger_kernel.ops.backends.registry import register_impl # noqa: F401 +from liger_kernel.ops.backends.registry import select_impl # noqa: F401 -# Auto-import all _ subpackages to trigger registration -# Each vendor's __init__.py calls register_vendor() when imported +# Auto-import all _ subpackages to trigger registration of +# alternative-hardware backends (e.g., _ascend/). Each one calls register_impl() +# in its __init__.py. for _, modname, ispkg in pkgutil.iter_modules(__path__): if ispkg and modname.startswith("_"): importlib.import_module(f"{__name__}.{modname}") diff --git a/src/liger_kernel/ops/backends/_ascend/__init__.py b/src/liger_kernel/ops/backends/_ascend/__init__.py index a07e7ab09..f4ad594b5 100644 --- a/src/liger_kernel/ops/backends/_ascend/__init__.py +++ b/src/liger_kernel/ops/backends/_ascend/__init__.py @@ -1,5 +1,14 @@ -from liger_kernel.ops.backends.registry import VendorInfo -from liger_kernel.ops.backends.registry import register_vendor +from liger_kernel.ops.backends.registry import ImplInfo +from liger_kernel.ops.backends.registry import register_impl -# Register Ascend vendor for NPU device -register_vendor(VendorInfo(vendor="ascend", device="npu")) +# Ascend NPU backend — default on NPU devices. +# Future: when tilelang-ascend lands, this can be renamed to "ascend-triton" +# and a second register_impl(ImplInfo(name="ascend-tilelang", ...)) added. +register_impl( + ImplInfo( + name="ascend", + devices=("npu",), + default_devices=("npu",), + module_path=f"{__name__}.ops", # liger_kernel.ops.backends._ascend.ops + ) +) diff --git a/src/liger_kernel/ops/backends/registry.py b/src/liger_kernel/ops/backends/registry.py index 5fe3613c8..16e65b1bd 100644 --- a/src/liger_kernel/ops/backends/registry.py +++ b/src/liger_kernel/ops/backends/registry.py @@ -1,61 +1,108 @@ """ -Vendor registry for Liger-Kernel multi-backend support. - -This module defines VendorInfo and the registry for vendor registration. -Each vendor registers itself by calling register_vendor() in its __init__.py. +Implementation registry for Liger-Kernel. + +An "implementation" here is a named alternative kernel set. It may correspond +to a different hardware device (e.g., Ascend on NPU, in ``backends/_ascend/``) +or a different DSL on the same device (e.g., cuTile on CUDA, in ``ops/cutile/``). +It may support one or more devices. + +Each implementation declares: + - the set of devices it supports + - the subset of those devices on which it is the *default* (auto-applied on + import). On any other supported device the implementation is opt-in only + and must be requested explicitly via the LIGER_KERNEL_IMPL environment + variable. + - the Python module path where its operators live. + +Each implementation registers itself by calling register_impl() in its +__init__.py. """ from dataclasses import dataclass +from dataclasses import field from typing import Optional +from typing import Tuple -# Dynamically get backends package path to avoid hardcoding -_BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends" +# Environment variable users set to explicitly select an opt-in implementation. +LIGER_KERNEL_IMPL_ENV = "LIGER_KERNEL_IMPL" -@dataclass -class VendorInfo: +@dataclass(frozen=True) +class ImplInfo: """ - Information about a chip vendor and its supported device. + Information about a kernel implementation. Attributes: - vendor: Vendor name (e.g., "ascend", "intel", "nvidia") - device: Device type this vendor supports (e.g., "npu", "xpu") + name: Implementation identifier (e.g., "ascend", "cutile"). Also the + value users pass via ``LIGER_KERNEL_IMPL=``. + devices: Tuple of device types this implementation supports + (e.g., ``("npu",)``, ``("cuda",)``, ``("cuda", "xpu")``). + default_devices: Subset of ``devices`` on which this implementation + is automatically applied at import time. On supported devices not + listed here, it is opt-in only via ``LIGER_KERNEL_IMPL``. Empty + tuple (the default) means opt-in only on every supported device. + module_path: Python module path where the operator implementations + live (e.g., ``"liger_kernel.ops.cutile.ops"``). Required. """ - vendor: str - device: str + name: str + devices: Tuple[str, ...] + default_devices: Tuple[str, ...] = field(default_factory=tuple) + module_path: str = "" - @property - def module_path(self) -> str: - """Auto-generated module path based on vendor name.""" - return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops" + def __post_init__(self): + if not self.devices: + raise ValueError(f"Implementation {self.name!r} must declare at least one supported device.") + if not self.module_path: + raise ValueError(f"Implementation {self.name!r} must declare a module_path.") + extra = set(self.default_devices) - set(self.devices) + if extra: + raise ValueError( + f"Implementation {self.name!r}: default_devices {sorted(extra)} not in devices {list(self.devices)}." + ) -# Registry mapping device types to their vendor info -# Vendors register themselves via register_vendor() -VENDOR_REGISTRY: dict[str, VendorInfo] = {} +# Registry mapping implementation names to their info. +IMPL_REGISTRY: dict[str, ImplInfo] = {} -def register_vendor(vendor_info: VendorInfo) -> None: - """ - Register a vendor's info in the global registry. +def register_impl(info: ImplInfo) -> None: + """Register an implementation's info in the global registry.""" + IMPL_REGISTRY[info.name] = info - This should be called in each vendor's __init__.py to register itself. - Args: - vendor_info: VendorInfo instance to register +def select_impl(device: str, explicit: Optional[str] = None) -> Optional[ImplInfo]: """ - VENDOR_REGISTRY[vendor_info.device] = vendor_info - - -def get_vendor_for_device(device: str) -> Optional[VendorInfo]: - """ - Get the VendorInfo for a given device type. + Select the implementation for the current device. Args: - device: Device type (e.g., "npu", "xpu") + device: Device type from ``infer_device()`` (e.g., "cuda", "npu"). + explicit: If set, force selection of this named implementation. The + supported devices are validated against the runtime. Returns: - VendorInfo if found, None otherwise + ``ImplInfo`` if an implementation should replace the defaults, + ``None`` to keep defaults. + + Raises: + RuntimeError: If ``explicit`` names an unknown implementation or one + incompatible with the current device. """ - return VENDOR_REGISTRY.get(device) + if explicit: + info = IMPL_REGISTRY.get(explicit) + if info is None: + known = ", ".join(sorted(IMPL_REGISTRY)) or "" + raise RuntimeError(f"Unknown {LIGER_KERNEL_IMPL_ENV}={explicit!r}. Registered implementations: {known}.") + if device not in info.devices: + supported = ", ".join(info.devices) + raise RuntimeError( + f"{LIGER_KERNEL_IMPL_ENV}={info.name!r} supports devices ({supported}), " + f"but the current device is {device!r}." + ) + return info + + # Auto-select: pick an implementation that lists the current device in its defaults. + for info in IMPL_REGISTRY.values(): + if device in info.default_devices: + return info + return None diff --git a/src/liger_kernel/ops/cutile/__init__.py b/src/liger_kernel/ops/cutile/__init__.py new file mode 100644 index 000000000..031f257c1 --- /dev/null +++ b/src/liger_kernel/ops/cutile/__init__.py @@ -0,0 +1,18 @@ +""" +cuTile backend for Liger-Kernel. + +cuTile is an optional CUDA-only DSL. It is opt-in only — users select it +explicitly via ``LIGER_KERNEL_IMPL=cutile``. It is not auto-applied on +any device (note the empty ``default_devices`` on the registration below). +""" + +from liger_kernel.ops.backends.registry import ImplInfo +from liger_kernel.ops.backends.registry import register_impl + +register_impl( + ImplInfo( + name="cutile", + devices=("cuda",), + module_path=f"{__name__}.ops", # liger_kernel.ops.cutile.ops + ) +) diff --git a/src/liger_kernel/ops/cutile/ops/__init__.py b/src/liger_kernel/ops/cutile/ops/__init__.py new file mode 100644 index 000000000..a3dafe684 --- /dev/null +++ b/src/liger_kernel/ops/cutile/ops/__init__.py @@ -0,0 +1,23 @@ +""" +cuTile-specific operator implementations. +""" + +try: + import cuda.tile as ct # noqa: F401 +except ImportError as exc: + raise ImportError( + "cuTile backend requires cuda-tile. Install it with `pip install cuda-tile` " + "or `pip install 'cuda-tile[tileiras]'` to include the optional tileiras compiler. " + "When installing Liger-Kernel, use `pip install 'liger-kernel[cutile]'` " + "or `pip install 'liger-kernel[cutile-tileiras]'`." + ) from exc + +from liger_kernel.ops.cutile.ops.jsd import LigerJSDFunction +from liger_kernel.ops.cutile.ops.jsd import jsd_backward +from liger_kernel.ops.cutile.ops.jsd import jsd_forward + +__all__ = [ + "LigerJSDFunction", + "jsd_backward", + "jsd_forward", +] diff --git a/src/liger_kernel/ops/cutile/ops/jsd.py b/src/liger_kernel/ops/cutile/ops/jsd.py new file mode 100644 index 000000000..bacb02bfe --- /dev/null +++ b/src/liger_kernel/ops/cutile/ops/jsd.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import math + +from typing import Optional + +import cuda.tile as ct +import torch + +from liger_kernel.ops.cutile.ops.utils import _next_power_of_2 +from liger_kernel.ops.utils import ensure_contiguous + +ConstFloat = ct.Constant[float] +ConstInt = ct.Constant[int] +JSD_BLOCK_SIZE = 4096 + + +@ct.kernel(occupancy=ct.ByTarget(sm_100=4)) +def jsd_kernel_ct( + x, # (BT, V) log Q (student) + y, # (BT, V) log P (teacher) + loss, # (BT, V) float32 loss accumulator + dx, # (BT, V) gradient output + label, # (BT,) label tensor, or dummy tensor when HAS_LABEL=0 + beta: ConstFloat, + inv_n_non_ignore: ConstFloat, + ignore_index: ConstInt, + n_cols: ConstInt, + BLOCK_SIZE: ConstInt, + HAS_LABEL: ConstInt, +): + """ + cuTile kernel for generalized Jensen-Shannon Divergence. + """ + row_idx = ct.bid(0) + + if HAS_LABEL: + lbl = ct.load(label, row_idx, shape=()) + if lbl == ignore_index: + num_chunks_early = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + for ci in range(num_chunks_early): + col_indices = ct.arange(BLOCK_SIZE, dtype=ct.int32) + ci * BLOCK_SIZE + ct.scatter(dx, (row_idx, col_indices), ct.full((BLOCK_SIZE,), 0.0, dtype=dx.dtype), check_bounds=True) + return + + num_chunks = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + for chunk_idx in range(num_chunks): + col_indices = ct.arange(BLOCK_SIZE, dtype=ct.int32) + chunk_idx * BLOCK_SIZE + + x_tile = ct.gather(x, (row_idx, col_indices), check_bounds=True, padding_value=-math.inf) + y_tile = ct.gather(y, (row_idx, col_indices), check_bounds=True, padding_value=-math.inf) + + x_f32 = ct.astype(x_tile, ct.float32) + y_f32 = ct.astype(y_tile, ct.float32) + + loss_tile = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + dx_tile = ct.full((BLOCK_SIZE,), 0.0, dtype=ct.float32) + + if beta == 0.0: + y_max = ct.max(y_f32, 0, keepdims=True) + y_prob = ct.exp(y_f32 - y_max) * ct.exp(y_max) + loss_tile = y_prob * (y_f32 - x_f32) + dx_tile = -y_prob + elif beta == 1.0: + x_max = ct.max(x_f32, 0, keepdims=True) + x_prob = ct.exp(x_f32 - x_max) * ct.exp(x_max) + loss_tile = x_prob * (x_f32 - y_f32) + dx_tile = loss_tile + x_prob + else: + x_max = ct.max(x_f32, 0, keepdims=True) + y_max = ct.max(y_f32, 0, keepdims=True) + max_val = ct.maximum(x_max, y_max) + exp_max = ct.exp(max_val) + q_prob = ct.exp(x_f32 - max_val) * exp_max + p_prob = ct.exp(y_f32 - max_val) * exp_max + beta_p = beta * p_prob + one_minus_beta_q = (1.0 - beta) * q_prob + m_prob = beta_p + one_minus_beta_q + log_m = ct.log(m_prob) + loss_tile = beta_p * y_f32 + one_minus_beta_q * x_f32 - m_prob * log_m + dx_tile = one_minus_beta_q * (x_f32 - log_m) + + loss_tile = loss_tile * inv_n_non_ignore + dx_tile = dx_tile * inv_n_non_ignore + + ct.scatter(loss, (row_idx, col_indices), loss_tile, check_bounds=True) + ct.scatter(dx, (row_idx, col_indices), ct.astype(dx_tile, dx.dtype), check_bounds=True) + + +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): + num_rows, vocab_size = _input.shape + BLOCK_SIZE = min(JSD_BLOCK_SIZE, _next_power_of_2(vocab_size)) + + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) + dx = torch.empty_like(_input) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = num_rows + + if n_non_ignore == 0: + return torch.tensor(0.0, device=_input.device, dtype=_input.dtype), torch.zeros_like(_input) + + inv_n_non_ignore = 1.0 / n_non_ignore + label_tensor = shift_labels if has_label else torch.empty(1, device=_input.device, dtype=torch.int64) + + ct.launch( + torch.cuda.current_stream(), + (num_rows, 1, 1), + jsd_kernel_ct, + ( + _input, + target, + loss, + dx, + label_tensor, + float(beta), + float(inv_n_non_ignore), + int(ignore_index), + int(vocab_size), + int(BLOCK_SIZE), + int(has_label), + ), + ) + + return torch.sum(loss).to(_input.dtype), dx + + +def jsd_backward(dx, grad_output): + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return dx + return grad_output * dx + + +class LigerJSDFunction(torch.autograd.Function): + r""" + cuTile autograd wrapper for the generalized Jensen-Shannon Divergence loss. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + shift_labels: Optional[torch.Tensor], + beta: float, + ignore_index: int, + ) -> torch.Tensor: + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (_input.shape[0],), ( + f"shift_labels must have shape (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, dx = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) + ctx.save_for_backward(dx) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor): + (dx,) = ctx.saved_tensors + dx = jsd_backward(dx, grad_output) + return (dx, None, None, None, None) diff --git a/src/liger_kernel/ops/cutile/ops/utils.py b/src/liger_kernel/ops/cutile/ops/utils.py new file mode 100644 index 000000000..5bc958f22 --- /dev/null +++ b/src/liger_kernel/ops/cutile/ops/utils.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + + +def _next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n.""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/test/transformers/test_cutile_backend.py b/test/transformers/test_cutile_backend.py new file mode 100644 index 000000000..762d25ff4 --- /dev/null +++ b/test/transformers/test_cutile_backend.py @@ -0,0 +1,49 @@ +import os +import subprocess +import sys +import textwrap + +from pathlib import Path + +import pytest +import torch + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuTile backend requires CUDA") +@pytest.mark.skipif( + os.environ.get("LIGER_KERNEL_IMPL", "").strip().lower() != "cutile", + reason="cuTile backend selection test requires LIGER_KERNEL_IMPL=cutile", +) +def test_liger_kernel_impl_cutile_selects_cutile_jsd_function(): + repo_root = Path(__file__).resolve().parents[2] + pythonpath = os.pathsep.join( + [ + str(repo_root / "src"), + str(repo_root), + os.environ.get("PYTHONPATH", ""), + ] + ) + env = { + **os.environ, + "LIGER_KERNEL_IMPL": "cutile", + "PYTHONPATH": pythonpath, + } + script = textwrap.dedent( + """ + from liger_kernel.transformers.jsd import LigerJSDFunction + + module_name = LigerJSDFunction.__module__ + expected_prefix = "liger_kernel.ops.cutile." + if not module_name.startswith(expected_prefix): + raise AssertionError( + f"Expected cuTile LigerJSDFunction from {expected_prefix}, got {module_name}" + ) + """ + ) + + subprocess.run( + [sys.executable, "-c", script], + check=True, + env=env, + cwd=repo_root, + )