Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ case "$HF_MODEL" in
PREPROCESSOR_FEATURE_SIZE=""
PREPROCESSOR_OUTPUT=""
;;
SocialLocalMobile/gemma-4-31B-it-HQQ-INT4)
unsloth/gemma-4-31B-it-GGUF)
MODEL_NAME="gemma4_31b"
TASK=""
MAX_SEQ_LEN=""
Expand All @@ -205,7 +205,7 @@ case "$HF_MODEL" in
;;
*)
echo "Error: Unsupported model '$HF_MODEL'"
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4"
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, unsloth/gemma-4-31B-it-GGUF"
exit 1
;;
esac
Expand Down Expand Up @@ -467,21 +467,27 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
exit 0
fi

# Gemma 4 31B uses a prequantized checkpoint and custom export script
# Gemma 4 31B: download the Q4_K_M GGUF and export via the GGUF loader
if [ "$MODEL_NAME" = "gemma4_31b" ]; then
pip install safetensors huggingface_hub gguf

# Download prequantized model outside OUTPUT_DIR to avoid uploading on failure
# Download GGUF + tokenizer outside OUTPUT_DIR to avoid uploading on failure.
# The unsloth GGUF repo ships the .gguf but no tokenizer.json, so the tokenizer
# is fetched from the (non-GGUF) unsloth/gemma-4-31B-it repo.
LOCAL_MODEL_DIR=$(mktemp -d)
INDUCTOR_CACHE=$(mktemp -d)
trap 'rm -rf "$LOCAL_MODEL_DIR" "$INDUCTOR_CACHE"' EXIT

python -c "from huggingface_hub import snapshot_download; snapshot_download('${HF_MODEL}', local_dir='${LOCAL_MODEL_DIR}')"
GGUF_FILE="gemma-4-31B-it-Q4_K_M.gguf"
python -c "from huggingface_hub import hf_hub_download; hf_hub_download('unsloth/gemma-4-31B-it-GGUF', '${GGUF_FILE}', local_dir='${LOCAL_MODEL_DIR}')"
python -c "from huggingface_hub import hf_hub_download; hf_hub_download('unsloth/gemma-4-31B-it', 'tokenizer.json', local_dir='${LOCAL_MODEL_DIR}')"
GGUF_PATH="${LOCAL_MODEL_DIR}/${GGUF_FILE}"

