diff --git a/lightllm/__init__.py b/lightllm/__init__.py index e69de29bb..e9ba6f304 100644 --- a/lightllm/__init__.py +++ b/lightllm/__init__.py @@ -0,0 +1,4 @@ +from lightllm.utils.device_utils import is_musa + +if is_musa(): + import torchada # noqa: F401 diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index eb5af6fec..45de83e98 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -60,7 +60,8 @@ def _fwd_kernel_token_att1( ).to(tl.int64) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32) + att_value = tl.sum(q[None, :] * k, 1) + att_value = att_value.to(tl.float32) att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index cd48a355b..09d7a680f 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -81,11 +81,14 @@ def calcu_kernel_best_vsm_count(kernel, num_warps): return num_sm +@lru_cache(maxsize=1) +def is_musa(): + return hasattr(torch.version, "musa") and torch.version.musa is not None + + @lru_cache(maxsize=None) def get_current_device_name(): - import torch - - if torch.cuda.is_available(): + if torch.cuda.is_available() or is_musa(): device = torch.cuda.current_device() gpu_name = torch.cuda.get_device_name(device) # 4090 trans to 4090 D @@ -103,8 +106,6 @@ def init_p2p(device_index): """ torch 调用跨卡的to操作后,triton编译的算子便能自动操作跨卡tensor。 """ - import torch - num_gpus = torch.cuda.device_count() tensor = torch.zeros((1,)) tensor = tensor.to(f"cuda:{device_index}") @@ -127,8 +128,26 @@ def has_nvlink(): result = result.decode("utf-8") # Check if the output contains 'NVLink' return any(f"NV{i}" in result for i in range(1, 8)) + except FileNotFoundError: + # nvidia-smi is not installed, assume no NVLink + return False + except subprocess.CalledProcessError: + # If there's an error while executing nvidia-smi, assume no NVLink + return False + + +def has_mtlink(): + try: + # Call mthreads-gmi to get the topology matrix + result = subprocess.check_output(["mthreads-gmi", "topo", "--matrix"]) + result = result.decode("utf-8") + # Check if the output contains 'MTLink' + return any(f"MT{i}" in result for i in range(1, 8)) + except FileNotFoundError: + # mthreads-gmi is not installed, assume no MTLink + return False except subprocess.CalledProcessError: - # If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink + # If there's an error while executing mthreads-gmi, assume no MTLink return False