Skip to content

assert d in supports_dim and e in supports_dim ? #3

@XintianHan

Description

@XintianHan

Thank you for the nice implementation! It seems that dim=192 is not in supports_dim. Why is it the case here? Could you add dim=192?

I tried this script

import torch
from lightning_attn.ops import lightning_attn_func
from lightning_attn.utils import _build_slope_tensor

dtype = torch.bfloat16
device = torch.device("cuda")
b, h, n, d, e = 2, 12, 2048, 192, 192

q = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
v = torch.randn((b, h, n, e), dtype=dtype, device=device).requires_grad_()
s = _build_slope_tensor(h).to(q.device).to(torch.float32)

o = lightning_attn_func(q, k, v, s)

print(o.shape)

and got this error

    o = lightning_attn_func(q, k, v, s)
  File "/opt/tiger/mariana/lightning-attention-main/lightning_attn/ops/lightning_attn_interface.py", line 10, in lightning_attn_func
    assert d in supports_dim and e in supports_dim
AssertionError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions