From b2120588f0a7079ee9a543df33643d3233545294 Mon Sep 17 00:00:00 2001 From: kahlun Date: Mon, 27 Apr 2026 07:04:01 -0700 Subject: [PATCH 1/4] [hardware] feat: add Intel XPU device support - Add IS_XPU_AVAILABLE flag, XPU branch in get_device_type(), get_dist_comm_backend() (xccl), and stream_synchronize() - Guard torch.cuda.get_device_capability() for non-CUDA devices; let XPU fall through to get_device_name() instead of returning 'unknown' - Guard group_gemm import (CUDA-only Triton kernels) with None defaults to avoid NameError on non-CUDA devices - Accept 'xpu' as valid init_device in non-FSDP path Tested on Intel Arc Pro B60 (Battlemage BMG-G21, 24 GB VRAM): - 1-GPU standalone: 7/7 pass - 2-GPU FSDP2: 8/8 pass (with CCL_ATL_SHM=1) - veRL e2e GRPO (VeOmni engine + vLLM rollout): PASS - Model: Qwen2.5-0.5B-Instruct (494M params, bf16) --- veomni/distributed/moe/moe_layer.py | 4 +++- veomni/distributed/torch_parallelize.py | 2 +- veomni/ops/kernels/moe/_kernels/utils/device.py | 9 +++++---- veomni/utils/device.py | 9 ++++++++- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/veomni/distributed/moe/moe_layer.py b/veomni/distributed/moe/moe_layer.py index 3fadff1f4..9bf514557 100644 --- a/veomni/distributed/moe/moe_layer.py +++ b/veomni/distributed/moe/moe_layer.py @@ -23,7 +23,9 @@ from .moe_utils import generate_weights_idx, permute, sort_chunks_by_idxs, unpermute -if not is_torch_npu_available(): +group_gemm_same_mn = None +group_gemm_same_nk = None +if not is_torch_npu_available() and torch.cuda.is_available(): from ...ops.kernels.moe._kernels.kernel.group_gemm import group_gemm_same_mn, group_gemm_same_nk diff --git a/veomni/distributed/torch_parallelize.py b/veomni/distributed/torch_parallelize.py index 7c4463753..69b67edf0 100644 --- a/veomni/distributed/torch_parallelize.py +++ b/veomni/distributed/torch_parallelize.py @@ -597,7 +597,7 @@ def build_parallelize_model( parallel_state = get_parallel_state() if not parallel_state.fsdp_enabled: - if kwargs.get("init_device") not in ["cuda", "npu"]: + if kwargs.get("init_device") not in ["cuda", "npu", "xpu"]: raise ValueError("Only FSDP training supports `init_device=cpu` or `init_device=meta`.") if kwargs.pop("enable_fsdp_offload", False): raise ValueError("Only FSDP training supports `enable_fsdp_offload`.") diff --git a/veomni/ops/kernels/moe/_kernels/utils/device.py b/veomni/ops/kernels/moe/_kernels/utils/device.py index d90c1e8fa..dbc813fa4 100644 --- a/veomni/ops/kernels/moe/_kernels/utils/device.py +++ b/veomni/ops/kernels/moe/_kernels/utils/device.py @@ -21,11 +21,12 @@ def get_device_key() -> str: import torch - if torch.cuda.get_device_capability() == (8, 0): - return "A100" # A30 is treated the same way as A100 for the moment. + if torch.cuda.is_available(): + if torch.cuda.get_device_capability() == (8, 0): + return "A100" # A30 is treated the same way as A100 for the moment. - if torch.cuda.get_device_capability() == (9, 0): - return "H100" + if torch.cuda.get_device_capability() == (9, 0): + return "H100" name = get_device_name() if name.startswith("NVIDIA "): diff --git a/veomni/utils/device.py b/veomni/utils/device.py index 27512200a..2e8094630 100644 --- a/veomni/utils/device.py +++ b/veomni/utils/device.py @@ -27,17 +27,20 @@ IS_CUDA_AVAILABLE = torch.cuda.is_available() IS_NPU_AVAILABLE = is_torch_npu_available() +IS_XPU_AVAILABLE = hasattr(torch, "xpu") and torch.xpu.is_available() if IS_NPU_AVAILABLE: torch.npu.config.allow_internal_format = False def get_device_type() -> str: - """Get device type based on current machine, currently only support CPU, CUDA, NPU.""" + """Get device type based on current machine, currently only support CPU, CUDA, NPU, XPU.""" if IS_CUDA_AVAILABLE: device = "cuda" elif IS_NPU_AVAILABLE: device = "npu" + elif IS_XPU_AVAILABLE: + device = "xpu" else: device = "cpu" @@ -71,6 +74,8 @@ def get_dist_comm_backend() -> str: return "nccl" elif IS_NPU_AVAILABLE: return "hccl" + elif IS_XPU_AVAILABLE: + return "xccl" else: raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.") @@ -86,6 +91,8 @@ def stream_synchronize() -> None: torch.cuda.current_stream().synchronize() elif IS_NPU_AVAILABLE: torch.npu.current_stream().synchronize() + elif IS_XPU_AVAILABLE: + torch.xpu.current_stream().synchronize() else: synchronize() From f893cad16e7299b04cf0d1c602f2381ea1e0deee Mon Sep 17 00:00:00 2001 From: kahlun Date: Wed, 29 Apr 2026 19:53:03 -0700 Subject: [PATCH 2/4] [review] fix: explicit fallback and robust XPU test behavior - device.py: keep explicit get_device_key fallback for non-CUDA devices - moe_layer.py: raise actionable RuntimeError when fused MoE group_gemm is unavailable (e.g., XPU), guiding users to moe_implementation=eager - tests/special_xpu/test_fsdp2_simple_xpu.py: rewritten to true FSDP2 via fully_shard (no FSDPv1 API) - tests/special_xpu/test_fsdp2_simple_xpu.py: removed dead train_step helper - tests/special_xpu/run_veomni_e2e_sft_xpu.sh: clarify this trainer path is fsdp1 smoke; FSDP2 coverage comes from test_fsdp2_simple_xpu.py Validated: - 2-GPU XPU FSDP2 smoke passes with fully_shard (loss decreases) - FSDP2 wrappers present at runtime (FSDPModule count > 0) --- configs/text/qwen2_5_xpu.yaml | 53 +++++++++ tests/special_xpu/run_fsdp2_simple_xpu.sh | 35 ++++++ tests/special_xpu/run_veomni_e2e_sft_xpu.sh | 50 +++++++++ tests/special_xpu/test_fsdp2_simple_xpu.py | 105 ++++++++++++++++++ veomni/distributed/moe/moe_layer.py | 15 +++ .../ops/kernels/moe/_kernels/utils/device.py | 5 +- 6 files changed, 261 insertions(+), 2 deletions(-) create mode 100644 configs/text/qwen2_5_xpu.yaml create mode 100755 tests/special_xpu/run_fsdp2_simple_xpu.sh create mode 100755 tests/special_xpu/run_veomni_e2e_sft_xpu.sh create mode 100644 tests/special_xpu/test_fsdp2_simple_xpu.py diff --git a/configs/text/qwen2_5_xpu.yaml b/configs/text/qwen2_5_xpu.yaml new file mode 100644 index 000000000..6e746dee5 --- /dev/null +++ b/configs/text/qwen2_5_xpu.yaml @@ -0,0 +1,53 @@ +model: + model_path: Qwen/Qwen2.5-0.5B-Instruct + ops_implementation: + attn_implementation: sdpa + +data: + train_path: fineweb + train_size: 1000000000000 + dataloader: + type: native + drop_last: true + datasets_type: iterable + data_type: plaintext + max_seq_len: 512 + text_keys: text + dyn_bsz_buffer_size: 200 + +train: + accelerator: + ulysses_size: 1 + fsdp_config: + fsdp_mode: fsdp1 + full_shard: true + offload: false + mixed_precision: + enable: true + offload: + enable_activation: false + gradient_checkpointing: + enable: true + global_batch_size: 8 + micro_batch_size: 1 + bsz_warmup_ratio: 0.007 + optimizer: + type: adamw + lr: 3.0e-4 + lr_warmup_ratio: 0.007 + lr_decay_style: constant + lr_decay_ratio: 1.0 + weight_decay: 0.01 + max_grad_norm: 1.0 + init_device: meta + enable_full_determinism: false + empty_cache_steps: 500 + checkpoint: + output_dir: Qwen2.5-0.5B-Instruct-sft-xpu + manager: dcp + save_steps: 100 + save_hf_weights: true + wandb: + enable: false + num_train_epochs: 1 + max_steps: 10 diff --git a/tests/special_xpu/run_fsdp2_simple_xpu.sh b/tests/special_xpu/run_fsdp2_simple_xpu.sh new file mode 100755 index 000000000..ebeb735ae --- /dev/null +++ b/tests/special_xpu/run_fsdp2_simple_xpu.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Simple FSDP2 XPU test - standalone, no trainer dependencies + +set -e + +echo "==========================================" +echo "VeOmni FSDP2 Standalone Test on Intel XPU" +echo "==========================================" +echo "GPUs: 2 (XPU devices 0,1)" +echo "Model: Qwen2.5-0.5B-Instruct (sdpa attention)" +echo "" + +# XPU environment variables +export ZE_AFFINITY_MASK="0,1" +export CCL_ATL_SHM=1 +export CCL_BUFFER_CACHE=0 +export CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 +export CCL_TOPO_ALGO=0 +export RAY_NUM_PRESTART_PYTHON_WORKERS=0 + +# Run with torchrun (2 XPU GPUs) +torchrun \ + --nnodes=1 \ + --nproc_per_node=2 \ + --master-port=4321 \ + tests/special_xpu/test_fsdp2_simple_xpu.py \ + --model Qwen/Qwen2.5-0.5B-Instruct \ + --batch-size 2 \ + --seq-len 128 \ + --steps 5 + +echo "" +echo "==========================================" +echo "VeOmni FSDP2 XPU test completed!" +echo "==========================================" diff --git a/tests/special_xpu/run_veomni_e2e_sft_xpu.sh b/tests/special_xpu/run_veomni_e2e_sft_xpu.sh new file mode 100755 index 000000000..04f637f6a --- /dev/null +++ b/tests/special_xpu/run_veomni_e2e_sft_xpu.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# VeOmni trainer e2e SFT smoke test on Intel XPU (2 GPUs) +# Note: This path currently runs fsdp1 from config. FSDP2 coverage is in +# tests/special_xpu/test_fsdp2_simple_xpu.py + +set -e + +# Get the VeOmni root directory +VEOMNI_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +cd "$VEOMNI_ROOT" + +# XPU environment variables +export ZE_AFFINITY_MASK="0,1" +export CCL_ATL_SHM=1 +export CCL_BUFFER_CACHE=0 +export CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 +export CCL_TOPO_ALGO=0 +export RAY_NUM_PRESTART_PYTHON_WORKERS=0 + +# Model and data setup +MODEL_PATH="${CI_HF_MODELS_DIR:-.}/Qwen/Qwen2.5-0.5B-Instruct" +DATASET_DIR="${CI_DATASET_DIR:-.}" + +echo "==========================================" +echo "VeOmni e2e SFT Test on Intel XPU" +echo "==========================================" +echo "Model: Qwen2.5-0.5B-Instruct" +echo "GPUs: 2 (XPU devices 0,1 via ZE_AFFINITY_MASK)" +echo "Config: qwen2_5_xpu.yaml (sdpa attention, fsdp1 trainer smoke)" +echo "" + +# Run torchrun with 2 XPU GPUs +torchrun \ + --nnodes=1 \ + --nproc_per_node=2 \ + --master-port=4321 \ + tasks/train_text.py \ + configs/text/qwen2_5_xpu.yaml \ + --model.model_path "$MODEL_PATH" \ + --data.train_path "$DATASET_DIR/fineweb" \ + --train.checkpoint.output_dir "Qwen2.5-0.5B-Instruct-sft-xpu" \ + --train.accelerator.fsdp_config.fsdp_mode fsdp1 \ + --train.num_train_epochs 1 \ + --train.max_steps 5 \ + --train.wandb.enable false + +echo "" +echo "==========================================" +echo "VeOmni XPU e2e test PASSED!" +echo "==========================================" diff --git a/tests/special_xpu/test_fsdp2_simple_xpu.py b/tests/special_xpu/test_fsdp2_simple_xpu.py new file mode 100644 index 000000000..54bcbc479 --- /dev/null +++ b/tests/special_xpu/test_fsdp2_simple_xpu.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +"""Simple FSDP2 training smoke test on Intel XPU. + +This test intentionally uses FSDP2's ``fully_shard`` API (not FSDPv1's +FullyShardedDataParallel class) so it validates the real FSDP2/XCCL path. +""" + +import argparse +import os + +import torch +import torch.distributed as dist +import torch.optim as optim +from transformers import AutoModelForCausalLM + + +try: + from torch.distributed.fsdp import FSDPModule, fully_shard +except ImportError: + from torch.distributed.fsdp._fully_shard import FSDPModule, fully_shard + + +def setup_distributed() -> tuple[int, int, int]: + dist.init_process_group("xccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + return rank, world_size, local_rank + + +def create_model(model_name: str, device: torch.device) -> torch.nn.Module: + print(f"[Rank {dist.get_rank()}] Loading model: {model_name}") + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa", + trust_remote_code=True, + ) + + # FSDP2 path: shard layers then root module with fully_shard. + model = model.to(device) + for layer in model.model.layers: + fully_shard(layer) + fully_shard(model) + + # Ensure we are actually testing FSDP2 wrappers. + fsdp2_count = sum(1 for m in model.modules() if isinstance(m, FSDPModule)) + if fsdp2_count == 0: + raise RuntimeError("FSDP2 wrapping failed: no FSDPModule instances found") + if dist.get_rank() == 0: + print(f"[Rank 0] FSDP2 modules wrapped: {fsdp2_count}") + + return model + + +def create_dummy_batch(batch_size: int = 2, seq_len: int = 128) -> dict[str, torch.Tensor]: + input_ids = torch.randint(0, 32000, (batch_size, seq_len), dtype=torch.long) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + labels = input_ids.clone() + labels[:, :-1] = labels[:, 1:].clone() + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +def main(args: argparse.Namespace) -> None: + rank, world_size, local_rank = setup_distributed() + + if not torch.xpu.is_available(): + raise RuntimeError("XPU is not available") + torch.xpu.set_device(local_rank) + device = torch.device(f"xpu:{local_rank}") + + print(f"[Rank {rank}/{world_size}] Device: {device}") + + model = create_model(args.model, device) + optimizer = optim.AdamW(model.parameters(), lr=1e-4) + + print(f"[Rank {rank}] Starting {args.steps} training steps...") + for step in range(1, args.steps + 1): + batch = create_dummy_batch(batch_size=args.batch_size, seq_len=args.seq_len) + batch = {k: v.to(device) for k, v in batch.items()} + + optimizer.zero_grad() + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + + if rank == 0 and step % max(1, args.steps // 5) == 0: + print(f"Step {step}/{args.steps}: loss={loss.item():.4f}") + + dist.destroy_process_group() + print(f"[Rank {rank}] Training complete!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-0.5B-Instruct") + parser.add_argument("--batch-size", type=int, default=2) + parser.add_argument("--seq-len", type=int, default=128) + parser.add_argument("--steps", type=int, default=5) + main(parser.parse_args()) diff --git a/veomni/distributed/moe/moe_layer.py b/veomni/distributed/moe/moe_layer.py index 9bf514557..015aab737 100644 --- a/veomni/distributed/moe/moe_layer.py +++ b/veomni/distributed/moe/moe_layer.py @@ -151,6 +151,13 @@ def forward( ): # permute_tokens: [tokens, hidden_dim] # cumsum: [local_experts] + + if group_gemm_same_nk is None: + raise RuntimeError( + "MoE fused group_gemm kernels are not available on this device. " + "This typically means you're using Intel XPU or another non-CUDA device. " + "Please use moe_implementation='eager' instead of 'fused'." + ) # compute linear layer fc1-1 fc1_1_output = group_gemm_same_nk( @@ -322,6 +329,14 @@ def forward( ): # permute_tokens: [tokens, hidden_dim] # cumsum: [local_experts] + + if group_gemm_same_nk is None: + raise RuntimeError( + "MoE fused group_gemm kernels are not available on this device. " + "This typically means you're using Intel XPU or another non-CUDA device. " + "Please use moe_implementation='eager' instead of 'fused'." + ) + assert fc1_1_2_weight.shape[1] % 2 == 0, ( f"Merged fc1_1_2_weight dim 1 must be even, got {fc1_1_2_weight.shape[1]}" ) diff --git a/veomni/ops/kernels/moe/_kernels/utils/device.py b/veomni/ops/kernels/moe/_kernels/utils/device.py index dbc813fa4..6176e0537 100644 --- a/veomni/ops/kernels/moe/_kernels/utils/device.py +++ b/veomni/ops/kernels/moe/_kernels/utils/device.py @@ -29,7 +29,8 @@ def get_device_key() -> str: return "H100" name = get_device_name() - if name.startswith("NVIDIA "): + if name and name.startswith("NVIDIA "): name = name[len("NVIDIA ") :] - return name + # Fallback for non-NVIDIA devices (e.g., Intel XPU, AMD) + return name if name else "unknown" From 7a733975fe29c950e2f7b1329a7f619edf282635 Mon Sep 17 00:00:00 2001 From: kahlun Date: Thu, 30 Apr 2026 00:52:26 -0700 Subject: [PATCH 3/4] tests(xpu): remove standalone fsdp2 smoke; generalize moe guard --- tests/special_xpu/run_fsdp2_simple_xpu.sh | 35 ------- tests/special_xpu/run_veomni_e2e_sft_xpu.sh | 2 - tests/special_xpu/test_fsdp2_simple_xpu.py | 105 -------------------- veomni/distributed/moe/moe_layer.py | 6 +- 4 files changed, 2 insertions(+), 146 deletions(-) delete mode 100755 tests/special_xpu/run_fsdp2_simple_xpu.sh delete mode 100644 tests/special_xpu/test_fsdp2_simple_xpu.py diff --git a/tests/special_xpu/run_fsdp2_simple_xpu.sh b/tests/special_xpu/run_fsdp2_simple_xpu.sh deleted file mode 100755 index ebeb735ae..000000000 --- a/tests/special_xpu/run_fsdp2_simple_xpu.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# Simple FSDP2 XPU test - standalone, no trainer dependencies - -set -e - -echo "==========================================" -echo "VeOmni FSDP2 Standalone Test on Intel XPU" -echo "==========================================" -echo "GPUs: 2 (XPU devices 0,1)" -echo "Model: Qwen2.5-0.5B-Instruct (sdpa attention)" -echo "" - -# XPU environment variables -export ZE_AFFINITY_MASK="0,1" -export CCL_ATL_SHM=1 -export CCL_BUFFER_CACHE=0 -export CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 -export CCL_TOPO_ALGO=0 -export RAY_NUM_PRESTART_PYTHON_WORKERS=0 - -# Run with torchrun (2 XPU GPUs) -torchrun \ - --nnodes=1 \ - --nproc_per_node=2 \ - --master-port=4321 \ - tests/special_xpu/test_fsdp2_simple_xpu.py \ - --model Qwen/Qwen2.5-0.5B-Instruct \ - --batch-size 2 \ - --seq-len 128 \ - --steps 5 - -echo "" -echo "==========================================" -echo "VeOmni FSDP2 XPU test completed!" -echo "==========================================" diff --git a/tests/special_xpu/run_veomni_e2e_sft_xpu.sh b/tests/special_xpu/run_veomni_e2e_sft_xpu.sh index 04f637f6a..3ecac6d03 100755 --- a/tests/special_xpu/run_veomni_e2e_sft_xpu.sh +++ b/tests/special_xpu/run_veomni_e2e_sft_xpu.sh @@ -1,7 +1,5 @@ #!/bin/bash # VeOmni trainer e2e SFT smoke test on Intel XPU (2 GPUs) -# Note: This path currently runs fsdp1 from config. FSDP2 coverage is in -# tests/special_xpu/test_fsdp2_simple_xpu.py set -e diff --git a/tests/special_xpu/test_fsdp2_simple_xpu.py b/tests/special_xpu/test_fsdp2_simple_xpu.py deleted file mode 100644 index 54bcbc479..000000000 --- a/tests/special_xpu/test_fsdp2_simple_xpu.py +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/env python3 -"""Simple FSDP2 training smoke test on Intel XPU. - -This test intentionally uses FSDP2's ``fully_shard`` API (not FSDPv1's -FullyShardedDataParallel class) so it validates the real FSDP2/XCCL path. -""" - -import argparse -import os - -import torch -import torch.distributed as dist -import torch.optim as optim -from transformers import AutoModelForCausalLM - - -try: - from torch.distributed.fsdp import FSDPModule, fully_shard -except ImportError: - from torch.distributed.fsdp._fully_shard import FSDPModule, fully_shard - - -def setup_distributed() -> tuple[int, int, int]: - dist.init_process_group("xccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - local_rank = int(os.environ.get("LOCAL_RANK", rank)) - return rank, world_size, local_rank - - -def create_model(model_name: str, device: torch.device) -> torch.nn.Module: - print(f"[Rank {dist.get_rank()}] Loading model: {model_name}") - model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - attn_implementation="sdpa", - trust_remote_code=True, - ) - - # FSDP2 path: shard layers then root module with fully_shard. - model = model.to(device) - for layer in model.model.layers: - fully_shard(layer) - fully_shard(model) - - # Ensure we are actually testing FSDP2 wrappers. - fsdp2_count = sum(1 for m in model.modules() if isinstance(m, FSDPModule)) - if fsdp2_count == 0: - raise RuntimeError("FSDP2 wrapping failed: no FSDPModule instances found") - if dist.get_rank() == 0: - print(f"[Rank 0] FSDP2 modules wrapped: {fsdp2_count}") - - return model - - -def create_dummy_batch(batch_size: int = 2, seq_len: int = 128) -> dict[str, torch.Tensor]: - input_ids = torch.randint(0, 32000, (batch_size, seq_len), dtype=torch.long) - attention_mask = torch.ones_like(input_ids, dtype=torch.long) - labels = input_ids.clone() - labels[:, :-1] = labels[:, 1:].clone() - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, - } - - -def main(args: argparse.Namespace) -> None: - rank, world_size, local_rank = setup_distributed() - - if not torch.xpu.is_available(): - raise RuntimeError("XPU is not available") - torch.xpu.set_device(local_rank) - device = torch.device(f"xpu:{local_rank}") - - print(f"[Rank {rank}/{world_size}] Device: {device}") - - model = create_model(args.model, device) - optimizer = optim.AdamW(model.parameters(), lr=1e-4) - - print(f"[Rank {rank}] Starting {args.steps} training steps...") - for step in range(1, args.steps + 1): - batch = create_dummy_batch(batch_size=args.batch_size, seq_len=args.seq_len) - batch = {k: v.to(device) for k, v in batch.items()} - - optimizer.zero_grad() - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - - if rank == 0 and step % max(1, args.steps // 5) == 0: - print(f"Step {step}/{args.steps}: loss={loss.item():.4f}") - - dist.destroy_process_group() - print(f"[Rank {rank}] Training complete!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-0.5B-Instruct") - parser.add_argument("--batch-size", type=int, default=2) - parser.add_argument("--seq-len", type=int, default=128) - parser.add_argument("--steps", type=int, default=5) - main(parser.parse_args()) diff --git a/veomni/distributed/moe/moe_layer.py b/veomni/distributed/moe/moe_layer.py index 015aab737..0741323f0 100644 --- a/veomni/distributed/moe/moe_layer.py +++ b/veomni/distributed/moe/moe_layer.py @@ -154,8 +154,7 @@ def forward( if group_gemm_same_nk is None: raise RuntimeError( - "MoE fused group_gemm kernels are not available on this device. " - "This typically means you're using Intel XPU or another non-CUDA device. " + "MoE fused group_gemm kernels are not available for the current device/backend. " "Please use moe_implementation='eager' instead of 'fused'." ) @@ -332,8 +331,7 @@ def forward( if group_gemm_same_nk is None: raise RuntimeError( - "MoE fused group_gemm kernels are not available on this device. " - "This typically means you're using Intel XPU or another non-CUDA device. " + "MoE fused group_gemm kernels are not available for the current device/backend. " "Please use moe_implementation='eager' instead of 'fused'." ) From 4fde11dde710587d02399c68fbb64ce5bfdd2f20 Mon Sep 17 00:00:00 2001 From: kahlun Date: Thu, 30 Apr 2026 01:02:03 -0700 Subject: [PATCH 4/4] configs(xpu): move qwen2_5_xpu yaml into dedicated xpu folder --- configs/{ => xpu}/text/qwen2_5_xpu.yaml | 0 tests/special_xpu/run_veomni_e2e_sft_xpu.sh | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename configs/{ => xpu}/text/qwen2_5_xpu.yaml (100%) diff --git a/configs/text/qwen2_5_xpu.yaml b/configs/xpu/text/qwen2_5_xpu.yaml similarity index 100% rename from configs/text/qwen2_5_xpu.yaml rename to configs/xpu/text/qwen2_5_xpu.yaml diff --git a/tests/special_xpu/run_veomni_e2e_sft_xpu.sh b/tests/special_xpu/run_veomni_e2e_sft_xpu.sh index 3ecac6d03..0f40b3d3b 100755 --- a/tests/special_xpu/run_veomni_e2e_sft_xpu.sh +++ b/tests/special_xpu/run_veomni_e2e_sft_xpu.sh @@ -24,7 +24,7 @@ echo "VeOmni e2e SFT Test on Intel XPU" echo "==========================================" echo "Model: Qwen2.5-0.5B-Instruct" echo "GPUs: 2 (XPU devices 0,1 via ZE_AFFINITY_MASK)" -echo "Config: qwen2_5_xpu.yaml (sdpa attention, fsdp1 trainer smoke)" +echo "Config: configs/xpu/text/qwen2_5_xpu.yaml (sdpa attention, fsdp1 trainer smoke)" echo "" # Run torchrun with 2 XPU GPUs @@ -33,7 +33,7 @@ torchrun \ --nproc_per_node=2 \ --master-port=4321 \ tasks/train_text.py \ - configs/text/qwen2_5_xpu.yaml \ + configs/xpu/text/qwen2_5_xpu.yaml \ --model.model_path "$MODEL_PATH" \ --data.train_path "$DATASET_DIR/fineweb" \ --train.checkpoint.output_dir "Qwen2.5-0.5B-Instruct-sft-xpu" \