# Sanity check: run inference on the prequantized model
# Sanity check: run inference on the GGUF model
echo "::group::Inference sanity check"
INFERENCE_OUTPUT=$(python -m executorch.examples.models.gemma4_31b.inference \
--prequantized "$LOCAL_MODEL_DIR" \
--gguf "$GGUF_PATH" \
--tokenizer-path "${LOCAL_MODEL_DIR}/tokenizer.json" \
--prompt "What is the capital of France?" \
--max-new-tokens 32 \
--temperature 0 \
Expand All @@ -494,13 +500,13 @@ if [ "$MODEL_NAME" = "gemma4_31b" ]; then
echo "::endgroup::"

# Copy tokenizer for the runner
cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json"
cp "${LOCAL_MODEL_DIR}/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json"

# Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues)
echo "::group::Export"
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
python -m executorch.examples.models.gemma4_31b.export \
--prequantized "$LOCAL_MODEL_DIR" \
--gguf "$GGUF_PATH" \
--output-dir "${OUTPUT_DIR}"
echo "::endgroup::"

Expand Down
4 changes: 2 additions & 2 deletions .ci/scripts/test_model_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ case "$HF_MODEL" in
AUDIO_FILE=""
IMAGE_PATH=""
;;
SocialLocalMobile/gemma-4-31B-it-HQQ-INT4)
unsloth/gemma-4-31B-it-GGUF)
MODEL_NAME="gemma4_31b"
RUNNER_TARGET="gemma4_31b_runner"
RUNNER_PATH="gemma4_31b"
Expand All @@ -242,7 +242,7 @@ case "$HF_MODEL" in
;;
*)
echo "Error: Unsupported model '$HF_MODEL'"
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4"
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, unsloth/gemma-4-31B-it-GGUF"
exit 1
;;
esac
Expand Down
28 changes: 14 additions & 14 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ jobs:
name: "dinov2-small-imagenet1k-1-layer"
- repo: "SocialLocalMobile"
name: "Qwen3.5-35B-A3B-HQQ-INT4"
- repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
- repo: "unsloth"
name: "gemma-4-31B-it-GGUF"
quant:
- "non-quantized"
- "quantized-int4-tile-packed"
Expand All @@ -281,12 +281,12 @@ jobs:
quant: "quantized-int4-weight-only"
# Gemma 4 31B uses a prequantized checkpoint, only tile-packed
- model:
repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
repo: "unsloth"
name: "gemma-4-31B-it-GGUF"
quant: "non-quantized"
- model:
repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
repo: "unsloth"
name: "gemma-4-31B-it-GGUF"
quant: "quantized-int4-weight-only"
# Voxtral Realtime only supports int4-tile-packed on CUDA
- model:
Expand Down Expand Up @@ -342,7 +342,7 @@ jobs:
with:
timeout: 150
secrets-env: EXECUTORCH_HF_TOKEN
runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'mt-l-x86iavx512-11-125-a100' || 'mt-l-x86aavx2-29-113-a10g' }}
runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-GGUF') && 'mt-l-x86iavx512-11-125-a100' || 'mt-l-x86aavx2-29-113-a10g' }}
gpu-arch-type: cuda
gpu-arch-version: "13.0"
use-custom-docker-registry: false
Expand Down Expand Up @@ -424,8 +424,8 @@ jobs:
name: "dinov2-small-imagenet1k-1-layer"
- repo: "SocialLocalMobile"
name: "Qwen3.5-35B-A3B-HQQ-INT4"
- repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
- repo: "unsloth"
name: "gemma-4-31B-it-GGUF"
quant:
- "non-quantized"
- "quantized-int4-tile-packed"
Expand All @@ -447,12 +447,12 @@ jobs:
quant: "quantized-int4-weight-only"
# Gemma 4 31B uses a prequantized checkpoint, only tile-packed
- model:
repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
repo: "unsloth"
name: "gemma-4-31B-it-GGUF"
quant: "non-quantized"
- model:
repo: "SocialLocalMobile"
name: "gemma-4-31B-it-HQQ-INT4"
repo: "unsloth"
name: "gemma-4-31B-it-GGUF"
quant: "quantized-int4-weight-only"
# Voxtral Realtime only supports int4-tile-packed on CUDA
- model:
Expand Down Expand Up @@ -502,7 +502,7 @@ jobs:
quant: "non-quantized"
with:
timeout: 90
runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'mt-l-x86iavx512-11-125-a100' || 'mt-l-x86aavx2-29-113-a10g' }}
runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-GGUF') && 'mt-l-x86iavx512-11-125-a100' || 'mt-l-x86aavx2-29-113-a10g' }}
gpu-arch-type: cuda
gpu-arch-version: "13.0"
use-custom-docker-registry: false
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ if(CMAKE_CUDA_COMPILER)
_aoti_cuda_shim_sources
runtime/shims/int4mm.cu
runtime/shims/int4_plain_mm.cu
runtime/shims/int6_plain_mm.cu
runtime/shims/int8_plain_mm.cu
runtime/shims/sort.cu
runtime/shims/rand.cu
Expand Down
7 changes: 7 additions & 0 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
"aoti_torch_cuda_randint_low_out": None,
"executorch_cuda::int4_plain_mm": None,
"aoti_torch_cuda_int4_plain_mm": None,
"executorch_cuda::int6_plain_mm": None,
"aoti_torch_cuda_int6_plain_mm": None,
"executorch_cuda::int8_plain_mm": None,
"aoti_torch_cuda_int8_plain_mm": None,
}
Expand Down Expand Up @@ -314,6 +316,11 @@ def get_aoti_compile_options(
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
"AtenTensorHandle, int64_t, AtenTensorHandle*)"
],
torch.ops.executorch_cuda.int6_plain_mm.default: [
"AOTITorchError aoti_torch_cuda_int6_plain_mm("
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
"AtenTensorHandle, int64_t, AtenTensorHandle*)"
],
torch.ops.executorch_cuda.int8_plain_mm.default: [
"AOTITorchError aoti_torch_cuda_int8_plain_mm("
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
Expand Down
Loading
Loading