Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions configs/xpu/text/qwen2_5_xpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
model:
model_path: Qwen/Qwen2.5-0.5B-Instruct
ops_implementation:
attn_implementation: sdpa

data:
train_path: fineweb
train_size: 1000000000000
dataloader:
type: native
drop_last: true
datasets_type: iterable
data_type: plaintext
max_seq_len: 512
text_keys: text
dyn_bsz_buffer_size: 200

train:
accelerator:
ulysses_size: 1
fsdp_config:
fsdp_mode: fsdp1
full_shard: true
offload: false
mixed_precision:
enable: true
offload:
enable_activation: false
gradient_checkpointing:
enable: true
global_batch_size: 8
micro_batch_size: 1
bsz_warmup_ratio: 0.007
optimizer:
type: adamw
lr: 3.0e-4
lr_warmup_ratio: 0.007
lr_decay_style: constant
lr_decay_ratio: 1.0
weight_decay: 0.01
max_grad_norm: 1.0
init_device: meta
enable_full_determinism: false
empty_cache_steps: 500
checkpoint:
output_dir: Qwen2.5-0.5B-Instruct-sft-xpu
manager: dcp
save_steps: 100
save_hf_weights: true
wandb:
enable: false
num_train_epochs: 1
max_steps: 10
48 changes: 48 additions & 0 deletions tests/special_xpu/run_veomni_e2e_sft_xpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash
# VeOmni trainer e2e SFT smoke test on Intel XPU (2 GPUs)

set -e

# Get the VeOmni root directory
VEOMNI_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
cd "$VEOMNI_ROOT"

# XPU environment variables
export ZE_AFFINITY_MASK="0,1"
export CCL_ATL_SHM=1
export CCL_BUFFER_CACHE=0
export CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
export CCL_TOPO_ALGO=0
export RAY_NUM_PRESTART_PYTHON_WORKERS=0

# Model and data setup
MODEL_PATH="${CI_HF_MODELS_DIR:-.}/Qwen/Qwen2.5-0.5B-Instruct"
DATASET_DIR="${CI_DATASET_DIR:-.}"

echo "=========================================="
echo "VeOmni e2e SFT Test on Intel XPU"
echo "=========================================="
echo "Model: Qwen2.5-0.5B-Instruct"
echo "GPUs: 2 (XPU devices 0,1 via ZE_AFFINITY_MASK)"
echo "Config: configs/xpu/text/qwen2_5_xpu.yaml (sdpa attention, fsdp1 trainer smoke)"
echo ""

# Run torchrun with 2 XPU GPUs
torchrun \
--nnodes=1 \
--nproc_per_node=2 \
--master-port=4321 \
tasks/train_text.py \
configs/xpu/text/qwen2_5_xpu.yaml \
--model.model_path "$MODEL_PATH" \
--data.train_path "$DATASET_DIR/fineweb" \
--train.checkpoint.output_dir "Qwen2.5-0.5B-Instruct-sft-xpu" \
--train.accelerator.fsdp_config.fsdp_mode fsdp1 \
--train.num_train_epochs 1 \
--train.max_steps 5 \
--train.wandb.enable false

echo ""
echo "=========================================="
echo "VeOmni XPU e2e test PASSED!"
echo "=========================================="
17 changes: 16 additions & 1 deletion veomni/distributed/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from .moe_utils import generate_weights_idx, permute, sort_chunks_by_idxs, unpermute


if not is_torch_npu_available():
group_gemm_same_mn = None
group_gemm_same_nk = None
if not is_torch_npu_available() and torch.cuda.is_available():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Guarding the import of group_gemm kernels with torch.cuda.is_available() will cause a NameError at runtime on XPU devices when EPGroupGemm or EPMergedFc1GroupGemm are used, as these classes refer to the imported functions in their methods. If MoE is not yet supported on XPU, it would be better to provide a clear error message (e.g., a RuntimeError in the forward method) or ensure the functions are defined as None and checked before use, rather than allowing a NameError to occur.

from ...ops.kernels.moe._kernels.kernel.group_gemm import group_gemm_same_mn, group_gemm_same_nk


Expand Down Expand Up @@ -149,6 +151,12 @@ def forward(
):
# permute_tokens: [tokens, hidden_dim]
# cumsum: [local_experts]

if group_gemm_same_nk is None:
raise RuntimeError(
"MoE fused group_gemm kernels are not available for the current device/backend. "
"Please use moe_implementation='eager' instead of 'fused'."
)

# compute linear layer fc1-1
fc1_1_output = group_gemm_same_nk(
Expand Down Expand Up @@ -320,6 +328,13 @@ def forward(
):
# permute_tokens: [tokens, hidden_dim]
# cumsum: [local_experts]

if group_gemm_same_nk is None:
raise RuntimeError(
"MoE fused group_gemm kernels are not available for the current device/backend. "
"Please use moe_implementation='eager' instead of 'fused'."
)

assert fc1_1_2_weight.shape[1] % 2 == 0, (
f"Merged fc1_1_2_weight dim 1 must be even, got {fc1_1_2_weight.shape[1]}"
)
Expand Down
2 changes: 1 addition & 1 deletion veomni/distributed/torch_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def build_parallelize_model(
parallel_state = get_parallel_state()

if not parallel_state.fsdp_enabled:
if kwargs.get("init_device") not in ["cuda", "npu"]:
if kwargs.get("init_device") not in ["cuda", "npu", "xpu"]:
raise ValueError("Only FSDP training supports `init_device=cpu` or `init_device=meta`.")
if kwargs.pop("enable_fsdp_offload", False):
raise ValueError("Only FSDP training supports `enable_fsdp_offload`.")
Expand Down
14 changes: 8 additions & 6 deletions veomni/ops/kernels/moe/_kernels/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
def get_device_key() -> str:
import torch

if torch.cuda.get_device_capability() == (8, 0):
return "A100" # A30 is treated the same way as A100 for the moment.
if torch.cuda.is_available():
if torch.cuda.get_device_capability() == (8, 0):
return "A100" # A30 is treated the same way as A100 for the moment.

if torch.cuda.get_device_capability() == (9, 0):
return "H100"
if torch.cuda.get_device_capability() == (9, 0):
return "H100"

name = get_device_name()
if name.startswith("NVIDIA "):
if name and name.startswith("NVIDIA "):
name = name[len("NVIDIA ") :]

return name
# Fallback for non-NVIDIA devices (e.g., Intel XPU, AMD)
return name if name else "unknown"
9 changes: 8 additions & 1 deletion veomni/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@

IS_CUDA_AVAILABLE = torch.cuda.is_available()
IS_NPU_AVAILABLE = is_torch_npu_available()
IS_XPU_AVAILABLE = hasattr(torch, "xpu") and torch.xpu.is_available()

if IS_NPU_AVAILABLE:
torch.npu.config.allow_internal_format = False


def get_device_type() -> str:
"""Get device type based on current machine, currently only support CPU, CUDA, NPU."""
"""Get device type based on current machine, currently only support CPU, CUDA, NPU, XPU."""
if IS_CUDA_AVAILABLE:
device = "cuda"
elif IS_NPU_AVAILABLE:
device = "npu"
elif IS_XPU_AVAILABLE:
device = "xpu"
else:
device = "cpu"

Expand Down Expand Up @@ -71,6 +74,8 @@ def get_dist_comm_backend() -> str:
return "nccl"
elif IS_NPU_AVAILABLE:
return "hccl"
elif IS_XPU_AVAILABLE:
return "xccl"
else:
raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.")

Expand All @@ -86,6 +91,8 @@ def stream_synchronize() -> None:
torch.cuda.current_stream().synchronize()
elif IS_NPU_AVAILABLE:
torch.npu.current_stream().synchronize()
elif IS_XPU_AVAILABLE:
torch.xpu.current_stream().synchronize()
else:
synchronize()

Expand Down
Loading