Skip to content

Bugs in Triton operator? #9

@XintianHan

Description

@XintianHan

Hi. Thanks for the nice triton implementation. Maybe I found a bug in the triton operator. It seems that the operator does not support head dim=192, but it supports dim=128 and 256.

For the example below

from lightning_attention import lightning_attention
import torch
# b h n d
b = 1
h = 16
n = 64
d = 192
q = torch.randn(b, h, n, d).to("cuda")
k = torch.randn(b, h, n, d).to("cuda")
v = torch.randn(b, h, n, d).to("cuda")
slope_rate = torch.ones(h).to("cuda")
output = lightning_attention(
    q, k, v, True, slope_rate.squeeze(-1).squeeze(-1)
)
print("test succeed!")

It gives me the error

  File "<string>", line 41, in _fwd_kernel
  File "/home/tiger/.local/lib/python3.9/site-packages/triton/compiler.py", line 1621, in compile
    next_module = compile(module)
  File "/home/tiger/.local/lib/python3.9/site-packages/triton/compiler.py", line 1550, in <lambda>
    lambda src: ast_to_ttir(src, signature, configs[0], constants)),
  File "/home/tiger/.local/lib/python3.9/site-packages/triton/compiler.py", line 963, in ast_to_ttir
    return optimize_triton_ir(mod)
  File "/home/tiger/.local/lib/python3.9/site-packages/triton/compiler.py", line 957, in optimize_triton_ir
    pm.run(mod)
RuntimeError: PassManager::run failed

On

line 370, in forward
    _fwd_kernel[grid](

Any advice here?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions