Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lightllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from lightllm.utils.device_utils import is_musa

if is_musa():
import torchada # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def _fwd_kernel_token_att1(
).to(tl.int64)
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32)
att_value = tl.sum(q[None, :] * k, 1)
att_value = att_value.to(tl.float32)
att_value *= sm_scale
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
Expand Down
31 changes: 25 additions & 6 deletions lightllm/utils/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,14 @@ def calcu_kernel_best_vsm_count(kernel, num_warps):
return num_sm


@lru_cache(maxsize=1)
def is_musa():
return hasattr(torch.version, "musa") and torch.version.musa is not None


@lru_cache(maxsize=None)
def get_current_device_name():
import torch

if torch.cuda.is_available():
if torch.cuda.is_available() or is_musa():
device = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(device)
# 4090 trans to 4090 D
Expand All @@ -103,8 +106,6 @@ def init_p2p(device_index):
"""
torch 调用跨卡的to操作后,triton编译的算子便能自动操作跨卡tensor。
"""
import torch

num_gpus = torch.cuda.device_count()
tensor = torch.zeros((1,))
tensor = tensor.to(f"cuda:{device_index}")
Expand All @@ -127,8 +128,26 @@ def has_nvlink():
result = result.decode("utf-8")
# Check if the output contains 'NVLink'
return any(f"NV{i}" in result for i in range(1, 8))
except FileNotFoundError:
# nvidia-smi is not installed, assume no NVLink
return False
except subprocess.CalledProcessError:
# If there's an error while executing nvidia-smi, assume no NVLink
return False


def has_mtlink():
try:
# Call mthreads-gmi to get the topology matrix
result = subprocess.check_output(["mthreads-gmi", "topo", "--matrix"])
result = result.decode("utf-8")
# Check if the output contains 'MTLink'
return any(f"MT{i}" in result for i in range(1, 8))
except FileNotFoundError:
# mthreads-gmi is not installed, assume no MTLink
return False
except subprocess.CalledProcessError:
# If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink
# If there's an error while executing mthreads-gmi, assume no MTLink
return False


Expand Down