diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 583be7c620..f0230a8ecd 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -3,3 +3,4 @@ ciflow_push_tags: - ciflow/benchmark - ciflow/tutorials - ciflow/rocm +- ciflow/4xh100 diff --git a/.github/scripts/torchao_model_releases/README.md b/.github/scripts/torchao_model_releases/README.md new file mode 100644 index 0000000000..f4609fc7ee --- /dev/null +++ b/.github/scripts/torchao_model_releases/README.md @@ -0,0 +1,142 @@ +# Scripts for torchao Model Release and Eval + +Note: all commands below should be run in directory: `.github/scripts/torchao_model_releases/` + +## Frequently Used Commands +### Release and Eval Scripts for New Model Releases +``` +MODEL=Qwen/Qwen3-8B +# Releasing all models: INT4, INT8, INT8-INT4 +sh release.sh --model_id $MODEL --push_to_hub --populate_model_card_template + +# INT8-INT4 requires additional steps to export and run so it's skipped from +# general eval here +# Need to set QMODEL_PREFIX properly before running eval +# QMODEL_PREFIX=pytorch/Qwen3-8B +sh eval.sh --model_ids $MODEL "$QMODEL_PREFIX-FP8" "$QMODEL_PREFIX-INT4" + +# Some follow up evals +sh eval.sh --eval_type latency --batch_size 256 "$QMODEL_PREFIX-FP8" +sh eval.sh --eval_type quality --batch_size 256 "$QMODEL_PREFIX-INT8-INT4" + +# Summarize all results +sh summarize_results.sh --model_ids $MODEL "$QMODEL_PREFIX-FP8" "$QMODEL_PREFIX-INT4" "$QMODEL_PREFIX-INT8-INT4" "$QMODEL_PREFIX-AWQ-INT4" +``` + +### AWQ Release and Eval +``` +MODEL=Qwen/Qwen3-8B +TASK=mmlu_abstract_algebra +python quantize_and_upload.py --model_id $MODEL --quant AWQ-INT4 --push_to_hub --task $TASK --calibration_limit 10 --populate_model_card_template +sh eval.sh --model_ids $MODEL "$QMODEL_PREFIX-AWQ-INT4" +``` + +### Update Released Checkpoints in PyTorch +Sometimes we may have to update the checkpoints under a different user name (organization) without changing the model card, e.g. for INT4 +``` +MODEL=Qwen/Qwen3-8B +sh release.sh --model $MODEL --quants INT4 --push_to_hub --push_to_user_id pytorch +``` + +Or AWQ checkpoint: +``` +MODEL=Qwen/Qwen3-8B +TASK=mmlu_abstract_algebra +python quantize_and_upload.py --model_id $MODEL --quant AWQ-INT4--task $TASK --calibration_limit 10 --push_to_hub --push_to_user_id pytorch +``` + +## Release Scripts +### default options +By default, we release FP8, INT4, INT8-INT4 checkpoints, with model card pre-filled with template content, that can be modified later after we have eval results. + +Examples: +``` +# Note: first login with `huggingface-cli login`, the quantized model will be uploaded to +# the logged in user + +# release with default quant options (FP8, INT4, INT8-INT4) +./release.sh --model_id Qwen/Qwen3-8B --push_to_hub + +# release a custom set of quant options +./release.sh --model_id Qwen/Qwen3-8B --quants INT4 FP8 --push_to_hub +``` + +Note: for initial release, please include `--populate_model_card_template` to populate model card template. + +### AWQ-INT4 +[AWQ](https://arxiv.org/abs/2306.00978) is a technique to improve accuracy for weight only quantization. It improves accuracy by preserving "salient" weight channels that has high impact on the accuracy of output, through multiplying the weight channel by a scale, and do the reverse for the correspnoding activation, since activation is not quantized, there is no additional loss from activation, while the quantization loss from weight can be reduced. + +After eval for INT4 checkpoint is done, we might find some task have a large accuracy drop compared to high precision baseline, in that case we can do a calibration for that task, with a few samples, tasks are selected from [lm-eval](https://github.com/EleutherAI/lm-eval\uation-harness/blob/main/lm_eval/tasks/README.md). You can follow [new task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md) to add new tasks to lm-eval. + +Examples: +``` +# release AWQ-INT4 model, calibrated with a specific task +# with some calibration_limit (number of samples) +python quantize_and_upload.py --model_id Qwen/Qwen3-8B --quant AWQ-INT4 --push_to_hub --task bbh --calibration_limit 2 +``` + +### Update checkpoints for a different user_id (e.g. pytorch) +Sometimes we may want to update the checkpoints for a different user id, without changing model card. For this we can use `--push_to_user_id`, e.g. + +``` +sh release.sh --model_id microsoft/Phi-4-mini-instruct --quants FP8 --push_to_hub --push_to_user_id pytorch +``` + +This will update `pytorch/Phi-4-mini-instruct-FP8` without changing the model card. + +## Eval Scripts +After we run the release script for a model, we can find new models in the huggingface hub page for the user, e.g. https://huggingface.co/torchao-testing, the models will have a model card that's filled in with template content, such as information about the model and eval instructions, there are a few things we need to fill in, including 1. peak memory usage, 2. latency when running model with vllm and 3. quality measurement using lm-eval. + +### Single Script +The simplest is just to run all three evals. Please check out `Run Single Evals` section to make sure the environment is setup correctly. This includes: +1. install [vllm](https://github.com/vllm-project/vllm) from source and set `VLLM_DIR` to the soruce directory of vllm +2. install [lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) + +``` +sh eval.sh --eval_type all --model_ids Qwen/Qwen3-8B pytorch/Qwen3-8B-INT4 +``` + +If `eval_type` is all, we'll also run summarize results for the list of `model_ids`, summarized results will be found in files: `summary_results_Qwen_Qwen3-8B.log` and `summary_results_pytorch_Qwen3-8B-INT4.log`. + +Then we can fill in the blanks in the model cards of uploaded checkpoints. + +### Separate Scripts +#### Memory Eval +``` +sh eval.sh --eval_type memory --model_ids Qwen/Qwen3-8B +``` + +#### Latency Eval +For latency eval, make sure vllm is installed. +``` +uv pip install vllm +``` + +Or install vllm nightly: +``` +uv pip install vllm --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 +``` + +After environment is setup, we can run eval: +``` +sh eval.sh --eval_type latency --model_ids Qwen/Qwen3-8B --batch_sizes 1,256 +``` + +#### Model Quality Eval +For model quality eval, we need to install lm-eval +``` +uv pip install lm-eval +``` +After environment is setup, we can run eval: +``` +sh eval.sh --eval_type quality --model_ids Qwen/Qwen3-8B --tasks hellaswag,mmlu +``` + +#### Summarize results +After we have finished all evals for each model, we can summarize the results with: +``` +sh summarize_results.sh --model_ids Qwen/Qwen3-8B pytorch/Qwen3-8B-INT4 +``` +Summarized results files for above command: `summary_results_Qwen_Qwen3-8B.log` and `summary_results_pytorch_Qwen3-8B-INT4.log` + +It will look through the current directory to find all the result files from memory, latency and quality evals and combine all the result information into a single file. diff --git a/.github/scripts/torchao_model_releases/eval.sh b/.github/scripts/torchao_model_releases/eval.sh new file mode 100644 index 0000000000..f284b2a0c3 --- /dev/null +++ b/.github/scripts/torchao_model_releases/eval.sh @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +#!/bin/bash +set -e +source eval_env_checks.sh + +usage() { + echo "Usage: $0 --model_ids ... [--eval_type ] [--batch_sizes ] [--tasks ]" + echo "Defaults:" + echo " batch_sizes: 1 256" + echo " tasks: mmlu" + exit 1 +} +MODEL_ID_ARRAY=() +EVAL_TYPE="all" +# these will be parsed in the other scripts +BATCH_SIZES="1 256" # Default for latency eval +TASKS="mmlu" # Default for quality eval +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --eval_type) + shift + if [[ $# -eq 0 ]]; then + echo "Error: --eval_type requires a value" + exit 1 + fi + EVAL_TYPE="$1" + shift + ;; + --model_ids) + shift + # Collect all subsequent arguments that are not another flag + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + MODEL_ID_ARRAY+=("$1") + shift + done + ;; + --batch_sizes) + shift + if [[ $# -eq 0 ]]; then + echo "Error: --batch_sizes requires a value" + exit 1 + fi + BATCH_SIZES="$1" + shift + ;; + --tasks) + shift + if [[ $# -eq 0 ]]; then + echo "Error: --tasks requires a value" + exit 1 + fi + TASKS="$1" + shift + ;; + *) + echo "Unknown argument: $1" + usage + ;; + esac +done +if [[ ${#MODEL_ID_ARRAY[@]} -eq 0 ]]; then + echo "Error: --model_ids is required" + usage +fi + +run_memory() { + check_torch + local model_id="$1" + sh eval_memory.sh --model_ids "$model_id" +} +run_latency() { + check_vllm + local model_id="$1" + sh eval_latency.sh --model_ids "$model_id" --batch_sizes $BATCH_SIZES +} +run_quality() { + check_lm_eval + local model_id="$1" + sh eval_quality.sh --model_ids "$model_id" --tasks $TASKS +} +for MODEL_ID in "${MODEL_ID_ARRAY[@]}"; do + case "$EVAL_TYPE" in + memory) + run_memory "$MODEL_ID" + ;; + latency) + run_latency "$MODEL_ID" + ;; + quality) + run_quality "$MODEL_ID" + ;; + all) + run_quality "$MODEL_ID" + run_memory "$MODEL_ID" + run_latency "$MODEL_ID" + ;; + *) + echo "Unknown eval_type: $EVAL_TYPE" + echo "Valid types are: all, memory, latency, quality" + exit 2 + ;; + esac +done + +# Run summarize_results.sh with MODEL_IDS if eval_type is "all" +if [[ "$EVAL_TYPE" == "all" ]]; then + sh summarize_results.sh --model_ids "${MODEL_ID_ARRAY[@]}" +fi diff --git a/.github/scripts/torchao_model_releases/eval_env_checks.sh b/.github/scripts/torchao_model_releases/eval_env_checks.sh new file mode 100644 index 0000000000..d6eb9c8801 --- /dev/null +++ b/.github/scripts/torchao_model_releases/eval_env_checks.sh @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +check_torch() { + if ! pip show torch > /dev/null 2>&1; then + echo "Error: torch package is NOT installed. please install with `pip install torch`" >&2 + exit 1 + fi +} + +check_vllm() { + if ! pip show vllm > /dev/null 2>&1; then + echo "Error: vllm package is NOT installed. please install with `pip install vllm`" >&2 + exit 1 + fi +} + +check_lm_eval() { + if ! pip show lm_eval > /dev/null 2>&1; then + echo "Error: lm_eval package is NOT installed. please install with `pip install lm_eval`" >&2 + exit 1 + fi +} diff --git a/.github/scripts/torchao_model_releases/eval_latency.sh b/.github/scripts/torchao_model_releases/eval_latency.sh new file mode 100644 index 0000000000..cc987d8d45 --- /dev/null +++ b/.github/scripts/torchao_model_releases/eval_latency.sh @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +#!/bin/bash +set -e +source eval_env_checks.sh +check_vllm + +MODEL_ID_ARRAY=() +BATCH_SIZE_ARRAY=(1) # default can be overwritten by user input +INPUT_LEN="256" # default input length +OUTPUT_LEN="256" # default output length +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --model_ids) + shift + # Collect all subsequent arguments that are not another flag + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + MODEL_ID_ARRAY+=("$1") + shift + done + ;; + --batch_sizes) + shift + BATCH_SIZE_ARRAY=() + # Collect all subsequent arguments that are not another flag + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + BATCH_SIZE_ARRAY+=("$1") + shift + done + ;; + --input_len) + shift + if [[ $# -eq 0 ]]; then + echo "Error: --input_len requires a value" + exit 1 + fi + INPUT_LEN="$1" + shift + ;; + --output_len) + shift + if [[ $# -eq 0 ]]; then + echo "Error: --output_len requires a value" + exit 1 + fi + OUTPUT_LEN="$1" + shift + ;; + *) + echo "Unknown argument: $1" + echo "Usage: $0 --model_id [--batch_sizes ] [--input_len ] [--output_len ]" + exit 1 + ;; + esac +done +if [[ ${#MODEL_ID_ARRAY[@]} -eq 0 ]]; then + echo "Error: --model_ids is required" + echo "Usage: $0 --model_ids ... [--batch_sizes ...] [--input_len ] [--output_len ]" + exit 1 +fi +# Save the original directory +ORIG_DIR="$(pwd)" +# cd to VLLM_DIR +cd $VLLM_DIR +for MODEL_ID in "${MODEL_ID_ARRAY[@]}"; do + echo "======================== Eval Latency $MODEL_ID ===========================" + # Replace all '/' with '_' + SAFE_MODEL_ID="${MODEL_ID//\//_}" + # Loop over batch sizes and print (replace with your eval command) + for BATCH_SIZE in "${BATCH_SIZE_ARRAY[@]}"; do + OUTPUT_FILE="$ORIG_DIR/${SAFE_MODEL_ID}_latency_batch${BATCH_SIZE}_in${INPUT_LEN}_out${OUTPUT_LEN}.log" + echo "Running latency eval for model $MODEL_ID with batch size $BATCH_SIZE with input length: $INPUT_LEN and output length: $OUTPUT_LEN" + VLLM_DISABLE_COMPILE_CACHE=1 vllm bench latency --input-len $INPUT_LEN --output-len $OUTPUT_LEN --model $MODEL_ID --batch-size $BATCH_SIZE > "$OUTPUT_FILE" 2>&1 + echo "Latency eval result saved to $OUTPUT_FILE" + done + echo "======================== Eval Latency $MODEL_ID End =========================" +done + +# cd back to original place +cd $ORIG_DIR diff --git a/.github/scripts/torchao_model_releases/eval_memory.sh b/.github/scripts/torchao_model_releases/eval_memory.sh new file mode 100644 index 0000000000..f181c492f6 --- /dev/null +++ b/.github/scripts/torchao_model_releases/eval_memory.sh @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +#!/bin/bash +set -e +source eval_env_checks.sh +check_torch +MODEL_ID_ARRAY=() +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --model_ids) + shift + # Collect all subsequent arguments that are not another flag + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + MODEL_ID_ARRAY+=("$1") + shift + done + ;; + *) + echo "Unknown argument: $1" + echo "Usage: $0 --model_ids ..." + exit 1 + ;; + esac +done +if [[ ${#MODEL_ID_ARRAY[@]} -eq 0 ]]; then + echo "Usage: $0 --model_ids ..." + exit 1 +fi +for MODEL_ID in "${MODEL_ID_ARRAY[@]}"; do + # Replace all '/' with '_' + SAFE_MODEL_ID="${MODEL_ID//\//_}" + OUTPUT_FILE="$(pwd)/${SAFE_MODEL_ID}_memory.log" + echo "======================== Eval Memory $MODEL_ID ============================" + python eval_peak_memory_usage.py --model_id "$MODEL_ID" > "$OUTPUT_FILE" 2>&1 + echo "Evaluation complete. Output saved to $OUTPUT_FILE" + echo "======================== Eval Memory $MODEL_ID End ========================" +done diff --git a/.github/scripts/torchao_model_releases/eval_peak_memory_usage.py b/.github/scripts/torchao_model_releases/eval_peak_memory_usage.py new file mode 100644 index 0000000000..b2f6762178 --- /dev/null +++ b/.github/scripts/torchao_model_releases/eval_peak_memory_usage.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def eval_peak_memory_usage(model_id: str): + model = AutoModelForCausalLM.from_pretrained( + model_id, device_map="cuda:0", torch_dtype=torch.bfloat16 + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + torch.cuda.reset_peak_memory_stats() + + prompt = "Hey, are you conscious? Can you talk to me?" + messages = [ + { + "role": "system", + "content": "", + }, + {"role": "user", "content": prompt}, + ] + templated_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + print("Prompt:", prompt) + print("Templated prompt:", templated_prompt) + inputs = tokenizer( + templated_prompt, + return_tensors="pt", + ).to("cuda") + generated_ids = model.generate(**inputs, max_new_tokens=128) + output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + print("Response:", output_text[0][len(prompt) :]) + + mem = torch.cuda.max_memory_reserved() / 1e9 + print(f"Peak Memory Usage: {mem:.02f} GB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Evaluate a model with the specified parameters." + ) + parser.add_argument( + "--model_id", type=str, help="Huggingface hub model ID of the model." + ) + args = parser.parse_args() + eval_peak_memory_usage(args.model_id) diff --git a/.github/scripts/torchao_model_releases/eval_quality.sh b/.github/scripts/torchao_model_releases/eval_quality.sh new file mode 100644 index 0000000000..dd0ab9c2b2 --- /dev/null +++ b/.github/scripts/torchao_model_releases/eval_quality.sh @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +#!/bin/bash +set -e +source eval_env_checks.sh +check_lm_eval + +MODEL_ID_ARRAY=() +TASK_ARRAY=("mmlu") # default can be overwritten by user input +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --model_ids) + shift + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + MODEL_ID_ARRAY+=("$1") + shift + done + ;; + --tasks) + shift + TASK_ARRAY=() + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + TASK_ARRAY+=("$1") + shift + done + ;; + *) + echo "Unknown argument: $1" + echo "Usage: $0 --model_id [--tasks (comma-separated, e.g. mmlu,arc_challenge, default mmlu)]" + exit 1 + ;; + esac +done +if [[ ${#MODEL_ID_ARRAY[@]} -eq 0 ]]; then + echo "Error: --model_ids is required" + echo "Usage: $0 --model_ids ... [--tasks ...]" + exit 1 +fi +RESULTS_DIR="$(pwd)/quality_eval_results" +for MODEL_ID in "${MODEL_ID_ARRAY[@]}"; do + # Replace all '/' with '_' + SAFE_MODEL_ID="${MODEL_ID//\//_}" + echo "======================== Eval Model Quality $MODLE_ID ======================" + for TASK in "${TASK_ARRAY[@]}"; do + OUTPUT_FILE="$(pwd)/${SAFE_MODEL_ID}_quality_${TASK}.log" + EVAL_CACHE_DB_PREFIX="/tmp/${SAFE_MODEL_ID}_quality_${TASK}" + mkdir -p "${EVAL_CACHE_DB_PREFIX}" + echo "Running model quality (accuracy) evaluation for model $MODEL_ID on task $TASK" + + lm_eval \ + --model hf \ + --model_args pretrained="$MODEL_ID" \ + --tasks "$TASK" \ + --device cuda:0 \ + --use_cache "$EVAL_CACHE_DB_PREFIX" \ + --batch_size auto \ + --output_path "$RESULTS_DIR" > "$OUTPUT_FILE" 2>&1 + + echo "Quality eval output for task '$TASK' saved to $OUTPUT_FILE" + done + echo "======================== Eval Model Quality $MODEL_ID End ==================" +done diff --git a/.github/scripts/torchao_model_releases/quantize_and_upload.py b/.github/scripts/torchao_model_releases/quantize_and_upload.py new file mode 100644 index 0000000000..083787526a --- /dev/null +++ b/.github/scripts/torchao_model_releases/quantize_and_upload.py @@ -0,0 +1,827 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from typing import List + +import torch +from huggingface_hub import ModelCard, get_token, whoami +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +from torchao._models._eval import TransformerEvalWrapper +from torchao.prototype.awq import ( + AWQConfig, +) +from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + ModuleFqnToConfig, + PerAxis, + PerGroup, + PerRow, + quantize_, +) + + +def _get_username(): + token = get_token() + username = whoami(token=token)["name"] + return username + + +def _untie_weights_and_save_locally(model_id): + untied_model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype="auto", device_map="cuda:0" + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + + from transformers.modeling_utils import find_tied_parameters + + if getattr( + untied_model.config.get_text_config(decoder=True), "tie_word_embeddings" + ): + setattr( + untied_model.config.get_text_config(decoder=True), + "tie_word_embeddings", + False, + ) + + untied_model._tied_weights_keys = [] + untied_model.lm_head.weight = torch.nn.Parameter( + untied_model.lm_head.weight.clone() + ) + + print("tied weights:", find_tied_parameters(untied_model)) + + MODEL_NAME = model_id.split("/")[-1] + # save locally + save_to_local_path = f"{MODEL_NAME}-untied-weights" + untied_model.save_pretrained(save_to_local_path) + tokenizer.save_pretrained(save_to_local_path) + return save_to_local_path + + +MODEL_CARD = """--- +base_model: {base_model} +tags: +- transformers +- torchao +- {model_type} +license: apache-2.0 +language: +- en +--- + +# {quant} {base_model} model + +- **Developed by:** {username} +- **License:** apache-2.0 +- **Quantized from Model :** {base_model} +- **Quantization Method :** {quant} + +{server_inference_recipe} + +{mobile_inference_recipe} + +# Quantization Recipe + +Install the required packages: +```Shell +pip install torch +pip install git+https://github.com/huggingface/transformers@main +pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 +pip install accelerate +``` + +{untie_embedding_recipe} + +Use the following code to get the quantized model: +```Py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +model_id = "{base_model}" +model_to_quantize = "{untied_model}" + +{quant_code} + +# Push to hub +USER_ID = "YOUR_USER_ID" +MODEL_NAME = model_id.split("/")[-1] +save_to = f"{{USER_ID}}/{{MODEL_NAME}}-{quant}" +quantized_model.push_to_hub(save_to, safe_serialization=False) +tokenizer.push_to_hub(save_to) + +# Manual Testing +prompt = "Hey, are you conscious? Can you talk to me?" +messages = [ + {{ + "role": "system", + "content": "", + }}, + {{"role": "user", "content": prompt}}, +] +templated_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, +) +print("Prompt:", prompt) +print("Templated prompt:", templated_prompt) +inputs = tokenizer( + templated_prompt, + return_tensors="pt", +).to("cuda") +generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) +output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print("Response:", output_text[0][len(prompt):]) +``` + +Note: to `push_to_hub` you need to run +```Shell +pip install -U "huggingface_hub[cli]" +huggingface-cli login +``` +and use a token with write access, from https://huggingface.co/settings/tokens + +# Model Quality +We rely on [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) to evaluate the quality of the quantized model. Here we only run on mmlu for sanity check. + +| Benchmark | | | +|----------------------------------|----------------|---------------------------| +| | {base_model} | {quantized_model} | +| mmlu | To be filled | To be filled | + + +
+ Reproduce Model Quality Results + +Need to install lm-eval from source: +https://github.com/EleutherAI/lm-evaluation-harness#install + +## baseline +```Shell +lm_eval --model hf --model_args pretrained={base_model} --tasks mmlu --device cuda:0 --batch_size 8 +``` + +## {quant} +```Shell +export MODEL={quantized_model} +lm_eval --model hf --model_args pretrained=$MODEL --tasks mmlu --device cuda:0 --batch_size 8 +``` +
+ + + +{server_peak_memory_usage} + + +{server_model_performance} + +{mobile_export_to_executorch} + +# Paper: TorchAO: PyTorch-Native Training-to-Serving Model Optimization +The model's quantization is powered by **TorchAO**, a framework presented in the paper [TorchAO: PyTorch-Native Training-to-Serving Model Optimization](https://huggingface.co/papers/2507.16099). + +**Abstract:** We present TorchAO, a PyTorch-native model optimization framework leveraging quantization and sparsity to provide an end-to-end, training-to-serving workflow for AI models. TorchAO supports a variety of popular model optimization techniques, including FP8 quantized training, quantization-aware training (QAT), post-training quantization (PTQ), and 2:4 sparsity, and leverages a novel tensor subclass abstraction to represent a variety of widely-used, backend agnostic low precision data types, including INT4, INT8, FP8, MXFP4, MXFP6, and MXFP8. TorchAO integrates closely with the broader ecosystem at each step of the model optimization pipeline, from pre-training (TorchTitan) to fine-tuning (TorchTune, Axolotl) to serving (HuggingFace, vLLM, SGLang, ExecuTorch), connecting an otherwise fragmented space in a single, unified workflow. TorchAO has enabled recent launches of the quantized Llama 3.2 1B/3B and LlamaGuard3-8B models and is open-source at this https URL . + +# Resources +* **Official TorchAO GitHub Repository:** [https://github.com/pytorch/ao](https://github.com/pytorch/ao) +* **TorchAO Documentation:** [https://docs.pytorch.org/ao/stable/index.html](https://docs.pytorch.org/ao/stable/index.html) + + +# Disclaimer +PyTorch has not performed safety evaluations or red teamed the quantized models. Performance characteristics, outputs, and behaviors may differ from the original models. Users are solely responsible for selecting appropriate use cases, evaluating and mitigating for accuracy, safety, and fairness, ensuring security, and complying with all applicable laws and regulations. + +Nothing contained in this Model Card should be interpreted as or deemed a restriction or modification to the licenses the models are released under, including any limitations of liability or disclaimers of warranties provided therein. +""" + + +_int4_quant_code = """ +from torchao.quantization import Int4WeightOnlyConfig +quant_config = Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq") +quantization_config = TorchAoConfig(quant_type=quant_config) +quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +tokenizer = AutoTokenizer.from_pretrained(model_id) +""" + +_fp8_quant_code = """ +from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow +quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) +quantization_config = TorchAoConfig(quant_type=quant_config) +quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +tokenizer = AutoTokenizer.from_pretrained(model_id) +""" + +_int8_int4_quant_code = """ +from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + ModuleFqnToConfig, +) +from torchao.quantization.granularity import PerGroup, PerAxis +embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), +) +linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), +) +quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}}) +quantization_config = TorchAoConfig(quant_type=quant_config, include_input_output_embeddings=True, modules_to_not_convert=[]) +quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="cuda:0", torch_dtype=torch.bfloat16, quantization_config=quantization_config) +tokenizer = AutoTokenizer.from_pretrained(model_id) +""" + +_awq_int4_quant_code = """ +from torchao.quantization import Int4WeightOnlyConfig, quantize_ +from torchao.prototype.awq import ( + AWQConfig, +) +from torchao._models._eval import TransformerEvalWrapper +model = AutoModelForCausalLM.from_pretrained( + model_to_quantize, + device_map="cuda:0", + torch_dtype=torch.bfloat16, +) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +base_config = Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq") +quant_config = AWQConfig(base_config, step="prepare") +quantize_( + model, + quant_config, +) +TransformerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=max_seq_length, +).run_eval( + tasks=tasks, + limit=calibration_limit, +) +quant_config = AWQConfig(base_config, step="convert") +quantize_(model, quant_config) + +quantized_model = model +quant_config = AWQConfig(base_config, step="prepare_for_loading") +quantized_model.config.quantization_config = TorchAoConfig(quant_config) +""" + + +_server_inference_recipe = """ +# Inference with vLLM +Install vllm nightly and torchao nightly to get some recent changes: +``` +pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly +pip install torchao +``` + +## Serving +Then we can serve with the following command: +```Shell +# Server +export MODEL={quantized_model} +VLLM_DISABLE_COMPILE_CACHE=1 vllm serve $MODEL --tokenizer $MODEL -O3 +``` + +```Shell +# Client +curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{{ + "model": "{quantized_model}", + "messages": [ + {{"role": "user", "content": "Give me a short introduction to large language models."}} + ], + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_tokens": 32768 +}}' +``` + +Note: please use `VLLM_DISABLE_COMPILE_CACHE=1` to disable compile cache when running this code, e.g. `VLLM_DISABLE_COMPILE_CACHE=1 python example.py`, since there are some issues with the composability of compile in vLLM and torchao, +this is expected be resolved in pytorch 2.8. + +# Inference with Transformers + +Install the required packages: +```Shell +pip install git+https://github.com/huggingface/transformers@main +pip install torchao +pip install torch +pip install accelerate +``` + +Example: +```Py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "{quantized_model}" + +# load the tokenizer and the model +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="cuda:0" +) + +# prepare the model input +prompt = "Give me a short introduction to large language model." +messages = [ + {{"role": "user", "content": prompt}} +] +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True # Switches between thinking and non-thinking modes. Default is True. +) +model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + +# conduct text completion +generated_ids = model.generate( + **model_inputs, + max_new_tokens=32768 +) +output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + +# parsing thinking content +try: + # rindex finding 151668 () + index = len(output_ids) - output_ids[::-1].index(151668) +except ValueError: + index = 0 + +thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n") +content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") + +print("thinking content:", thinking_content) +print("content:", content) +``` +""" + +_server_peak_memory_usage = """ +# Peak Memory Usage + +## Results + +| Benchmark | | | +|------------------|----------------|--------------------------------| +| | {base_model} | {quantized_model} | +| Peak Memory (GB) | To be filled | To be filled (?% reduction) | + + + +
+ Reproduce Peak Memory Usage Results + +We can use the following code to get a sense of peak memory usage during inference: + +```Py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +# use "{base_model}" or "{quantized_model}" +model_id = "{quantized_model}" +quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", torch_dtype=torch.bfloat16) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +torch.cuda.reset_peak_memory_stats() + +prompt = "Hey, are you conscious? Can you talk to me?" +messages = [ + {{ + "role": "system", + "content": "", + }}, + {{"role": "user", "content": prompt}}, +] +templated_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, +) +print("Prompt:", prompt) +print("Templated prompt:", templated_prompt) +inputs = tokenizer( + templated_prompt, + return_tensors="pt", +).to("cuda") +generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) +output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print("Response:", output_text[0][len(prompt):]) + +mem = torch.cuda.max_memory_reserved() / 1e9 +print(f"Peak Memory Usage: {{mem:.02f}} GB") +``` + +
+""" + +_server_model_performance = """ +# Model Performance + +## Results (A100 machine) +| Benchmark (Latency) | | | +|----------------------------------|----------------|--------------------------| +| | {base_model} | {quantized_model} | +| latency (batch_size=1) | ?s | ?s (?x speedup) | +| latency (batch_size=256) | ?s | ?s (?x speedup) | + +
+ Reproduce Model Performance Results + +## Setup + +Get vllm source code: +```Shell +git clone git@github.com:vllm-project/vllm.git +``` + +Install vllm +``` +VLLM_USE_PRECOMPILED=1 pip install --editable . +``` + +Run the benchmarks under `vllm` root folder: + +## benchmark_latency + +### baseline +```Shell +export MODEL={base_model} +python benchmarks/benchmark_latency.py --input-len 256 --output-len 256 --model $MODEL --batch-size 1 +``` + +### {quant} +```Shell +export MODEL={quantized_model} +VLLM_DISABLE_COMPILE_CACHE=1 python benchmarks/benchmark_latency.py --input-len 256 --output-len 256 --model $MODEL --batch-size 1 +``` +
+""" + + +# Mobile Specific recipes + +_mobile_inference_recipe = """ +# Running in a mobile app +The [pte file](https://huggingface.co/{quantized_model}/blob/main/model.pte) can be run with ExecuTorch on a mobile phone. See the [instructions](https://pytorch.org/executorch/main/llm/llama-demo-ios.html) for doing this in iOS. +On iPhone 15 Pro, the model runs at (to be filled) tokens/sec and uses (to be filled) Mb of memory. + +TODO: attach image +""" +_untie_embedding_recipe = """ +## Untie Embedding Weights +We want to quantize the embedding and lm_head differently. Since those layers are tied, we first need to untie the model: + +```Py +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) +import torch + +model_id = "{base_model}" +untied_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="cuda:0") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +print(untied_model) +from transformers.modeling_utils import find_tied_parameters +print("tied weights:", find_tied_parameters(untied_model)) +if getattr(untied_model.config.get_text_config(decoder=True), "tie_word_embeddings"): + setattr(untied_model.config.get_text_config(decoder=True), "tie_word_embeddings", False) + +untied_model._tied_weights_keys = [] +untied_model.lm_head.weight = torch.nn.Parameter(untied_model.lm_head.weight.clone()) + +print("tied weights:", find_tied_parameters(untied_model)) + +USER_ID = "YOUR_USER_ID" +MODEL_NAME = model_id.split("/")[-1] +save_to = f"{{USER_ID}}/{{MODEL_NAME}}-untied-weights" + +# save locally (we use this in the recipe) +save_to_local_path = f"{{MODEL_NAME}}-untied-weights" +untied_model.save_pretrained(save_to_local_path) +tokenizer.save_pretrained(save_to_local_path) + + +# or push to hub +untied_model.push_to_hub(save_to) +tokenizer.push_to_hub(save_to) +``` + +Note: to `push_to_hub` you need to run +```Shell +pip install -U "huggingface_hub[cli]" +huggingface-cli login +``` +and use a token with write access, from https://huggingface.co/settings/tokens + +## Quantization +""" + +_mobile_export_to_executorch = """ +# Exporting to ExecuTorch + +We can run the quantized model on a mobile phone using [ExecuTorch](https://github.com/pytorch/executorch). +Once ExecuTorch is [set-up](https://pytorch.org/executorch/main/getting-started.html), exporting and running the model on device is a breeze. + +ExecuTorch's LLM export scripts require the checkpoint keys and parameters have certain names, which differ from those used in Hugging Face. +So we first use a script that converts the Hugging Face checkpoint key names to ones that ExecuTorch expects: +The following script does this for you. + +[TODO: fix command below where necessary] +```Shell +python -m executorch.examples.models.qwen3.convert_weights $(hf download {quantized_model}) pytorch_model_converted.bin +``` + +Once we have the checkpoint, we export it to ExecuTorch with a max_seq_length/max_context_length of 1024 to the XNNPACK backend as follows. + +[TODO: fix config path in note where necessary] +(Note: ExecuTorch LLM export script requires config.json have certain key names. The correct config to use for the LLM export script is located at examples/models/qwen3/config/4b_config.json within the ExecuTorch repo.) + +[TODO: fix command below where necessary] +```Shell +python -m executorch.examples.models.llama.export_llama \ + --model "qwen3_4b" \ + --checkpoint pytorch_model_converted.bin \ + --params examples/models/qwen3/config/4b_config.json \ + --output_name model.pte \ + -kv \ + --use_sdpa_with_kv_cache \ + -X \ + --xnnpack-extended-ops \ + --max_context_length 1024 \ + --max_seq_length 1024 \ + --dtype fp32 \ + --metadata '{{"get_bos_id":199999, "get_eos_ids":[200020,199999]}}' +``` + +After that you can run the model in a mobile app (see [Running in a mobile app](#running-in-a-mobile-app)). + +(We try to keep these instructions up-to-date, but if you find they do not work, check out our [CI test in ExecuTorch](https://github.com/pytorch/executorch/blob/main/.ci/scripts/test_torchao_huggingface_checkpoints.sh) for the latest source of truth, and let us know we need to update our model card.) +""" + + +def quantize_and_upload( + model_id: str, + quant: str, + tasks: List[str], + calibration_limit: int, + max_seq_length: int, + push_to_hub: bool, + push_to_user_id: str, + populate_model_card_template: bool, +): + _int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + ) + _int8_int4_embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + quant_to_config = { + "FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + "INT4": Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", + ), + "INT8-INT4": ModuleFqnToConfig( + { + "_default": _int8_int4_linear_config, + "model.embed_tokens": _int8_int4_embedding_config, + } + ), + } + + quant_to_quant_code = { + "FP8": _fp8_quant_code, + "INT4": _int4_quant_code, + "INT8-INT4": _int8_int4_quant_code, + "AWQ-INT4": _awq_int4_quant_code, + } + + # preparation + model_to_quantize = model_id + if quant == "INT8-INT4": + model_to_quantize = _untie_weights_and_save_locally(model_to_quantize) + + # quantization + + if "AWQ" in quant: + # awq will use torchao API directly + assert quant == "AWQ-INT4", "Only support AWQ-INT4 for now" + model = AutoModelForCausalLM.from_pretrained( + model_to_quantize, + device_map="cuda:0", + torch_dtype=torch.bfloat16, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + base_config = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", + ) + quant_config = AWQConfig(base_config, step="prepare") + quantize_( + model, + quant_config, + ) + TransformerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + ).run_eval( + tasks=tasks, + limit=calibration_limit, + ) + quant_config = AWQConfig(base_config, step="convert") + quantize_(model, quant_config) + + quantized_model = model + quant_config = AWQConfig(base_config, step="prepare_for_loading") + quantized_model.config.quantization_config = TorchAoConfig(quant_config) + else: + # other quantization are integrated with `from_pretrained` in huggingface transformers + assert quant in quant_to_config, f"Unsupported quant option: {quant}" + quant_config = quant_to_config[quant] + + torchao_config_kwargs = {} + if "INT8-INT4" in quant: + torchao_config_kwargs["modules_to_not_convert"] = [] + torchao_config_kwargs["include_input_output_embeddings"] = True + + quantization_config = TorchAoConfig( + quant_type=quant_config, **torchao_config_kwargs + ) + quantized_model = AutoModelForCausalLM.from_pretrained( + model_to_quantize, + device_map="cuda:0", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + username = _get_username() + + MODEL_NAME = model_id.split("/")[-1] + + save_to_user_id = username if push_to_user_id is None else push_to_user_id + save_to = f"{save_to_user_id}/{MODEL_NAME}-{quant}" + untied_model_path = 'f"{{MODEL_NAME}}-untied-weights"' + is_mobile = quant == "INT8-INT4" + quantized_model_id = save_to + # model card + content = MODEL_CARD.format( + username=username, + base_model=model_id, + quantized_model=quantized_model_id, + model_type=quantized_model.config.model_type, + quant=quant, + quant_code=quant_to_quant_code[quant], + # server specific recipes + server_inference_recipe="" + if is_mobile + else _server_inference_recipe.format(quantized_model=quantized_model_id), + server_peak_memory_usage="" + if is_mobile + else _server_peak_memory_usage.format( + base_model=model_id, quantized_model=quantized_model_id + ), + server_model_performance="" + if is_mobile + else _server_model_performance.format( + base_model=model_id, quantized_model=quantized_model_id, quant=quant + ), + # mobile specific recipes + untied_model=untied_model_path if is_mobile else model_id, + untie_embedding_recipe=_untie_embedding_recipe if is_mobile else "", + mobile_inference_recipe=_mobile_inference_recipe.format( + quantized_model=quantized_model_id + ) + if is_mobile + else "", + mobile_export_to_executorch=_mobile_export_to_executorch.format( + quantized_model=quantized_model_id + ) + if is_mobile + else "", + ) + card = ModelCard(content) + + # Push to hub + if push_to_hub: + quantized_model.push_to_hub(quantized_model_id, safe_serialization=False) + tokenizer.push_to_hub(quantized_model_id) + if populate_model_card_template: + card.push_to_hub(quantized_model_id) + else: + quantized_model.save_pretrained(quantized_model_id, safe_serialization=False) + tokenizer.save_pretrained(quantized_model_id) + + # Manual Testing + prompt = "Hey, are you conscious? Can you talk to me?" + messages = [ + { + "role": "system", + "content": "", + }, + {"role": "user", "content": prompt}, + ] + templated_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + print("Prompt:", prompt) + print("Templated prompt:", templated_prompt) + inputs = tokenizer( + templated_prompt, + return_tensors="pt", + ).to("cuda") + generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) + output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + print("Response:", output_text[0][len(prompt) :]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Evaluate a model with the specified parameters." + ) + parser.add_argument( + "--model_id", type=str, help="Huggingface hub model ID of the model." + ) + parser.add_argument( + "--quant", + type=str, + help="Quantization method. Options are FP8, INT4, INT8-INT4, AWQ-INT4", + ) + parser.add_argument( + "--tasks", + nargs="+", + type=str, + help="lm-eval task to optimize for in awq, we'll select a sample from the task dataset and run awq calibration based on that", + default=["gsm8k"], + ) + parser.add_argument( + "--calibration_limit", + type=int, + default=10, + help="Number of samples to use for calibration. Default is 10.", + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=2048, + help="Maximum sequence length of examples to calibrate and evaluate model on. Default is 2048", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + default=False, + help="Flag to indicate whether push to huggingface hub or not", + ) + parser.add_argument( + "--push_to_user_id", + type=str, + default=None, + help="The user_id to use for pushing the quantized model, only used when --push_to_hub is set", + ) + parser.add_argument( + "--populate_model_card_template", + action="store_true", + default=False, + help="Flag to indicate whether push model card to huggingface hub or not", + ) + args = parser.parse_args() + quantize_and_upload( + args.model_id, + args.quant, + args.tasks, + args.calibration_limit, + args.max_seq_length, + args.push_to_hub, + args.push_to_user_id, + args.populate_model_card_template, + ) diff --git a/.github/scripts/torchao_model_releases/release.sh b/.github/scripts/torchao_model_releases/release.sh new file mode 100755 index 0000000000..81378052af --- /dev/null +++ b/.github/scripts/torchao_model_releases/release.sh @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +#!/bin/bash + +# see README.md for instructions + +# Default quantization options +default_quants=("FP8" "INT4" "INT8-INT4") +push_to_hub="" +push_to_user_id="" +populate_model_card_template="" +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --model_id) + model_id="$2" + shift 2 + ;; + --quants) + shift + quants=() + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + quants+=("$1") + shift + done + ;; + --push_to_hub) + push_to_hub="--push_to_hub" + shift + ;; + --push_to_user_id) + push_to_user_id=("--push_to_user_id $2") + shift 2 + ;; + --populate_model_card_template) + populate_model_card_template="--populate_model_card_template" + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done +# Use default quants if none specified +if [[ -z "$model_id" ]]; then + echo "Error: --model_id is required" + echo "Usage: $0 --model_id [--quants [quant2 ...]] [--push_to_hub] [--push_to_user_id ] [--populate_model_card_template]" + exit 1 +fi +if [[ ${#quants[@]} -eq 0 ]]; then + quants=("${default_quants[@]}") +fi +# Run the python command for each quantization option +for quant in "${quants[@]}"; do + echo "Running: python quantize_and_upload.py --model_id $model_id --quant $quant $push_to_hub $push_to_user_id $populate_model_card_template" + python quantize_and_upload.py --model_id "$model_id" --quant "$quant" $push_to_hub $push_to_user_id $populate_model_card_template +done diff --git a/.github/scripts/torchao_model_releases/summarize_results.sh b/.github/scripts/torchao_model_releases/summarize_results.sh new file mode 100644 index 0000000000..7e9c43b99b --- /dev/null +++ b/.github/scripts/torchao_model_releases/summarize_results.sh @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +#!/bin/bash +set -e +usage() { + echo "Usage: $0 --model_ids ..." + exit 1 +} +MODEL_ID_ARRAY=() +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --model_ids) + shift + # Collect all subsequent arguments that are not another flag + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + MODEL_ID_ARRAY+=("$1") + shift + done + ;; + *) + echo "Unknown argument: $1" + usage + ;; + esac +done +if [[ ${#MODEL_ID_ARRAY[@]} -eq 0 ]]; then + echo "Error: --model_ids is required" + usage + exit 1 +fi +for MODEL_ID in "${MODEL_ID_ARRAY[@]}"; do + SAFE_MODEL_ID="${MODEL_ID//\//_}" + OUTPUT_FILE="summary_results_${SAFE_MODEL_ID}.log" + # Clear or create the output file + > "$OUTPUT_FILE" + + { + echo "===== Summary for model: $MODEL_ID =====" + QUALITY_LOG_PATTERN="${SAFE_MODEL_ID}_quality_*.log" + # Quality logs (multiple files, one per task) + QUALITY_LOGS=( $QUALITY_LOG_PATTERN ) + if [ -e "${QUALITY_LOGS[0]}" ]; then + for Q_LOG in "${QUALITY_LOGS[@]}"; do + # find last appearance of pretrained={MODEL_ID} and + # extract all lines after that + PATTERN="pretrained=${MODEL_ID}" + LAST_LINE=$(grep -n "$PATTERN" "$Q_LOG" | tail -1 | cut -d: -f1) + if [ -n "$LAST_LINE" ]; then + echo "--- Quality log: $Q_LOG (lines starting from $((LAST_LINE + 1))) ---" + tail -n +"$((LAST_LINE + 1))" "$Q_LOG" + else + echo "Pattern not found in $Q_LOG" + fi + done + else + echo "--- No quality logs found matching pattern: $QUALITY_LOG_PATTERN" + fi + + MEMORY_LOG="${SAFE_MODEL_ID}_memory.log" + if [ -f "$MEMORY_LOG" ]; then + echo "--- Memory log (last 1 lines) ---" + tail -n 1 "$MEMORY_LOG" + else + echo "--- Memory log not found: $MEMORY_LOG" + fi + + LATENCY_LOG_PATTERN="${SAFE_MODEL_ID}_latency_batch*_in*_out*.log" + LATENCY_LOGS=( $LATENCY_LOG_PATTERN ) + if [ -e "${LATENCY_LOGS[0]}" ]; then + for LAT_LOG in "${LATENCY_LOGS[@]}"; do + echo "--- Latency log: $LAT_LOG (last 7 lines) ---" + tail -n 7 "$LAT_LOG" + done + else + echo "--- No latency logs found matching pattern: $LATENCY_LOG_PATTERN" + fi + echo "" + echo "===== End of Summary for model: $MODEL_ID =====" + } >> "$OUTPUT_FILE" + echo "Summary of results saved to $OUTPUT_FILE" +done diff --git a/.github/workflows/1xH100_tests.yml b/.github/workflows/1xH100_tests.yml new file mode 100644 index 0000000000..cd5ef73207 --- /dev/null +++ b/.github/workflows/1xH100_tests.yml @@ -0,0 +1,54 @@ +name: Run 1xH100 Tests + +on: + push: + branches: + - main + - 'gh/**' + pull_request: + branches: + - main + - 'gh/**' + +concurrency: + group: 1xH100_tests-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + +jobs: + test: + strategy: + fail-fast: false + matrix: + include: + - name: H100 + runs-on: linux.aws.h100 + torch-spec: '--pre torch torchvision torchaudio fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu126' + gpu-arch-type: "cuda" + gpu-arch-version: "12.4" + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 90 + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive + script: | + conda create -n venv python=3.9 -y + conda activate venv + export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH + python -m pip install --upgrade pip + pip install uv + pip install ${{ matrix.torch-spec }} + uv pip install -r dev-requirements.txt + pip install . + pytest test/integration --verbose -s + pytest test/dtypes/test_affine_quantized_float.py --verbose -s + python test/quantization/quantize_/workflows/float8/test_float8_tensor.py + ./test/float8/test_everything_single_gpu.sh + pytest test/prototype/mx_formats/ -s diff --git a/.github/workflows/float8_test.yml b/.github/workflows/1xL4_tests.yml similarity index 73% rename from .github/workflows/float8_test.yml rename to .github/workflows/1xL4_tests.yml index 91083df0bf..39175ed0f9 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/1xL4_tests.yml @@ -1,4 +1,4 @@ -name: Run Float8 Tests +name: Run 1xL4 Tests on: push: @@ -11,7 +11,7 @@ on: - 'gh/**' concurrency: - group: float8_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + group: 1xL4_tests-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} cancel-in-progress: true env: @@ -28,11 +28,6 @@ jobs: torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126' gpu-arch-type: "cuda" gpu-arch-version: "12.6" - - name: H100 - runs-on: linux.aws.h100 - torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126' - gpu-arch-type: "cuda" - gpu-arch-version: "12.4" permissions: id-token: write contents: read @@ -51,8 +46,8 @@ jobs: pip install uv pip install ${{ matrix.torch-spec }} uv pip install -r dev-requirements.txt - uv pip install vllm pip install . - pytest test/float8 --verbose -s pytest test/integration --verbose -s pytest test/dtypes/test_affine_quantized_float.py --verbose -s + ./test/float8/test_everything_single_gpu.sh + python test/quantization/quantize_/workflows/float8/test_float8_tensor.py diff --git a/.github/workflows/4xH100_tests.yml b/.github/workflows/4xH100_tests.yml new file mode 100644 index 0000000000..b19b2f2dcc --- /dev/null +++ b/.github/workflows/4xH100_tests.yml @@ -0,0 +1,49 @@ +name: Run 4xH100 tests + +on: + push: + branches: + - main + tags: + - ciflow/4xh100/* + workflow_dispatch: + +concurrency: + group: 4xH100_tests-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + +jobs: + test: + strategy: + fail-fast: false + matrix: + include: + - name: H100 + runs-on: linux.aws.h100.4 + torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126' + gpu-arch-type: "cuda" + gpu-arch-version: "12.4" + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 60 + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive + script: | + conda create -n venv python=3.9 -y + conda activate venv + export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH + python -m pip install --upgrade pip + pip install uv + pip install ${{ matrix.torch-spec }} + uv pip install -r dev-requirements.txt + pip install . + ./test/float8/test_everything_multi_gpu.sh + ./test/prototype/mx_formats/test_mx_dtensor.sh diff --git a/.github/workflows/build_wheels_linux.yml b/.github/workflows/build_wheels_linux.yml index a8d96abc8a..f164ed03c5 100644 --- a/.github/workflows/build_wheels_linux.yml +++ b/.github/workflows/build_wheels_linux.yml @@ -5,6 +5,7 @@ on: pull_request: paths: - build/packaging/** + - packaging/** - .github/workflows/build_wheels_linux.yml - setup.py push: diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index f1188fd7d5..0858076551 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -39,7 +39,7 @@ jobs: contents: read uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: - timeout: 120 + timeout: 180 runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} @@ -59,12 +59,6 @@ jobs: fail-fast: false matrix: include: - - name: CUDA 2.5.1 - runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' - gpu-arch-type: "cuda" - gpu-arch-version: "12.6" - dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/" - name: CUDA 2.6 runs-on: linux.g5.12xlarge.nvidia.gpu torch-spec: 'torch==2.6.0' @@ -77,13 +71,13 @@ jobs: gpu-arch-type: "cuda" gpu-arch-version: "12.6" dev-requirements-overrides: "" + - name: CUDA 2.8 + runs-on: linux.g5.12xlarge.nvidia.gpu + torch-spec: 'torch==2.8.0' + gpu-arch-type: "cuda" + gpu-arch-version: "12.6" + dev-requirements-overrides: "" - - name: CPU 2.5.1 - runs-on: linux.4xlarge - torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu' - gpu-arch-type: "cpu" - gpu-arch-version: "" - dev-requirements-overrides: "s/^pytest$/pytest==7.4.0/" - name: CPU 2.6 runs-on: linux.4xlarge torch-spec: 'torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu' @@ -96,10 +90,16 @@ jobs: gpu-arch-type: "cpu" gpu-arch-version: "" dev-requirements-overrides: "" + - name: CPU 2.8 + runs-on: linux.4xlarge + torch-spec: 'torch==2.8.0 --index-url https://download.pytorch.org/whl/cpu' + gpu-arch-type: "cpu" + gpu-arch-version: "" + dev-requirements-overrides: "" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: - timeout: 120 + timeout: 180 runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} diff --git a/.github/workflows/regression_test_aarch64.yml b/.github/workflows/regression_test_aarch64.yml new file mode 100644 index 0000000000..ff10b661a5 --- /dev/null +++ b/.github/workflows/regression_test_aarch64.yml @@ -0,0 +1,139 @@ +name: Run Regression Tests (aarch64) + +on: + push: + branches: + - main + - 'gh/**' + pull_request: + branches: + - main + - 'gh/**' + +jobs: + test-cpu-ops: + strategy: + matrix: + runner: [macos-14, linux.arm64.2xlarge] + runs-on: ${{matrix.runner}} + defaults: + run: + shell: bash -el {0} + steps: + - name: Checkout repo + uses: actions/checkout@v3 + with: + submodules: true + - name: Setup environment + uses: conda-incubator/setup-miniconda@v3 + with: + python-version: "3.10" + miniconda-version: "latest" + activate-environment: venv + - name: Install requirements mac + if: runner.os == 'macOS' + run: | + conda activate venv + # Install executorch first because it installs its own version + # of torch and torchao, which we do not want to use + pip install executorch + pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall + pip install -r dev-requirements.txt + USE_CPP=1 TORCHAO_BUILD_KLEIDIAI=1 pip install . + - name: Install requirements linux + if: runner.os == 'Linux' + run: | + conda activate venv + pip install coremltools + pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall + pip install -r dev-requirements.txt + BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install . + - name: Run python tests + run: | + conda activate venv + pytest -s test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py + pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py + pytest -s test/prototype/test_embedding.py + pytest -s test/prototype/test_int8_lut_tensor.py + pytest -s test/prototype/test_tensor_conversion.py + pytest -s test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py + pytest -s test/prototype/test_parq.py + - name: torchao/csrc/cpu - build and run C++ tests + if: runner.os == 'macOS' + run: | + conda activate venv + pushd torchao/csrc/cpu + sh build_and_run_tests.sh + rm -rf cmake-out + popd + - name: torchao/csrc/cpu - build benchmarks + if: runner.os == 'macOS' + run: | + conda activate venv + pushd torchao/csrc/cpu + sh build_and_run_benchmarks.sh build_only + rm -rf cmake-out + popd + - name: torchao/csrc/cpu - build shared_kernels with ExecuTorch + if: runner.os == 'macOS' + run: | + conda activate venv + pushd torchao/csrc/cpu + sh build_shared_kernels.sh executorch + rm -rf cmake-out + popd + + # test-mps-ops: + # strategy: + # matrix: + # runner: [macos-m1-stable] + # runs-on: ${{matrix.runner}} + # steps: + # - name: Print machine info + # run: | + # uname -a + # if [ $(uname -s) == Darwin ]; then + # sysctl machdep.cpu.brand_string + # sysctl machdep.cpu.core_count + # fi + # - name: Checkout repo + # uses: actions/checkout@v3 + # with: + # submodules: true + # - name: Create conda env + # run: | + # conda create -yn test-mps-ops-env python=3.11 + # - name: Activate conda env + # run: | + # source activate base + # conda activate test-mps-ops-env + # - name: Install torch + # run: | + # conda run -n test-mps-ops-env pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" + # - name: Print torch version + # run: | + + # conda run -n test-mps-ops-env python -c "import torch; print(torch.__version__)" + # - name: Install requirements + # run: | + # source activate base + # conda activate test-mps-ops-env + # pip install -r dev-requirements.txt + # pip install pyyaml importlib-metadata + # - name: Print pip freeze + # run: | + # conda run -n test-mps-ops-env pip freeze + # - name: Print current directory + # run: | + # conda run -n test-mps-ops-env python -c "import os; print(os.getcwd())" + # - name: Build ao with experimental mps ops + # run: | + # source activate base + # conda activate test-mps-ops-env + # USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . + # - name: Run mps tests + # run: | + # pushd torchao/experimental/ops/mps/test + # conda run -n test-mps-ops-env python test_lowbit.py + # conda run -n test-mps-ops-env python test_quantizer.py + # popd diff --git a/.github/workflows/regression_test_rocm.yml b/.github/workflows/regression_test_rocm.yml index d43b5f8d10..a9db993c25 100644 --- a/.github/workflows/regression_test_rocm.yml +++ b/.github/workflows/regression_test_rocm.yml @@ -21,21 +21,23 @@ jobs: matrix: include: - name: ROCM Nightly - runs-on: linux.rocm.gpu.mi300.2 + runs-on: linux.rocm.gpu.gfx942.2 torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3' gpu-arch-type: "rocm" gpu-arch-version: "6.3" + docker-image: pytorch/manylinux2_28-builder:rocm6.3 permissions: id-token: write contents: read uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: - timeout: 150 + timeout: 210 no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }} runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} + docker-image: ${{ matrix.docker-image }} submodules: recursive script: | conda create -n venv python=3.9 -y diff --git a/.github/workflows/release_model.yml b/.github/workflows/release_model.yml new file mode 100644 index 0000000000..6b3566e07c --- /dev/null +++ b/.github/workflows/release_model.yml @@ -0,0 +1,46 @@ +name: Release Model + +on: + workflow_dispatch: + inputs: + hf_model_id: + description: 'Model ID for huggingface model to quantize, e.g. google/gemma-3-4b-it' + required: true + type: string + +env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + +jobs: + test: + strategy: + fail-fast: false + matrix: + include: + - name: H100 + runs-on: linux.aws.h100 + torch-spec: '--pre torch torchvision torchaudio fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu126' + gpu-arch-type: "cuda" + gpu-arch-version: "12.4" + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 90 + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive + script: | + conda create -n venv python=3.9 -y + conda activate venv + export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH + python -m pip install --upgrade pip + pip install uv + pip install ${{ matrix.torch-spec }} + uv pip install -r dev-requirements.txt + pip install . + HF_MODEL_ID=${{ github.event.inputs.hf_model_id }} + cd .github/scripts/torchao_model_releases + ./release.sh --model_id $HF_MODEL_ID --push_to_hub diff --git a/.github/workflows/run_microbenchmarks.yml b/.github/workflows/run_microbenchmarks.yml new file mode 100644 index 0000000000..3c21afa35b --- /dev/null +++ b/.github/workflows/run_microbenchmarks.yml @@ -0,0 +1,69 @@ +name: Microbenchmarks-Perf-Nightly +# Dashboard: https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fao&benchmarkName=micro-benchmark+api + +on: + push: + tags: + - ciflow/benchmark/* + workflow_dispatch: + schedule: + - cron: '0 3 * * *' # Run daily at 7 AM UTC + +jobs: + benchmark: + runs-on: linux.aws.h100 + strategy: + matrix: + torch-spec: + - '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126' + steps: + - uses: actions/checkout@v4 + + - name: Setup miniconda + uses: pytorch/test-infra/.github/actions/setup-miniconda@main + with: + python-version: "3.9" + + - name: Run benchmark + shell: bash + run: | + set -eux + + # Upgrade pip + ${CONDA_RUN} python -m pip install --upgrade pip + + ${CONDA_RUN} ls + ${CONDA_RUN} bash -c 'pwd' + ${CONDA_RUN} bash -c 'echo $PYTHONPATH' + + # Install dependencies + ${CONDA_RUN} pip install ${{ matrix.torch-spec }} + ${CONDA_RUN} pip install -r dev-requirements.txt + ${CONDA_RUN} pip install . + + ${CONDA_RUN} ls + ${CONDA_RUN} bash -c 'pwd' + ${CONDA_RUN} bash -c 'echo $PYTHONPATH' + + # Set PYTHONPATH to current directory (.) if not set, and include the benchmarks directory + ${CONDA_RUN} export PYTHONPATH="${PYTHONPATH:-$(pwd)}:$(pwd)/benchmarks" + + # Create benchmark results directory + mkdir -p ${{ runner.temp }}/benchmark-results + + # Run microbenchmarks for dashboard + ${CONDA_RUN} bash -c ' + export PYTHONPATH="${PYTHONPATH:-$(pwd)}:$(pwd)/benchmarks" + echo "PYTHONPATH is: $PYTHONPATH" + echo "Current directory is: $(pwd)" + python benchmarks/dashboard/ci_microbenchmark_runner.py \ + --config benchmarks/dashboard/microbenchmark_quantization_config.yml \ + --output "$RUNNER_TEMP/benchmark-results/microbenchmark-results.json"' + + - name: Upload the benchmark results to OSS benchmark database for the dashboard + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: ${{ runner.temp }}/benchmark-results + dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml deleted file mode 100644 index 4c56ec0c0e..0000000000 --- a/.github/workflows/torchao_experimental_test.yml +++ /dev/null @@ -1,133 +0,0 @@ -name: Run TorchAO Experimental Tests - -on: - push: - branches: - - main - - 'gh/**' - pull_request: - branches: - - main - - 'gh/**' - -jobs: - test-cpu-ops: - strategy: - matrix: - runner: [macos-14, linux.arm64.2xlarge] - runs-on: ${{matrix.runner}} - defaults: - run: - shell: bash -el {0} - steps: - - name: Checkout repo - uses: actions/checkout@v3 - with: - submodules: true - - name: Setup environment - uses: conda-incubator/setup-miniconda@v3 - with: - python-version: "3.10" - miniconda-version: "latest" - activate-environment: venv - - name: Install requirements mac - if: runner.os == 'macOS' - run: | - conda activate venv - # Install executorch first because it installs its own version - # of torch and torchao, which we do not want to use - pip install executorch - pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall - pip install -r dev-requirements.txt - USE_CPP=1 TORCHAO_BUILD_KLEIDIAI=1 pip install . - - name: Install requirements linux - if: runner.os == 'Linux' - run: | - conda activate venv - pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall - pip install -r dev-requirements.txt - BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install . - - name: Run python tests - run: | - conda activate venv - pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py - python torchao/experimental/tests/test_embedding_xbit_quantizer.py - python torchao/experimental/tests/test_quant_passes.py - - name: Run kernels/cpu/aarch64/tests - if: runner.os == 'macOS' - run: | - conda activate venv - pushd torchao/experimental/kernels/cpu/aarch64/tests - sh build_and_run_tests.sh - rm -rf /tmp/cmake-out - popd - - name: Run torchao/experimental/ops/tests - if: runner.os == 'macOS' - run: | - conda activate venv - pushd torchao/experimental/ops/tests - sh build_and_run_tests.sh - rm -rf /tmp/cmake-out - popd - - name: ET ops build - if: runner.os == 'macOS' - run: | - conda activate venv - pushd torchao/experimental - sh build_torchao_ops.sh executorch - popd - - test-mps-ops: - strategy: - matrix: - runner: [macos-m1-stable] - runs-on: ${{matrix.runner}} - steps: - - name: Print machine info - run: | - uname -a - if [ $(uname -s) == Darwin ]; then - sysctl machdep.cpu.brand_string - sysctl machdep.cpu.core_count - fi - - name: Checkout repo - uses: actions/checkout@v3 - with: - submodules: true - - name: Create conda env - run: | - conda create -yn test-mps-ops-env python=3.11 - - name: Activate conda env - run: | - source activate base - conda activate test-mps-ops-env - - name: Install torch - run: | - conda run -n test-mps-ops-env pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" - - name: Print torch version - run: | - - conda run -n test-mps-ops-env python -c "import torch; print(torch.__version__)" - - name: Install requirements - run: | - source activate base - conda activate test-mps-ops-env - pip install -r dev-requirements.txt - pip install pyyaml importlib-metadata - - name: Print pip freeze - run: | - conda run -n test-mps-ops-env pip freeze - - name: Print current directory - run: | - conda run -n test-mps-ops-env python -c "import os; print(os.getcwd())" - - name: Build ao with experimental mps ops - run: | - source activate base - conda activate test-mps-ops-env - USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . - - name: Run mps tests - run: | - pushd torchao/experimental/ops/mps/test - conda run -n test-mps-ops-env python test_lowbit.py - conda run -n test-mps-ops-env python test_quantizer.py - popd diff --git a/CITATION.cff b/CITATION.cff index 60adc9a9c0..cdc582adea 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -4,6 +4,6 @@ message: "If you use this software, please cite it as below." type: software authors: - given-names: "torchao maintainers and contributors" -url: "https//github.com/pytorch/torchao" +url: "https//github.com/pytorch/ao" license: "BSD-3-Clause" date-released: "2024-10-25" diff --git a/LICENSE b/LICENSE index 56f4d62a47..44018e4daf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,6 @@ Copyright 2023 Meta +All contributions by Arm: +Copyright (c) 2024-2025 Arm Limited and/or its affiliates Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/README.md b/README.md index d269c3974e..9330900300 100644 --- a/README.md +++ b/README.md @@ -11,20 +11,21 @@
-[![](https://img.shields.io/badge/CodeML_%40_ICML-2025-blue)](https://codeml-workshop.github.io/codeml2025/) +[![](https://img.shields.io/badge/CodeML_%40_ICML-2025-blue)](https://openreview.net/attachment?id=HpqH0JakHf&name=pdf) [![](https://dcbadge.vercel.app/api/server/gpumode?style=flat&label=TorchAO%20in%20GPU%20Mode)](https://discord.com/channels/1189498204333543425/1205223658021458100) [![](https://img.shields.io/github/contributors-anon/pytorch/ao?color=yellow&style=flat-square)](https://github.com/pytorch/ao/graphs/contributors) [![](https://img.shields.io/badge/torchao-documentation-blue?color=DE3412)](https://docs.pytorch.org/ao/stable/index.html) [![license](https://img.shields.io/badge/license-BSD_3--Clause-lightgrey.svg)](./LICENSE) -[Latest News](#-latest-news) | [Overview](#-overview) | [Quick Start](#-quick-start) | [Integrations](#-integrations) | [Inference](#-inference) | [Training](#-training) | [Videos](#-videos) | [Citation](#-citation) +[Latest News](#-latest-news) | [Overview](#-overview) | [Quick Start](#-quick-start) | [Installation](#-installation) | [Integrations](#-integrations) | [Inference](#-inference) | [Training](#-training) | [Videos](#-videos) | [Citation](#-citation)
## 📣 Latest News -- [Jun 25] Our [TorchAO paper](https://codeml-workshop.github.io/codeml2025/) was accepted to CodeML @ ICML 2025! +- [Jun 25] Our [TorchAO paper](https://openreview.net/attachment?id=HpqH0JakHf&name=pdf) was accepted to CodeML @ ICML 2025! +- [May 25] QAT is now integrated into [Axolotl](https://github.com/axolotl-ai-cloud/axolotl) for fine-tuning ([docs](https://docs.axolotl.ai/docs/qat.html))! - [Apr 25] Float8 rowwise training yielded [1.34-1.43x training speedup](https://pytorch.org/blog/accelerating-large-scale-training-and-convergence-with-pytorch-float8-rowwise-on-crusoe-2k-h200s/) at 2k H100 GPU scale - [Apr 25] TorchAO is added as a [quantization backend to vLLM](https://docs.vllm.ai/en/latest/features/quantization/torchao.html) ([docs](https://docs.vllm.ai/en/latest/features/quantization/torchao.html))! - [Mar 25] Our [2:4 Sparsity paper](https://openreview.net/pdf?id=O5feVk7p6Y) was accepted to SLLM @ ICLR 2025! @@ -59,7 +60,7 @@ Check out our [docs](https://docs.pytorch.org/ao/main/) for more details! From the team that brought you the fast series: * 9.5x inference speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) * 10x inference speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2) -* 3x inference speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3) +* 3x inference speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3) (new: [flux-fast](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/)) * 2.7x inference speedup for FAIR’s Seamless M4T-v2 model with [seamlessv2-fast](https://pytorch.org/blog/accelerating-generative-ai-4/) @@ -70,27 +71,10 @@ First, install TorchAO. We recommend installing the latest stable version: pip install torchao ``` -
- Other installation options - - ``` - # Nightly - pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 - - # Different CUDA versions - pip install torchao --index-url https://download.pytorch.org/whl/cu126 # CUDA 12.6 - pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only - - # For developers - USE_CUDA=1 python setup.py develop - ``` - -
- Quantize your model weights to int4! ``` from torchao.quantization import Int4WeightOnlyConfig, quantize_ -quantize_(model, Int4WeightOnlyConfig(group_size=32)) +quantize_(model, Int4WeightOnlyConfig(group_size=32, version=1)) ``` Compared to a `torch.compiled` bf16 baseline, your quantized model should be significantly smaller and faster on a single A100 GPU: ``` @@ -105,14 +89,41 @@ speedup: 6.9x For the full model setup and benchmark details, check out our [quick start guide](https://docs.pytorch.org/ao/stable/quick_start.html). Alternatively, try quantizing your favorite model using our [HuggingFace space](https://huggingface.co/spaces/pytorch/torchao-my-repo)! +## 🛠 Installation + +To install the latest stable version: +``` +pip install torchao +``` + +
+ Other installation options + + ``` + # Nightly + pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + + # Different CUDA versions + pip install torchao --index-url https://download.pytorch.org/whl/cu126 # CUDA 12.6 + pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only + + # For developers + USE_CUDA=1 python setup.py develop + USE_CPP=0 python setup.py develop + ``` +
+ +Please see the [torchao compability table](https://github.com/pytorch/ao/issues/2919) for version requirements for dependencies. + ## 🔗 Integrations TorchAO is integrated into some of the leading open-source libraries including: * HuggingFace transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865) * HuggingFace diffusers best practices with `torch.compile` and TorchAO in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md) +* HuggingFace PEFT for LoRA using TorchAO as their [quantization backend](https://huggingface.co/docs/peft/en/developer_guides/quantization#torchao-pytorch-architecture-optimization) * Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) -* TorchTune for our [QLoRA](https://docs.pytorch.org/torchtune/main/tutorials/qlora_finetune.html), [QAT](https://docs.pytorch.org/torchtune/main/recipes/qat_distributed.html), and [float8 quantized fine-tuning](https://github.com/pytorch/torchtune/pull/2546) recipes +* TorchTune for our NF4 [QLoRA](https://docs.pytorch.org/torchtune/main/tutorials/qlora_finetune.html), [QAT](https://docs.pytorch.org/torchtune/main/recipes/qat_distributed.html), and [float8 quantized fine-tuning](https://github.com/pytorch/torchtune/pull/2546) recipes * TorchTitan for [float8 pre-training](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md) * VLLM for LLM serving: [usage](https://docs.vllm.ai/en/latest/features/quantization/torchao.html), [detailed docs](https://docs.pytorch.org/ao/main/torchao_vllm_integration.html) * SGLang for LLM serving: [usage](https://docs.sglang.ai/backend/server_arguments.html#server-arguments) and the major [PR](https://github.com/sgl-project/sglang/pull/1341). @@ -133,7 +144,7 @@ Quantize any model with `nn.Linear` layers in just one line (Option 1), or load ```python from torchao.quantization.quant_api import quantize_, Int4WeightOnlyConfig -quantize_(model, Int4WeightOnlyConfig(group_size=128, use_hqq=True)) +quantize_(model, Int4WeightOnlyConfig(group_size=128, use_hqq=True, version=1)) ``` #### Option 2: HuggingFace Integration @@ -143,12 +154,12 @@ from transformers import TorchAoConfig, AutoModelForCausalLM from torchao.quantization.quant_api import Int4WeightOnlyConfig # Create quantization configuration -quantization_config = TorchAoConfig(quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True)) +quantization_config = TorchAoConfig(quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True, version=1)) # Load and automatically quantize quantized_model = AutoModelForCausalLM.from_pretrained( "microsoft/Phi-4-mini-instruct", - torch_dtype="auto", + dtype="auto", device_map="auto", quantization_config=quantization_config ) @@ -169,12 +180,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup* Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/): ```python -from torchao.quantization import quantize_ -from torchao.quantization.qat import FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) -qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), -quantize_(my_model, qat_config) +from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig +from torchao.quantization.qat import QATConfig + +# prepare +base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) +quantize_(my_model, QATConfig(base_config, step="prepare")) + +# train model (not shown) + +# convert +quantize_(my_model, QATConfig(base_config, step="convert")) ``` Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py). @@ -268,7 +284,7 @@ If you find the torchao library useful, please cite it in your work as below. @software{torchao, title={TorchAO: PyTorch-Native Training-to-Serving Model Optimization}, author={torchao}, - url={https://github.com/pytorch/torchao}, + url={https://github.com/pytorch/ao}, license={BSD-3-Clause}, month={oct}, year={2024} diff --git a/benchmarks/_models/eval_hf_models.py b/benchmarks/_models/eval_hf_models.py new file mode 100644 index 0000000000..3cd6887ab6 --- /dev/null +++ b/benchmarks/_models/eval_hf_models.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import itertools +import subprocess + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +from benchmarks.microbenchmarks.utils import string_to_config +from torchao.quantization import * # noqa: F401, F403 + + +def quantize_model_and_save(model_id, quant_config, output_dir="results"): + """Quantize the model and save it to the output directory.""" + print("Quantizing model with config: ", quant_config) + if quant_config is None: + quantization_config = None + else: + quantization_config = TorchAoConfig(quant_type=quant_config) + quantized_model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + dtype=torch.bfloat16, + quantization_config=quantization_config, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + quantized_model.save_pretrained(output_dir, safe_serialization=False) + tokenizer.save_pretrained(output_dir, safe_serialization=False) + return quantized_model, tokenizer + + +def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8): + """Run the lm_eval command using subprocess.""" + tasks_str = ",".join(tasks_list) + command = [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"pretrained={model_dir}", + "--tasks", + f"{tasks_str}", + "--device", + f"{device}", + "--batch_size", + f"{batch_size}", + ] + subprocess.run(command, check=True) + + +def get_model_size_in_bytes(model, ignore_embeddings=False): + """ + Returns the model size in bytes. The option to ignore embeddings + is useful for models with disproportionately large embeddings compared + to other model parameters that get quantized/sparsified. + """ + + def flat_size(tensor): + if hasattr(tensor, "__tensor_flatten__"): + size = 0 + # 0th element is a list of attributes that + # hold tensors + for attr_name in tensor.__tensor_flatten__()[0]: + sub_tensor = getattr(tensor, attr_name) + size += flat_size(sub_tensor) + return size + else: + return tensor.numel() * tensor.element_size() + + model_size = 0 + for _, child in model.named_children(): + if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): + for p in itertools.chain( + child.parameters(recurse=False), child.buffers(recurse=False) + ): + model_size += flat_size(p) + model_size += get_model_size_in_bytes(child, ignore_embeddings) + return model_size + + +def run( + model_id, + quantization, + tasks, + device, + batch_size, + model_output_dir, +): + print(f"Running model {model_id} with quantization {quantization}") + model_name = model_id.split("/")[-1] + model_output_dir = f"quantized_model/{model_name}-{quantization}" + quant_config = string_to_config(quantization, None) + quantized_model, tokenizer = quantize_model_and_save( + model_id, quant_config=quant_config, output_dir=model_output_dir + ) + print("Compiling model ....") + quantized_model = torch.compile( + quantized_model, + mode="reduce-overhead", + fullgraph=True, + ) + run_lm_eval( + model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size + ) + model_size = get_model_size_in_bytes(quantized_model, ignore_embeddings=True) / 1e9 + print(f"Model size: {model_size:.2f} GB") + + +if __name__ == "__main__": + try: + import lm_eval # noqa: F401 + except: + print( + "lm_eval is required to run this script. Please install it using pip install lm-eval." + ) + exit(0) + + # Set up argument parser + parser = argparse.ArgumentParser( + description="Quantize a model and evaluate its throughput." + ) + parser.add_argument( + "--model_id", + type=str, + default="meta-llama/Llama-3.1-8B", + help="The model ID to use.", + ) + parser.add_argument( + "--quantization", + type=str, + default=None, + help="The quantization method to use.", + ) + parser.add_argument( + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", + ) + parser.add_argument( + "--device", type=str, default="cuda:0", help="Device to run the model on." + ) + parser.add_argument( + "--batch_size", type=str, default="auto", help="Batch size for lm_eval." + ) + parser.add_argument( + "--prompt", + type=str, + default="What are we having for dinner?", + help="Prompt for model throughput evaluation.", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=10, + help="Max new tokens to generate for throughput evaluation.", + ) + parser.add_argument( + "--num_runs", + type=int, + default=5, + help="Number of runs to average over for throughput evaluation.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="quantized_models", + help="Output directory for quantized model.", + ) + args = parser.parse_args() + + # Use parsed arguments + run( + model_id=args.model_id, + quantization=args.quantization, + tasks=args.tasks, + device=args.device, + batch_size=args.batch_size, + model_output_dir=args.output_dir, + ) diff --git a/benchmarks/_models/eval_hf_models.sh b/benchmarks/_models/eval_hf_models.sh new file mode 100644 index 0000000000..d71d16e422 --- /dev/null +++ b/benchmarks/_models/eval_hf_models.sh @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +# For llama3.1-8B +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-128 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128-4 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128-8 --tasks wikitext hellaswag + + +# For llama3.2-3B +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-128 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128-4 --tasks wikitext hellaswag +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128-8 --tasks wikitext hellaswag diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index cdc6f6fe5a..8eb6ddde11 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -10,56 +10,36 @@ import torch from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, _replace_with_custom_fn_if_matches_filter, - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, quantize_, ) from torchao.quantization.subclass import ( Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - unwrap_tensor_subclass, -) def _int8wo_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_woqtensors(mod, **kwargs) + quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False) def _int8da_int8w_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_( - mod, - int8_dynamic_activation_int8_weight(**kwargs), - set_inductor_config=False, - ) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_dqtensors(mod, **kwargs) + quantize_( + mod, + Int8DynamicActivationInt8WeightConfig(**kwargs), + set_inductor_config=False, + ) def _int4wo_api(mod, **kwargs): - if TORCH_VERSION_AT_LEAST_2_4: - kwargs_copy = kwargs.copy() - if "groupsize" in kwargs_copy: - kwargs_copy["group_size"] = kwargs_copy["groupsize"] - del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int4_woqtensors(mod, **kwargs) + kwargs_copy = kwargs.copy() + if "groupsize" in kwargs_copy: + kwargs_copy["group_size"] = kwargs_copy["groupsize"] + del kwargs_copy["groupsize"] + quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False) class ToyLinearModel(torch.nn.Module): @@ -95,11 +75,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs """ from torchao.quantization.quant_api import ( _get_subclass_inserter, - _in_features_greater_than_16, _is_linear, ) from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight + def _in_features_greater_than_16(mod, *args): + return hasattr(mod, "in_features") and mod.in_features > 16 + if filter_fn is None: filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( *args @@ -195,13 +177,12 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): ) -if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available(): +if __name__ == "__main__" and torch.cuda.is_available(): all_shapes = [ (20, 2048, 2048), ] print("_int8da_int8w_api") - from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( @@ -209,7 +190,6 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): ) print("_int8wo_api") - from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( @@ -217,8 +197,7 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): ) print("_int4wo_api") - kwargs = {"groupsize": 32} - from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors + kwargs = {"groupsize": 32, "version": 1} for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( diff --git a/benchmarks/benchmark_blockwise_scaled_linear_triton.py b/benchmarks/benchmark_blockwise_scaled_linear_triton.py index 809202170a..ffdd63ec8d 100644 --- a/benchmarks/benchmark_blockwise_scaled_linear_triton.py +++ b/benchmarks/benchmark_blockwise_scaled_linear_triton.py @@ -13,7 +13,7 @@ from triton.testing import do_bench from torchao.float8.float8_utils import compute_error - from torchao.prototype.blockwise_fp8.blockwise_quantization import ( + from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_quant, diff --git a/benchmarks/dashboard/ci_microbenchmark_runner.py b/benchmarks/dashboard/ci_microbenchmark_runner.py new file mode 100644 index 0000000000..29971692ba --- /dev/null +++ b/benchmarks/dashboard/ci_microbenchmark_runner.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +CI Microbenchmark Runner for PyTorch OSS Benchmark Database + +This script runs microbenchmarks for a given config file +and outputs results in the format required by the PyTorch OSS benchmark database. +It reuses functionality from benchmark_runner.py and only adds CI-specific code. + +Usage: + python ci_microbenchmark_runner.py --config benchmark_config.yml + +The YAML file should contain all necessary configuration parameters for the benchmarks. +""" + +import argparse +import json +import platform +from typing import Any, Dict, List + +import torch + +from benchmarks.microbenchmarks.benchmark_inference import run as run_inference +from benchmarks.microbenchmarks.benchmark_runner import ( + load_benchmark_configs, +) +from benchmarks.microbenchmarks.utils import clean_caches + + +def create_benchmark_result( + benchmark_name: str, + shape: List[int], + metric_name: str, + metric_values: List[float], + quant_type: str, + device: str, + torch_compile_mode: str, + metric_extra_info: Dict[str, Any] = {}, +) -> Dict[str, Any]: + """Create a benchmark result in the PyTorch OSS benchmark database format. + + Args: + benchmark_name: Name of the benchmark + shape: List of shape dimensions [M, K, N] + metric_name: Name of the metric + metric_values: List of metric values + quant_type: Quantization type + device: Device type (cuda/cpu) + + Returns: + Dictionary containing the benchmark result in the required format + """ + print( + f"Creating benchmark result for {benchmark_name} with shape {shape} and metric {metric_name}" + ) + + # Map device to benchmark device name + benchmark_device = ( + torch.cuda.get_device_name(0) + if device == "cuda" + else platform.processor() + if device == "cpu" + else "unknown" + ) + + # Format shape as M-K-N + mkn_name = f"{shape[0]}-{shape[1]}-{shape[2]}" if len(shape) == 3 else "unknown" + + return { + "benchmark": { + "name": "micro-benchmark api", + "mode": "inference", + "dtype": quant_type, + "extra_info": { + "device": device, + "arch": benchmark_device, + "torch_compile_mode": torch_compile_mode, + }, + }, + "model": { + "name": mkn_name, # name in M-K-N format + "type": "micro-benchmark custom layer", # type + "origins": ["torchao"], + }, + "metric": { + "name": f"{metric_name}", # name with unit + "benchmark_values": metric_values, # benchmark_values + "target_value": 0.0, # TODO: Will need to define the target value + "extra_info": { + **metric_extra_info, + }, + }, + "runners": [], + "dependencies": {}, + } + + +def run_ci_benchmarks(config_path: str) -> List[Dict[str, Any]]: + """Run benchmarks using configurations from YAML file and return results in OSS format. + + Args: + config_path: Path to the benchmark configuration file + + Returns: + List of benchmark results in the PyTorch OSS benchmark database format + """ + # Load configuration using existing function + configs = load_benchmark_configs(argparse.Namespace(config=config_path)) + results = [] + + # Run benchmarks for each config + for config in configs: + # Run benchmark using existing function + clean_caches() + result = run_inference(config) + + if result is not None: + # Create benchmark result in OSS format + speedup_result = create_benchmark_result( + benchmark_name="TorchAO Quantization Benchmark", + shape=[config.m, config.k, config.n], + metric_name="Fwd Speedup (x)", + metric_values=[result.compile_speedup_on_baseline], + quant_type=config.quantization, + device=config.device, + torch_compile_mode=config.torch_compile_mode, + ) + results.append(speedup_result) + baseline_time_result = create_benchmark_result( + benchmark_name="TorchAO Quantization Benchmark", + shape=[config.m, config.k, config.n], + metric_name="Bfloat16 Fwd Time (ms)", + metric_values=[result.baseline_model_compiled_inference_time_in_ms], + quant_type=config.quantization, + device=config.device, + torch_compile_mode=config.torch_compile_mode, + metric_extra_info={ + "unit": "ms", + }, + ) + results.append(baseline_time_result) + quantize_time_result = create_benchmark_result( + benchmark_name="TorchAO Quantization Benchmark", + shape=[config.m, config.k, config.n], + metric_name="Quantized Fwd Time (ms)", + metric_values=[result.quantized_model_compiled_inference_time_in_ms], + quant_type=config.quantization, + device=config.device, + torch_compile_mode=config.torch_compile_mode, + metric_extra_info={ + "unit": "ms", + }, + ) + results.append(quantize_time_result) + allocated_memory_result = create_benchmark_result( + benchmark_name="TorchAO Quantization Benchmark", + shape=[config.m, config.k, config.n], + metric_name="Allocated Memory (MB)", + metric_values=[result.memory_stats["allocated_bytes.all.peak"]], + quant_type=config.quantization, + device=config.device, + torch_compile_mode=config.torch_compile_mode, + metric_extra_info={ + "unit": "MB", + }, + ) + results.append(allocated_memory_result) + + return results + + +def main(): + torch.manual_seed(42) + parser = argparse.ArgumentParser( + description="Run microbenchmarks and output results in PyTorch OSS benchmark database format" + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to benchmark configuration file", + ) + parser.add_argument( + "--output", + type=str, + default="benchmark_results.json", + help="Path to output JSON file", + ) + args = parser.parse_args() + + # Run benchmarks + results = run_ci_benchmarks(args.config) + + # Save results to JSON file + with open(args.output, "w") as f: + json.dump(results, f, indent=2) + + print(f"Benchmark results saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dashboard/microbenchmark_quantization_config.yml b/benchmarks/dashboard/microbenchmark_quantization_config.yml new file mode 100644 index 0000000000..8156422668 --- /dev/null +++ b/benchmarks/dashboard/microbenchmark_quantization_config.yml @@ -0,0 +1,20 @@ +# Benchmark configuration for microbenchmarks +benchmark_mode: "inference" +quantization_config_recipe_names: # Will run a baseline inference for model by default, without quantization for comparison + - "int8wo" + - "int8dq" + - "float8dq-tensor" + - "float8dq-row" + - "float8wo" +output_dir: "benchmarks/microbenchmarks/results" +model_params: + - name: "small_bf16_linear" + matrix_shapes: + - name: "small_sweep" + min_power: 10 + max_power: 15 + high_precision_dtype: "torch.bfloat16" + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "linear" + enable_memory_profiler: true diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index a7b1e17934..6d55bcc173 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -23,7 +23,7 @@ ScalingType, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_tensor import ScaledMMConfig +from torchao.float8.float8_training_tensor import ScaledMMConfig # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 30ea2eab39..c6499e692d 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -10,45 +10,15 @@ import pandas as pd import torch import torch.nn as nn -import torch.utils.benchmark as benchmark from utils import ( - get_gpu_kernel_gemm_time_s, + do_benchmarks, get_name_to_shapes_iter, ) from torchao.ops import mx_fp4_bf16 from torchao.prototype.mx_formats.mx_tensor import to_mx from torchao.testing.training.roofline_utils import get_specs - - -def benchmark_fn_in_sec(f, *args, **kwargs): - # Manual warmup - for _ in range(4): - f(*args, **kwargs) - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - measurement = t0.blocked_autorange() - return measurement.mean - - -def do_benchmarks( - tops, - peak_tops, - use_gpu_kernel_time, - f, - *args, - **kwargs, -): - if use_gpu_kernel_time: - # just the gemm GPU kernel - time_sec = get_gpu_kernel_gemm_time_s(f, *args, **kwargs) - else: - # e2e time including kernel launch overhead - time_sec = benchmark_fn_in_sec(f, *args, **kwargs) - tops_sec = float(tops) / time_sec - pct_top_peak = tops_sec / peak_tops - return time_sec, tops_sec, pct_top_peak +from torchao.utils import is_MI300 @torch.inference_mode() @@ -76,7 +46,8 @@ def run( specs = get_specs() bf16_peak_tops = specs["bf16_peak_tops"] fp8_peak_tops = specs["fp8_peak_tops"] - fp4_peak_tops = specs["fp4_peak_tops"] + fp4_peak_tops = specs.get("fp4_peak_tops", 0.0) # only on sm120 + print(f"recipe: {recipe}") print(f"gpu_name: {torch.cuda.get_device_name(0)}") print( f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}, fp4 {fp4_peak_tops:.2e}" @@ -87,8 +58,8 @@ def run( "M", "K", "N", + "ref_time_s", "time_s", - "speedup", "fp8_speedup", ) results = [] @@ -137,7 +108,10 @@ def run( else: # raw float8 matmul (upper bound for what we can achive in eager mode) # TODO(future): add e5m2 - d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype + e4m3_dtype = torch.float8_e4m3fn + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + e4m3_dtype = torch.float8_e4m3fnuz + d1, d2, d3 = e4m3_dtype, e4m3_dtype, dtype A = A_hp.to(d1) B = B_hp_t.to(d2).contiguous().T peak_tops = fp8_peak_tops @@ -175,6 +149,16 @@ def do_matmul_nvfp4(A, B): nonlocal scale_b return torch._scaled_mm(A, B, scale_a, scale_b, out_dtype=dtype) + def do_grouped_mm(A, B): + return torch._grouped_mm(A, B, use_fast_accum=fast_accum) + + def do_scaled_grouped_mm(A, B): + nonlocal scale_a + nonlocal scale_b + return torch._scaled_grouped_mm( + A, B, scale_a, scale_b, use_fast_accum=fast_accum + ) + if recipe == "mxfp4_cutlass": do_matmul = do_matmul_mxfp4 elif recipe == "nvfp4": diff --git a/benchmarks/float8/bench_padding.py b/benchmarks/float8/bench_padding.py index eed8a5b542..62a161637b 100644 --- a/benchmarks/float8/bench_padding.py +++ b/benchmarks/float8/bench_padding.py @@ -12,7 +12,7 @@ from torch._inductor.utils import do_bench_using_profiling from tqdm import tqdm -from torchao.float8.float8_tensor import ( +from torchao.float8.float8_training_tensor import ( GemmInputRole, LinearMMConfig, ScaledMMConfig, diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py new file mode 100644 index 0000000000..121b9fc7d3 --- /dev/null +++ b/benchmarks/float8/float8_inference_roofline.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +This is a script to estimate the benefit from converting a `torch.nn.Linear` +layer to float8 given a single saturated GPU, by estimating the difference +in e2e GPU kernel time between: +1. bf16 gemms in fwd and +2. float8 gemms in fwd and float8 overhead + +The gemm times are estimated either from direct measurements via benchmarks, +or with a roofline estimation based on TOPS and peak compute bandwidth of an +NVIDIA H100 or B200. + +The float8 overhead times are estimated by counting memory reads and writes +based on the specified float8 scaling, and estimating that we can achieve +a certain % of machine peak memory bandwidth when performing these reads and writes. +""" + +import copy +from typing import Optional + +import fire +import pandas as pd +import sympy +import torch +import torch.nn as nn +import tqdm +from torch.profiler import ProfilerActivity, profile +from utils import ( + get_gpu_kernel_gemm_time_s, + get_name_to_shapes_iter, + profiler_output_to_filtered_time_by_kernel_name, +) + +import torchao +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + PerRow, + quantize_, +) +from torchao.quantization.quantize_.common import KernelPreference +from torchao.testing.training.roofline_utils import ( + get_inference_float8_mem_sympy, + get_inference_gemm_time_sympy, +) +from torchao.utils import is_MI300 + + +@torch.no_grad() +def get_gpu_kernel_time(m, x): + # warm up + for _ in range(2): + __ = m(x) + + # capture a profiling run + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + n_iter = 5 + with profile(activities=activities) as prof: + for _ in range(n_iter): + __ = m(x) + torch.cuda.synchronize() + # get the gpu kernel time and aggregate it + num_leaf_tensors = 1 + len(list(m.parameters())) + ref_times = profiler_output_to_filtered_time_by_kernel_name( + prof, n_iter, num_leaf_tensors + ) + total_time_s = sum(v for v in ref_times.values()) / 1e6 / n_iter + return total_time_s + + +def get_gemm_times( + M: int, + K: int, + N: int, + fast_accum: bool, + float8_recipe_name: Optional[str], +): + device = torch.device("cuda") + + # bf16 time + x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) + # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() + w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) + + e4m3_dtype = torch.float8_e4m3fn + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + e4m3_dtype = torch.float8_e4m3fnuz + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 + A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) + B = ( + torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) + .view(d2) + .t() + .contiguous() + .t() + ) + if float8_recipe_name in ("rowwise"): + scale_a = torch.ones(M, 1, device=device) + scale_b = torch.ones(1, N, device=device) + else: + assert False, "unsupported" + + def do_matmul(A, B): + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) + + f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) + + return bf16_time_s, f8_time_s + + +def run( + outfile: str, + do_benchmarks: bool = True, + shape_gen_name: str = "pow2", + n_limit: Optional[int] = None, + float8_recipe_name: Optional[str] = None, +): + """ + Args: + * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked + * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` + * `n_limit (optional)`: if specified, only runs `n_limit` iterations + """ + + assert float8_recipe_name is not None, "unsupported" + + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"torch version: {torch.__version__}") + print(f"torchao version: {torchao.__version__}") + print(f"do_benchmarks: {do_benchmarks}") + print(f"shape_gen_name: {shape_gen_name}") + print(f"float8_recipe_name: {float8_recipe_name}") + + M, K, N = sympy.symbols("M K N") + + fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( + M, + K, + N, + float8_recipe_name, + ) + bf16_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.bfloat16, None, None + ) + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float8_e4m3fn, float8_recipe_name, None + ) + print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) + print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) + print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) + print() + + headers = [ + "fwd_M", + "fwd_K", + "fwd_N", + # roofline - gemm time (fwd + bwd, 3 gemms) + "r_bf16_gemm_s", + "r_fp8_gemm_s", + # roofline - fp8 overhead time (by counting reads/writes in the ideal case) + "r_fp8_ovhd_s", + # roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid) + "r_fp8_gemm_and_ovhd_s", + "r_fp8_gemm_and_ovhd_spdp", + # benchmarks - gemm time (fwd + bwd, 3 gemms) + "b_bf16_gemm_s", + "b_fp8_gemm_s", + # benchmarks - e2e LNLinearSigmoid time fwd + bwd + "b_bf16_e2e_s", + "b_fp8_e2e_s", + # note that e2e speedup is not the same as the roofline speedup: + # 1. roofline speedup: (bf16_gemm_time) / (fp8_gemm_time + fp8_ovhd_time) + # 2. e2e speedup: (ln + bf16_gemm_time + sigmoid) / (ln + fp8_gemm_time + fp8_ovhd_time + sigmoid) + # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple + # we don't break them out and don't have a roofline for them. + "b_fp8_e2e_spdp", + # how well benchmarked gemms match roofline predicted gemms + "rb_bf16_gemm_ratio", + "rb_fp8_gemm_ratio", + ] + results = [] + + name_to_shapes = get_name_to_shapes_iter(shape_gen_name, None, None, None) + + for idx, (name, (M_val, K_val, N_val)) in enumerate(tqdm.tqdm(name_to_shapes)): + if n_limit is not None and idx >= n_limit: + break + + # use roofline model to estimate gemm time + # note: cast from sympy.core.numbers.Float to float to make pandas formatting work + r_bf16_gemm_time_s = float( + bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) + r_fp8_gemm_time_s = float( + fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) + + # if enabled, also measured observed gemm time + b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 + rb_bf16_gemm_ratio = -1 + rb_fp8_gemm_ratio = -1 + + if do_benchmarks: + # TODO(future): make the bf16 gemm times exactly match the e2e + # benchmarks, there is a slight deviation, probably related to gemm + # operand memory formats/transpositions below not exactly matching + # what PyTorch core is doing for `torch.mm` + # input @ weight_t = output + bf16_g1, f8_g1 = get_gemm_times( + M_val, + K_val, + N_val, + True, + float8_recipe_name, + ) + b_bf16_gemm_time_s = bf16_g1 + b_fp8_gemm_time_s = f8_g1 + rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s + rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s + + # note: cast from sympy.core.numbers.Float to float to make pandas formatting work + r_fp8_ovhd_time_s = float( + fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) + + b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 + if do_benchmarks: + # create the model + m_orig = ( + nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16() + ) + x = torch.randn( + M_val, K_val, dtype=torch.bfloat16, device="cuda" + ).requires_grad_() + + # get the bf16 gpu kernel time + torch._dynamo.reset() + m_bf16 = torch.compile(copy.deepcopy(m_orig)) + b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x) + + # get the float8 dynamic scaling gpu kernel time + torch._dynamo.reset() + + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + # for now, use TORCH. In the future might be interesting + # to benchmark AUTO and FBGEMM. + kernel_preference=KernelPreference.TORCH, + ) + m_fp8_dyn = copy.deepcopy(m_orig) + quantize_(m_fp8_dyn, config) + + m_fp8_dyn = torch.compile(m_fp8_dyn) + b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x) + + results.append( + [ + M_val, + K_val, + N_val, + # roofline - gemm + r_bf16_gemm_time_s, + r_fp8_gemm_time_s, + # roofline - fp8 overhead + r_fp8_ovhd_time_s, + # roofline - gemm + overhead, and speedup + r_fp8_gemm_time_s + r_fp8_ovhd_time_s, + r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s), + # benchmarks - gemm + b_bf16_gemm_time_s, + b_fp8_gemm_time_s, + # benchmarks - e2e, and speedup + b_bf16_e2e_time_s, + b_fp8_e2e_time_s, + b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20), + # gemm ratios + rb_bf16_gemm_ratio, + rb_fp8_gemm_ratio, + ] + ) + + pd.set_option("display.precision", 2) + df = pd.DataFrame(results, columns=headers) + print(df) + df.to_csv(outfile) + print("done") + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 5a8419cde8..f37a932822 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -48,7 +48,6 @@ import sympy import torch import torch.nn as nn -import torch.utils.benchmark as benchmark import tqdm from torch.profiler import ProfilerActivity, profile from utils import ( @@ -57,6 +56,7 @@ profiler_output_to_filtered_time_by_kernel_name, ) +import torchao from torchao.float8 import ( Float8LinearConfig, convert_to_float8_training, @@ -67,6 +67,7 @@ get_float8_mem_sympy, get_gemm_time_sympy, ) +from torchao.utils import is_MI300 class LNLinearSigmoid(torch.nn.Module): @@ -83,20 +84,6 @@ def forward(self, x): return x -# TODO(next): hook this up - - -def benchmark_fn_in_sec(f, *args, **kwargs): - # Manual warmup - for _ in range(4): - f(*args, **kwargs) - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - measurement = t0.blocked_autorange() - return measurement.mean - - def get_gpu_kernel_time(m, x, grad_output): # warm up for _ in range(2): @@ -175,7 +162,12 @@ def get_gemm_times( if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight": f8_time_s = bf16_time_s else: - d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16 + e4m3_dtype = torch.float8_e4m3fn + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + e4m3_dtype = torch.float8_e4m3fnuz + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 + # TODO(future PR): create more realistic tensors here for more accurate + # gemm benchmarking A = torch.zeros(M, K, device=device, dtype=d1) B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() if float8_recipe_name == "tensorwise": @@ -184,7 +176,7 @@ def get_gemm_times( elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"): scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) - elif mx_recipe_name == "mxfp8_cublas": + elif mx_recipe_name in ("mxfp8_cublas", "mxfp8_cublas_rceil"): scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) else: @@ -232,6 +224,8 @@ def run( float8_recipe_name = "tensorwise" print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"torch version: {torch.__version__}") + print(f"torchao version: {torchao.__version__}") print(f"do_benchmarks: {do_benchmarks}") print(f"shape_gen_name: {shape_gen_name}") print(f"float8_recipe_name: {float8_recipe_name}") @@ -248,9 +242,11 @@ def run( mx_recipe_name, enable_fusion_modeling, ) - bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None, None) + bf16_gemm_time_sympy = get_gemm_time_sympy( + M, K, N, torch.bfloat16, None, None, None + ) fp8_gemm_time_sympy = get_gemm_time_sympy( - M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name + M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name, None ) print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) diff --git a/benchmarks/float8/training/torchtitan_benchmark.sh b/benchmarks/float8/training/llama3.sh similarity index 72% rename from benchmarks/float8/training/torchtitan_benchmark.sh rename to benchmarks/float8/training/llama3.sh index c1995ee39a..caab96662a 100755 --- a/benchmarks/float8/training/torchtitan_benchmark.sh +++ b/benchmarks/float8/training/llama3.sh @@ -17,9 +17,10 @@ LOG_FILE="/tmp/float8_training_log.txt" # validate user has specified torchtitan root directory if [ -z "${TORCHTITAN_ROOT}" ]; then echo "Error: TORCHTITAN environment variable is not set. Please set it before running this script." - echo "Usage: TORCHTITAN_ROOT= ./float8_training_benchmark.sh" + echo "Usage: TORCHTITAN_ROOT= ./llama3.sh" echo "Optional parameters configurable via environment variables:" echo " * FLOAT8_RECIPE_WITH_BEST_SETTINGS: "rowwise" or "tensorwise". if set, use float8 training in torchtitan with the specified recipe, including the additional settings which are optimal for that recipe. otherwise, use bf16 mixed precision training." + echo " * MX_RECIPE: any valid MX recipe name. Note: only one of FLOAT8_RECIPE_WITH_BEST_SETTINGS and MX_RECIPE can be set." echo " * LOCAL_BATCH_SIZE: defaults to 1." echo " * STEPS: defaults to 100." echo " * EXTRA_ARGS: additional arguments to pass to the torchtitan training script." @@ -27,12 +28,19 @@ if [ -z "${TORCHTITAN_ROOT}" ]; then fi # validate recipe name -if [ -n "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" ]; then +if [ -n "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" ] && [ -n "${MX_RECIPE}" ]; then + echo "Error: both FLOAT8_RECIPE_WITH_BEST_SETTINGS and MX_RECIPE are set, please only set one of them." >&2 + exit 1 +elif [ -n "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" ]; then if [ "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" == "tensorwise" ]; then FLOAT8_ARGS="--model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp" else FLOAT8_ARGS="--model.converters="float8" --float8.recipe_name=${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" fi +elif [ -n "${MX_RECIPE}" ]; then + FLOAT8_ARGS="--model.converters="mx" --mx.recipe_name=${MX_RECIPE}" +else + FLOAT8_ARGS="" fi @@ -45,13 +53,13 @@ cd ${TORCHTITAN_ROOT} echo "float8 args: ${FLOAT8_ARGS}" # run the command with the specified arguments -CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.local-batch-size=${LOCAL_BATCH_SIZE} --training.compile ${FLOAT8_ARGS} ${EXTRA_ARGS} 2>&1 | tee ${LOG_FILE} +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.local-batch-size=${LOCAL_BATCH_SIZE} --compile.enable ${FLOAT8_ARGS} ${EXTRA_ARGS} 2>&1 | tee ${LOG_FILE} # return to original working directory cd $original_dir # parse logs to calculate top line metrics -python parse_torchtitan_logs.py --log-file ${LOG_FILE} +python benchmarks/float8/training/parse_torchtitan_logs.py --log-file ${LOG_FILE} # clean up logs rm ${LOG_FILE} diff --git a/benchmarks/float8/training/llama4.sh b/benchmarks/float8/training/llama4.sh new file mode 100755 index 0000000000..216d1f918a --- /dev/null +++ b/benchmarks/float8/training/llama4.sh @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +#!/bin/bash +# This script can be used to launch a torchtitan float8 training run +# with the given parameters, + +# script arguments +LOCAL_BATCH_SIZE=${LOCAL_BATCH_SIZE:-1} +STEPS=${STEPS:-100} + +# temporary log file which is deleted after performance data is parsed out and metrics are calculated. +LOG_FILE="/tmp/float8_training_log.txt" + +# validate user has specified torchtitan root directory +if [ -z "${TORCHTITAN_ROOT}" ]; then + echo "Error: TORCHTITAN environment variable is not set. Please set it before running this script." + echo "Usage: TORCHTITAN_ROOT= ./torchtitan_llama4.sh" + echo " * EXTRA_ARGS: additional arguments to pass to the torchtitan training script." + exit 1 +fi + +# remember current directory to return to it later +original_dir=$(pwd) + +# navigate to torchtitan root dir +cd ${TORCHTITAN_ROOT} + +# run the command with the specified arguments +CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ${TORCHTITAN_ROOT}/run_train.sh ${EXTRA_ARGS} 2>&1 | tee ${LOG_FILE} + +# return to original working directory +cd $original_dir + +# parse logs to calculate top line metrics +python parse_torchtitan_logs.py --log-file ${LOG_FILE} + +# clean up logs +rm ${LOG_FILE} diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index 6c3051937d..55c9ad21a3 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -9,6 +9,7 @@ import re from typing import Optional +import torch.utils.benchmark as benchmark from torch.profiler import ProfilerActivity, profile @@ -211,6 +212,42 @@ def get_name_to_shapes_iter( raise AssertionError(f"unknown shape_gen_name {shape_gen_name}") +def get_name_to_moe_shapes_iter( + shape_gen_name: str, + M: Optional[int] = None, + K: Optional[int] = None, + N: Optional[int] = None, + E: Optional[int] = None, +): + M = 16640 if M is None else M + if shape_gen_name == "llama4_17bx16e": + # num_experts=16, dim=5120 + names_to_shapes = { + # M, K, N, E + "moe.experts.w1": (M, 5120, 8192, 16), + "moe.experts.w2": (M, 8192, 5120, 16), + } + return names_to_shapes.items() + elif shape_gen_name == "llama4_17bx128e": + # num_experts=128, dim=5120 + names_to_shapes = { + # M, K, N, E + "moe.experts.w1": (M, 5120, 4 * 5120, 128), + "moe.experts.w2": (M, 4 * 5120, 5120, 128), + } + return names_to_shapes.items() + elif shape_gen_name == "custom": + assert M is not None and K is not None and N is not None and E is not None, ( + "M, K, N, E must be specified for custom shape_gen" + ) + name_to_shapes = { + 1: (M, K, N, E), + } + return name_to_shapes.items() + + raise AssertionError(f"unknown shape_gen_name {shape_gen_name}") + + # copy-pasta from https://github.com/vkuzo/pytorch_scripts/blob/main/add_inductor_metadata_to_perf_trace.py def update_triton_kernels_in_prof_chome_trace_with_torch_logs( perf_trace_file: str, @@ -351,7 +388,43 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs): prof, n_iter, num_leaf_tensors=0 ) # there is only 1 key, aten::mm or aten::_scaled_mm, with unit nanoseconds - assert len(data) == 1 + assert len(data) == 1, f"unexpected data: {data}" key, value = next(iter(data.items())) - assert key in ("aten::mm", "aten::_scaled_mm", "torchao::mx_fp4_bf16") + assert key in ( + "aten::mm", + "aten::_scaled_mm", + "torchao::mx_fp4_bf16", + "aten::_grouped_mm", + "aten::_scaled_grouped_mm", + ) return value / 1e6 / n_iter + + +def benchmark_fn_in_sec(f, *args, **kwargs): + # Manual warmup + for _ in range(4): + f(*args, **kwargs) + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + measurement = t0.blocked_autorange() + return measurement.mean + + +def do_benchmarks( + tops, + peak_tops, + use_gpu_kernel_time, + f, + *args, + **kwargs, +): + if use_gpu_kernel_time: + # just the gemm GPU kernel + time_sec = get_gpu_kernel_gemm_time_s(f, *args, **kwargs) + else: + # e2e time including kernel launch overhead + time_sec = benchmark_fn_in_sec(f, *args, **kwargs) + tops_sec = float(tops) / time_sec + pct_top_peak = tops_sec / peak_tops + return time_sec, tops_sec, pct_top_peak diff --git a/benchmarks/inference/bench_float8_inference.py b/benchmarks/inference/bench_float8_inference.py new file mode 100644 index 0000000000..593e2425d7 --- /dev/null +++ b/benchmarks/inference/bench_float8_inference.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import fire +import torch +import torch.nn as nn +from torch._inductor.utils import do_bench_using_profiling + +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + PerRow, + quantize_, +) + + +def benchmark_fn_in_usec(f, *args, **kwargs): + no_args = lambda: f(*args, **kwargs) + time = do_bench_using_profiling(no_args) + return time * 1e3 + + +def run(torch_compile_mode: str = "default"): + M, K, N = 1024, 2048, 4096 + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + m = nn.Sequential(nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)) + quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + m = torch.compile(m, mode=torch_compile_mode) + # warm up + with torch.no_grad(): + _ = m(x) + # measure + with torch.no_grad(): + time_us = benchmark_fn_in_usec(m, x) + print("time_us", time_us) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md index f300bbab23..eb9564d7d7 100644 --- a/benchmarks/microbenchmarks/README.md +++ b/benchmarks/microbenchmarks/README.md @@ -71,6 +71,7 @@ Currently, quantization string is in same format as the one being passed in llam - `int8wo`: 8-bit weight-only quantization - `int4wo-{group_size}`: 4-bit weight-only quantization with specified group size - `int4wo-{group_size}-hqq`: 4-bit weight-only quantization with HQQ +- `gemlitewo-{bit_width}-{group_size}`: 4 or 8 bit integer quantization and utilizes the gemlite triton kernel ### Model Types - `linear`: Simple linear layer diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index 4ea5d05105..4a6525d52d 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -13,6 +13,7 @@ import os from copy import deepcopy from pathlib import Path +from typing import Dict, Tuple import torch @@ -34,15 +35,72 @@ create_model_and_input_data, ) +# ----------------------------------------------------------------------------- +# Baseline caching +# +# ``_BASELINE_CACHE`` maps a unique key constructed using _make_cache_key(config) -> (model_type, m, k, n, high_precision_dtype, device, torch_compile_mode) to a tuple +# ``(eager_baseline_time, compile_baseline_time)``. See ``_make_cache_key`` for the key +# construction. Users should not access this cache directly; it is +# internal to this module. +# Eg: (linear, 1024, 1024, 1024, torch.bfloat16, cuda, default) -> (95.00, 56.00) +# The cache is used to store the baseline inference time for a given configuration, which is further used to calculate speedup metrics. +# This helps in removing multiple baseline calculations, which in turn helps in reducing the benchmarking time. +# ----------------------------------------------------------------------------- + +_BASELINE_CACHE: Dict[Tuple, Tuple[float, float]] = {} + + +def _make_cache_key(config: BenchmarkConfig) -> Tuple: + """Create a key for caching based on benchmark configuration. + + Parameters that affect baseline performance are included: + + * model type (e.g. ``linear`` or ``transformer_block``) + * shape dimensions (m, k, n) + * high precision dtype (bf16, fp16, etc.) + * device (cuda, cpu, mps) + * compile settings (whether compile is enabled and compile mode) + + Sparsity and quantization settings are deliberately excluded + because the baseline (non‑quantized, non‑sparse) performance is + independent of those attributes. + """ + return ( + config.model_type, + config.m, + config.k, + config.n, + config.high_precision_dtype, + config.device, + config.torch_compile_mode, + ) + def run(config: BenchmarkConfig) -> BenchmarkResult: - """Run inference benchmarks""" + """ + Run inference benchmarks. + + The function first checks if a baseline for the given configuration + already exists in the internal cache. If not, it measures the baseline + inference time and stores the result. When the baseline is cached, + the function reuses the cached baselines to calculate speedup metrics. + + Args: + config (BenchmarkConfig): Benchmark configuration. + + Returns: + BenchmarkResult: Result of the benchmark. + """ try: clean_caches() # Clean caches # Create output directory if it doesn't exist Path(config.output_dir).mkdir(parents=True, exist_ok=True) + # Prepare result container + result = BenchmarkResult(config=config) + + # Create model and input data base_model, input_data = create_model_and_input_data( config.model_type, config.m, @@ -51,28 +109,47 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: high_precision_dtype=config.high_precision_dtype, device=config.device, ) - # Copy base model for quantizing - m_copy = deepcopy(base_model) - # Run benchmarks - result = BenchmarkResult(config=config) + # Generate a cache key for the current configuration + cache_key = _make_cache_key(config) - # Store result in model for memory profiling - base_model._benchmark_result = result - - # Run baseline benchmarking - base_model = base_model.eval().to(config.device) - if config.use_torch_compile: - print("Compiling baseline model....") - base_model = torch.compile( - base_model, mode=config.torch_compile_mode, fullgraph=True + # Check if the baseline for this configuration has been computed + if cache_key not in _BASELINE_CACHE: + # Switch model to eval and move to device + m_copy = deepcopy(base_model) + m_copy = m_copy.eval().to(config.device) + print("Benchmarking eager baseline inference.....") + eager_baseline_time = model_inference_time_in_ms( + model=m_copy, input_data=input_data + ) + + print("Benchmarking compile baseline inference.....") + m_copy = torch.compile( + m_copy, mode=config.torch_compile_mode, fullgraph=True + ) + compile_baseline_time = model_inference_time_in_ms( + model=m_copy, input_data=input_data ) - # Benchmark time to run an inference call for baseline model - print("Benchmarking baseline inference.....") - result.baseline_inference_time_in_ms = model_inference_time_in_ms( - model=base_model, input_data=input_data - ) + # Store uncompiled model, input and baseline time + _BASELINE_CACHE[cache_key] = (eager_baseline_time, compile_baseline_time) + + result.baseline_model_eager_inference_time_in_ms = eager_baseline_time + result.baseline_model_compiled_inference_time_in_ms = compile_baseline_time + else: + # Retrieve cached values + cached_eager_time, cached_compile_time = _BASELINE_CACHE[cache_key] + result.baseline_model_eager_inference_time_in_ms = cached_eager_time + result.baseline_model_compiled_inference_time_in_ms = cached_compile_time + + # At this point, ``base_model`` is an uncompiled model ready for quantization, + # and ``input_data`` is the corresponding input tensor. The baseline time + # has been stored in ``result.baseline_inference_time_in_ms``. + + # Copy base model for quantizing/sparsifying + m_copy = deepcopy(base_model) + + # Determine quantization/sparsity configuration ao_base_config = string_to_config( config.quantization, config.sparsity, @@ -101,24 +178,39 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: m_copy = m_copy.eval().to(config.device) quantize_(m_copy, ao_base_config) - if config.use_torch_compile: - print("Compiling quantized model....") - m_copy = torch.compile( - m_copy, mode=config.torch_compile_mode, fullgraph=True - ) - # Store result in model for memory profiling m_copy._benchmark_result = result - # Benchmark time to run an inference call for quantized model - print("Benchmarking quantized model.....") - result.model_inference_time_in_ms = model_inference_time_in_ms( + # Measure inference time for quantized model + print("Benchmarking eager quantized model.....") + result.quantized_model_eager_inference_time_in_ms = model_inference_time_in_ms( model=m_copy, input_data=input_data ) - # Calculate speedup w.r.t. baseline - result.speedup = round( - result.baseline_inference_time_in_ms / result.model_inference_time_in_ms, 2 + # Measure inference time for compiled quantized model + print("Benchmarking quantized model.....") + m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True) + result.quantized_model_compiled_inference_time_in_ms = ( + model_inference_time_in_ms(model=m_copy, input_data=input_data) + ) + + # Compute eager speedup relative to baseline + result.eager_speedup_on_baseline = round( + result.baseline_model_eager_inference_time_in_ms + / result.quantized_model_eager_inference_time_in_ms, + ndigits=2, + ) + # Compute compile speedup relative to baseline + result.compile_speedup_on_baseline = round( + result.baseline_model_compiled_inference_time_in_ms + / result.quantized_model_compiled_inference_time_in_ms, + ndigits=2, + ) + # Compute compile speedup for quantized model relative to eager quantized model + result.compile_speedup_on_eager = round( + result.quantized_model_eager_inference_time_in_ms + / result.quantized_model_compiled_inference_time_in_ms, + ndigits=2, ) # Run profiler if enabled @@ -149,13 +241,15 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: os.makedirs(memory_profiler_dir, exist_ok=True) # Save memory profile with .pickle extension - result.memory_profile_path = generate_memory_profile( - model=m_copy, - input_data=input_data, - profile_file_path=os.path.join( - memory_profiler_dir, - f"{config._file_name}_memory_profile.pickle", - ), + result.memory_profile_path, result.memory_stats = ( + generate_memory_profile( + model=m_copy, + input_data=input_data, + profile_file_path=os.path.join( + memory_profiler_dir, + f"{config._file_name}_memory_profile.pickle", + ), + ) ) if result.memory_profile_path: @@ -163,9 +257,9 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: result.memory_profile_path ) except ValueError as e: - if "not enough values to unpack" in e: + if "not enough values to unpack" in str(e): print( - "Failed due to existing bugs, re-run the code to generate memory profile. Please raise an issue if it persists." + "Failed due to existing bugs, re‑run the code to generate memory profile. Please raise an issue if it persists." ) except Exception as e: print(f"Error running memory profiler: {e}") diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 8066b71714..45a0534ee0 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -139,9 +139,6 @@ def get_quantization_sparsity_recipes( """ config_recipes = set() - # Always include baseline without sparsity - config_recipes.add(("baseline", None)) - # Add all quantization techniques without sparsity for quant_config in quantization_recipes: config_recipes.add((quant_config, None)) diff --git a/benchmarks/microbenchmarks/profiler.py b/benchmarks/microbenchmarks/profiler.py index c226216871..1decb620ee 100644 --- a/benchmarks/microbenchmarks/profiler.py +++ b/benchmarks/microbenchmarks/profiler.py @@ -91,6 +91,7 @@ def generate_memory_profile(model, input_data, profile_file_path): # Create parent directory if it doesn't exist os.makedirs(os.path.dirname(profile_file_path), exist_ok=True) + memory_stats = dict() try: torch.cuda.empty_cache() @@ -130,11 +131,19 @@ def generate_memory_profile(model, input_data, profile_file_path): print(f"Attempt {i + 1}/5: {e}, retrying...") time.sleep(3.0) + # Record memory stats + _memory_stats = torch.cuda.memory_stats() + memory_stats = { + "allocated_bytes.all.peak": _memory_stats["allocated_bytes.all.peak"] / 1e6, + "active_bytes.all.peak": _memory_stats["active_bytes.all.peak"] / 1e6, + "reserved_bytes.all.peak": _memory_stats["reserved_bytes.all.peak"] / 1e6, + } + except Exception as e: print(f"Error in memory profiling: {e}") # Return the file path for consistency with other profiler functions - return profile_file_path + return profile_file_path, memory_stats def visualize_memory_profile(profile_file_path): diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 4fd5eb2018..40db49e223 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -13,7 +13,6 @@ model_params: min_power: 14 max_power: 16 high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "linear" @@ -27,7 +26,6 @@ model_params: [2048, 4096, 1024], ] high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "ln_linear_sigmoid" @@ -41,7 +39,6 @@ model_params: [2048, 4096, 1024], # For transformer_block, k is the hidden dimension ] high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition) @@ -58,7 +55,6 @@ model_params: min_power: 10 # 1024 max_power: 11 # 2048 high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "linear" diff --git a/benchmarks/microbenchmarks/test/test_benchmark_inference.py b/benchmarks/microbenchmarks/test/test_benchmark_inference.py index 22863dcbcf..f3e853866d 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_inference.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_inference.py @@ -21,7 +21,6 @@ def setUp(self): sparsity="semi-sparse", params={ "high_precision_dtype": "torch.float32", - "use_torch_compile": False, "device": "cpu", "model_type": "linear", }, @@ -46,7 +45,9 @@ def test_run_inference(self, mock_string_to_config): result = run(self.config) self.assertIsInstance(result, BenchmarkResult) - self.assertTrue(hasattr(result, "model_inference_time_in_ms")) + self.assertTrue( + hasattr(result, "quantized_model_compiled_inference_time_in_ms") + ) @patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config") def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config): @@ -57,14 +58,14 @@ def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config): # Test with semi-sparse config mock_string_to_config.return_value = Int4WeightOnlyConfig( - layout=MarlinSparseLayout() + layout=MarlinSparseLayout(), + version=1, ) config = BenchmarkConfig( quantization="marlin", sparsity="semi-sparse", params={ "high_precision_dtype": "torch.float32", - "use_torch_compile": False, "device": "cpu", "model_type": "linear", }, @@ -75,7 +76,9 @@ def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config): ) result = run(config) self.assertIsInstance(result, BenchmarkResult) - self.assertTrue(hasattr(result, "model_inference_time_in_ms")) + self.assertTrue( + hasattr(result, "quantized_model_compiled_inference_time_in_ms") + ) @patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config") def test_run_inference_with_block_sparsity(self, mock_string_to_config): @@ -92,7 +95,6 @@ def test_run_inference_with_block_sparsity(self, mock_string_to_config): sparsity="block", params={ "high_precision_dtype": "torch.float32", - "use_torch_compile": False, "device": "cpu", "model_type": "linear", }, @@ -103,7 +105,9 @@ def test_run_inference_with_block_sparsity(self, mock_string_to_config): ) result = run(config) self.assertIsInstance(result, BenchmarkResult) - self.assertTrue(hasattr(result, "model_inference_time_in_ms")) + self.assertTrue( + hasattr(result, "quantized_model_compiled_inference_time_in_ms") + ) if __name__ == "__main__": diff --git a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py index 7f904b5bd3..d0c36d8cfe 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py @@ -178,7 +178,7 @@ def test_memory_profiler_enabled(self): ) # Generate memory profile - result_path = generate_memory_profile( + result_path, memory_stats = generate_memory_profile( self.model, self.input_data, memory_profile_path ) @@ -270,13 +270,12 @@ def test_memory_profiler_cuda_unavailable(self): f"{config.name}_{self.m}_{self.k}_{self.n}_memory_profile.json", ) - # Generate memory profile - result = generate_memory_profile( - self.model, self.input_data, memory_profile_path - ) - # Should return None when CUDA is unavailable - self.assertIsNone(result) + self.assertIsNone( + generate_memory_profile( + self.model, self.input_data, memory_profile_path + ) + ) # Should not create file when CUDA is unavailable self.assertFalse(os.path.exists(memory_profile_path)) diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index 2f7e5ba541..f7e54e4bec 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -39,7 +39,6 @@ def setUp(self): } ], "high_precision_dtype": "torch.bfloat16", - "use_torch_compile": True, "torch_compile_mode": "max-autotune", "device": "cpu", "model_type": "linear", @@ -130,7 +129,6 @@ def test_get_param_combinations(self): self.assertEqual(len(shapes), 1) self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024])) self.assertEqual(params["high_precision_dtype"], "torch.bfloat16") - self.assertEqual(params["use_torch_compile"], True) @patch("argparse.Namespace") def test_load_benchmark_configs(self, mock_args): diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 06f557a8f4..64af5b67e6 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -33,7 +33,6 @@ def setUp(self): self.test_params = { "name": "test_model", "high_precision_dtype": "torch.bfloat16", - "use_torch_compile": True, "torch_compile_mode": "max-autotune", "device": "cpu", "model_type": "linear", @@ -57,7 +56,6 @@ def test_benchmark_config(self): self.assertEqual(config.k, 1024) self.assertEqual(config.n, 1024) self.assertEqual(config.high_precision_dtype, torch.bfloat16) - self.assertEqual(config.use_torch_compile, True) self.assertEqual(config.torch_compile_mode, "max-autotune") self.assertEqual(config.device, "cpu") self.assertEqual(config.model_type, "linear") @@ -76,7 +74,7 @@ def test_benchmark_result(self): result = BenchmarkResult(config=config) self.assertEqual(result.config, config) - self.assertEqual(result.model_inference_time_in_ms, 0.0) + self.assertEqual(result.quantized_model_compiled_inference_time_in_ms, 0.0) def test_get_default_device(self): # Test CPU fallback diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index cbd864d6fe..d7300a6a81 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -18,6 +18,7 @@ Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, @@ -72,18 +73,13 @@ def __init__( self.high_precision_dtype = self._parse_precision( params.get("high_precision_dtype", "torch.bfloat16") ) - self.use_torch_compile = bool(params.get("use_torch_compile", False)) - self.torch_compile_mode = ( - params.get("torch_compile_mode", "default") - if self.use_torch_compile - else None - ) + self.torch_compile_mode = params.get("torch_compile_mode", "default") self.device = get_default_device(params.get("device", None)) self.model_type = params.get("model_type", "linear") self.output_dir = f"{output_dir}/{self.benchmark_mode}" self.name = params.get( "name", - f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.use_torch_compile else ''}", + f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile'}", ) self.enable_profiler = bool(params.get("enable_profiler", False)) self.enable_memory_profiler = bool(params.get("enable_memory_profiler", False)) @@ -107,7 +103,6 @@ def to_dict(self) -> Dict[str, Any]: "k": self.k, "n": self.n, "high_precision_dtype": self.high_precision_dtype, - "use_torch_compile": self.use_torch_compile, "torch_compile_mode": self.torch_compile_mode, "device": self.device, "model_type": self.model_type, @@ -124,9 +119,13 @@ def __init__( ): self.config = config self.output_dir = config.output_dir - self.baseline_inference_time_in_ms = 0.0 - self.model_inference_time_in_ms = 0.0 - self.speedup = 0.0 + self.baseline_model_eager_inference_time_in_ms = 0.0 + self.quantized_model_eager_inference_time_in_ms = 0.0 + self.baseline_model_compiled_inference_time_in_ms = 0.0 + self.quantized_model_compiled_inference_time_in_ms = 0.0 + self.eager_speedup_on_baseline = 0.0 + self.compile_speedup_on_baseline = 0.0 + self.compile_speedup_on_eager = 0.0 self.profiler_json_path: Optional[str] = None self.memory_profile_path: Optional[str] = None self.memory_visualization_path: Optional[str] = None @@ -136,9 +135,13 @@ def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for main function""" result_dict = { **self.config.to_dict(), - "baseline_inference_time_in_ms": self.baseline_inference_time_in_ms, - "model_inference_time_in_ms": self.model_inference_time_in_ms, - "speedup": self.speedup, + "baseline_model_eager_inference_time_in_ms": self.baseline_model_eager_inference_time_in_ms, + "quantized_model_eager_inference_time_in_ms": self.quantized_model_eager_inference_time_in_ms, + "baseline_model_compiled_inference_time_in_ms": self.baseline_model_compiled_inference_time_in_ms, + "quantized_model_compiled_inference_time_in_ms": self.quantized_model_compiled_inference_time_in_ms, + "eager speedup on baseline": self.eager_speedup_on_baseline, + "compile speedup on baseline": self.compile_speedup_on_baseline, + "eager vs compile speedup": self.compile_speedup_on_eager, "profiler_json_path": self.profiler_json_path, "memory_profile_path": self.memory_profile_path, "memory_visualization_path": self.memory_visualization_path, @@ -203,7 +206,7 @@ def string_to_config( 128, 256, ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" - return Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) + return Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1) elif "int8adq-int4w-symm" in quantization: from torchao.dtypes import CutlassInt4PackedLayout @@ -226,7 +229,7 @@ def string_to_config( elif sparsity is not None and ("semi" in sparsity or "2:4" in sparsity): from torchao.dtypes import MarlinSparseLayout - return Int4WeightOnlyConfig(layout=MarlinSparseLayout()) + return Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1) if "fp6" in quantization: return FPXWeightOnlyConfig(3, 2) elif "uintx" in quantization: @@ -257,7 +260,6 @@ def string_to_config( "int8_dynamic_activation_intx_weight requires using high_precision_dtype=torch.float32" ) - from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, @@ -275,8 +277,7 @@ def string_to_config( weight_mapping_type=MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC, - weight_scale_dtype=torch.bfloat16, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + intx_packing_format="opaque_torchao_auto", ) elif "float8wo" in quantization: return Float8WeightOnlyConfig() @@ -291,6 +292,23 @@ def string_to_config( else: granularity = PerTensor() return Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + if "gemlitewo" in quantization: + params = quantization.split("-") + bit_width = int(params[1]) if len(params) > 1 else 4 + group_size = ( + int(params[2]) + if len(params) > 2 and bit_width == 4 + else None + if bit_width == 8 + else 64 + ) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" + return GemliteUIntXWeightOnlyConfig(group_size=group_size, bit_width=bit_width) return None @@ -390,9 +408,13 @@ def print_results(results: List[BenchmarkResult]): result.config.quantization or "baseline", result.config.sparsity or "none", f"{result.config.shape_name} ({result.config.m}, {result.config.k}, {result.config.n})", - f"{result.baseline_inference_time_in_ms:.2f}", - f"{result.model_inference_time_in_ms:.2f}", - f"{result.speedup:.2f}x", + f"{result.baseline_model_eager_inference_time_in_ms:.2f}", + f"{result.quantized_model_eager_inference_time_in_ms:.2f}", + f"{result.eager_speedup_on_baseline:.2f}x", + f"{result.baseline_model_compiled_inference_time_in_ms:.2f}", + f"{result.quantized_model_compiled_inference_time_in_ms:.2f}", + f"{result.compile_speedup_on_baseline:.2f}x", + f"{result.compile_speedup_on_eager:.2f}x", str(result.config.enable_profiler), ] @@ -404,9 +426,13 @@ def print_results(results: List[BenchmarkResult]): "Quantization", "Sparsity", "Shape", - "Baseline Inference Time (ms)", - "Inference Time (ms)", - "Speedup", + "Eager Baseline Inference Time (ms)", + "Eager Model Inference Time (ms)", + "Eager Speedup", + "Compile Baseline Inference Time (ms)", + "Compile Model Inference Time (ms)", + "Compile Speedup", + "Eager vs Compile Speedup", "Profiler Enabled", ] diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index 56fbaf1c01..a9d8b18ae7 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -4,13 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Tuple +from typing import Tuple import fire import torch import triton -from torch._inductor.utils import do_bench_using_profiling +from triton.testing import do_bench +from torchao.prototype.mx_formats.config import ScaleCalculationMode from torchao.prototype.mx_formats.kernels import ( triton_to_mxfp8_dim1, ) @@ -53,22 +54,29 @@ def scale_dim0_dim1_reference( return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1 -def to_mx_dim0_reference(x_hp, block_size): - scale_d0, data_d0 = to_mx(x_hp, torch.float8_e4m3fn, block_size) +def to_mx_dim0_reference( + x_hp, + block_size, + scaling_mode=ScaleCalculationMode.FLOOR, + target_dtype=torch.float8_e4m3fn, +): + scale_d0, data_d0 = to_mx(x_hp, target_dtype, block_size, scaling_mode=scaling_mode) return data_d0, scale_d0 -def to_mx_dim1_reference(x_hp, block_size): +def to_mx_dim1_reference( + x_hp, + block_size, + scaling_mode=ScaleCalculationMode.FLOOR, + target_dtype=torch.float8_e4m3fn, +): x_hp = x_hp.t().contiguous() - scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size) + scale_d1, data_d1 = to_mx(x_hp, target_dtype, block_size, scaling_mode=scaling_mode) return data_d1.t(), scale_d1 -def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: - """Thin wrapper around do_bench_using_profiling""" - no_args = lambda: func(*args, **kwargs) - time = do_bench_using_profiling(no_args) - return time * 1e3 +def benchmark_cuda_function_in_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 def run( @@ -82,7 +90,19 @@ def run( print(f"torch version: {torch.__version__}") print(f"triton version: {triton.__version__}") print(f"mode: {mode}") - assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton") + assert mode in ( + "dim0", + "dim1", + "dim0_dim1", + "dim0_mxfp8_floor", + "dim0_mxfp4_floor", + "dim0_mxfp8_rceil", + "dim1_mxfp8_floor", + "dim1_mxfp8_rceil", + "dim1_mxfp8_triton_floor", + "dim1_mxfp8_cuda_floor", + "dim1_mxfp8_cuda_rceil", + ) x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 @@ -141,7 +161,7 @@ def run( ) bps = bytes_rw / (time_us / 1e6) - elif mode == "dim0_mx": + elif mode == "dim0_mxfp8_floor": to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference) y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE) @@ -159,7 +179,50 @@ def run( bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) - elif mode == "dim1_mx": + elif mode == "dim0_mxfp4_floor": + to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference) + y_d0, s_d0 = to_mx_dim0_reference_c( + x, BLOCK_SIZE, target_dtype=torch.float4_e2m1fn_x2 + ) + + for _ in range(2): + __ = to_mx_dim0_reference_c( + x, BLOCK_SIZE, target_dtype=torch.float4_e2m1fn_x2 + ) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: to_mx_dim0_reference_c( + x, BLOCK_SIZE, target_dtype=torch.float4_e2m1fn_x2 + ), + x, + BLOCK_SIZE, + ) + + # TODO(future PR): make to_mx return float4 directly + assert y_d0.dtype == torch.uint8 + assert s_d0.dtype == torch.float8_e8m0fnu + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim0_mxfp8_rceil": + to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference) + y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL) + + for _ in range(2): + __ = to_mx_dim0_reference_c(x, BLOCK_SIZE) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: to_mx_dim0_reference_c(x, BLOCK_SIZE), + x, + BLOCK_SIZE, + ) + + assert y_d0.dtype == torch.float8_e4m3fn + assert s_d0.dtype == torch.float8_e8m0fnu + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim1_mxfp8_floor": to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference) y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE) @@ -177,7 +240,25 @@ def run( bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) - elif mode == "dim1_mx_triton": + elif mode == "dim1_mxfp8_rceil": + to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference) + y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL) + + for _ in range(2): + __ = to_mx_dim1_reference_c(x, BLOCK_SIZE) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE), + x, + BLOCK_SIZE, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.float8_e8m0fnu + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim1_mxfp8_triton_floor": y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE) for _ in range(2): @@ -194,6 +275,58 @@ def run( bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) + elif mode == "dim1_mxfp8_cuda_floor": + from torchao.prototype import mxfp8_cuda + + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="floor" + ) + + for _ in range(2): + __ = mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="floor" + ) + + time_us = benchmark_cuda_function_in_microseconds( + lambda x: mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="floor" + ), + x, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.float8_e8m0fnu + + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim1_mxfp8_cuda_rceil": + from torchao.prototype import mxfp8_cuda + + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="rceil" + ) + + for _ in range(2): + __ = mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="rceil" + ) + + time_us = benchmark_cuda_function_in_microseconds( + lambda x: mxfp8_cuda.quantize( + x, rowwise=False, colwise=True, scaling_mode="rceil" + ), + x, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.float8_e8m0fnu + + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + else: raise AssertionError(f"unknown mode {mode}") diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py new file mode 100644 index 0000000000..14f47fe5f5 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_act_quant_lhs, + triton_fp8_blockwise_weight_quant_transposed_rhs, + triton_fp8_gemm_1x128_128x128, +) + +device = torch.device("cuda") + +# This benchmark requires CUDA 12.9+ +assert torch.version.cuda is not None, "CUDA is not available" +cuda_major, cuda_minor = map(int, torch.version.cuda.split(".")) +assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required" + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + out_dtype: torch.dtype + m: int + n: int + k: int + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_mm_us: float + fp8_triton_us: float + fp8_scaled_mm_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + mnk_list = [ + # Llama4 shapes + (16640, 5120, 8192), + (16640, 8192, 5120), + ] + out_dtypes = [torch.bfloat16] + configs = [] + for mnk, out_dtype in itertools.product(mnk_list, out_dtypes): + m, n, k = mnk + configs.append( + ExperimentConfig( + out_dtype=out_dtype, + m=m, + n=n, + k=k, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + # Simulate `grad_input = grad_output @ weight` + M, N, K = config.m, config.n, config.k + A = torch.randn(M, K, dtype=config.out_dtype, device="cuda") + B = torch.randn(N, K, dtype=config.out_dtype, device="cuda") + A_q, A_s = triton_fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn) + B_t_q, B_t_s = triton_fp8_blockwise_weight_quant_transposed_rhs( + B, dtype=torch.float8_e4m3fn + ) + + def warmup(func, *args, **kwargs): + for _ in range(10): + func(*args, **kwargs) + + # Warmup then run bf16 torch.mm + warmup(torch.mm, A, B.t()) + + bf16_mm_us = benchmark_cuda_function_in_microseconds(torch.mm, A, B.t()) + + # Warm up then run triton bench + warmup( + triton_fp8_gemm_1x128_128x128, + A_q, + B_t_q, + 1.0 / A_s, + 1.0 / B_t_s, + out_dtype=config.out_dtype, + ) + + fp8_triton_us = benchmark_cuda_function_in_microseconds( + triton_fp8_gemm_1x128_128x128, + A_q, + B_t_q, + 1.0 / A_s, + 1.0 / B_t_s, + out_dtype=config.out_dtype, + ) + + # Warm up then run torch bench + # scaled_mm requires A_s and B_t_s be in column-major format + A_s = A_s.t().contiguous().t() + + warmup( + torch._scaled_mm, + A_q, + B_t_q, + 1.0 / A_s, + 1.0 / B_t_s, + out_dtype=config.out_dtype, + ) + + fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds( + torch._scaled_mm, + A_q, + B_t_q, + 1.0 / A_s, + 1.0 / B_t_s, + out_dtype=config.out_dtype, + ) + + return ExperimentResult( + bf16_mm_us=bf16_mm_us, + fp8_triton_us=fp8_triton_us, + fp8_scaled_mm_us=fp8_scaled_mm_us, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "M", + "N", + "K", + "out_dtype", + "bf16_mm_us", + "fp8_triton_us", + "fp8_scaled_mm_us", + "bf16 tflops/sec", + "triton tflops/sec", + "scaled_mm tflops/sec", + ] + rows = [] + for experiment in experiments: + m, n, k = experiment.config.m, experiment.config.n, experiment.config.k + flops = 2 * m * n * k + bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6) + triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6) + scaled_mm_tflops_per_sec = (flops / 1e12) / ( + experiment.result.fp8_scaled_mm_us / 1e6 + ) + rows.append( + [ + m, + n, + k, + experiment.config.out_dtype, + experiment.result.bf16_mm_us, + experiment.result.fp8_triton_us, + experiment.result.fp8_scaled_mm_us, + bf16_mm_tflops_per_sec, + triton_tflops_per_sec, + scaled_mm_tflops_per_sec, + ] + ) + print(tabulate(rows, headers=headers)) + + +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py new file mode 100644 index 0000000000..5d429db302 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_act_quant_rhs, + triton_fp8_blockwise_act_quant_transposed_lhs, + triton_fp8_gemm_1x128_128x1, +) + +device = torch.device("cuda") + +# This benchmark requires CUDA 12.9+ +assert torch.version.cuda is not None, "CUDA is not available" +cuda_major, cuda_minor = map(int, torch.version.cuda.split(".")) +assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required" + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + out_dtype: torch.dtype + m: int + n: int + k: int + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_mm_us: float + fp8_triton_us: float + fp8_scaled_mm_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + mnk_list = [ + # Llama4 shapes + (16640, 5120, 8192), + (16640, 8192, 5120), + ] + out_dtypes = [torch.bfloat16] + configs = [] + for mnk, out_dtype in itertools.product(mnk_list, out_dtypes): + m, n, k = mnk + configs.append( + ExperimentConfig( + out_dtype=out_dtype, + m=m, + n=n, + k=k, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + # Simulate `grad_weight = grad_output_t @ input` + M, N, K = config.m, config.n, config.k + A = torch.randn(M, N, dtype=config.out_dtype, device="cuda") + B = torch.randn(M, K, dtype=config.out_dtype, device="cuda") + A_t_q, A_t_s = triton_fp8_blockwise_act_quant_transposed_lhs( + A, dtype=torch.float8_e4m3fn + ) + B_q, B_s = triton_fp8_blockwise_act_quant_rhs(B, dtype=torch.float8_e4m3fn) + + def warmup(func, *args, **kwargs): + for _ in range(10): + func(*args, **kwargs) + + # Warmup then run bf16 torch.mm + warmup(torch.mm, A.t(), B) + + bf16_mm_us = benchmark_cuda_function_in_microseconds(torch.mm, A.t(), B) + + # Warm up then run triton bench + warmup( + triton_fp8_gemm_1x128_128x1, + A_t_q, + B_q, + 1.0 / A_t_s, + 1.0 / B_s, + out_dtype=config.out_dtype, + ) + + fp8_triton_us = benchmark_cuda_function_in_microseconds( + triton_fp8_gemm_1x128_128x1, + A_t_q, + B_q, + 1.0 / A_t_s, + 1.0 / B_s, + out_dtype=config.out_dtype, + ) + + # Warm up then run torch bench + warmup( + torch._scaled_mm, + A_t_q, + B_q, + 1.0 / A_t_s, + 1.0 / B_s, + out_dtype=config.out_dtype, + ) + + fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds( + torch._scaled_mm, + A_t_q, + B_q, + 1.0 / A_t_s, + 1.0 / B_s, + out_dtype=config.out_dtype, + ) + + return ExperimentResult( + bf16_mm_us=bf16_mm_us, + fp8_triton_us=fp8_triton_us, + fp8_scaled_mm_us=fp8_scaled_mm_us, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "M", + "N", + "K", + "out_dtype", + "bf16_mm_us", + "fp8_triton_us", + "fp8_scaled_mm_us", + "bf16 tflops/sec", + "triton tflops/sec", + "scaled_mm tflops/sec", + ] + rows = [] + for experiment in experiments: + m, n, k = experiment.config.m, experiment.config.n, experiment.config.k + flops = 2 * m * n * k + bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6) + triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6) + scaled_mm_tflops_per_sec = (flops / 1e12) / ( + experiment.result.fp8_scaled_mm_us / 1e6 + ) + rows.append( + [ + m, + n, + k, + experiment.config.out_dtype, + experiment.result.bf16_mm_us, + experiment.result.fp8_triton_us, + experiment.result.fp8_scaled_mm_us, + bf16_mm_tflops_per_sec, + triton_tflops_per_sec, + scaled_mm_tflops_per_sec, + ] + ) + print(tabulate(rows, headers=headers)) + + +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py b/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py new file mode 100644 index 0000000000..7aefb9b546 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import argparse +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from triton.testing import do_bench + +from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd +from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear + +device = torch.device("cuda") + +# This benchmark requires CUDA 12.9+ +assert torch.version.cuda is not None, "CUDA is not available" +cuda_major, cuda_minor = map(int, torch.version.cuda.split(".")) +assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required" + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + out_dtype: torch.dtype + m: int + n: int + k: int + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_linear_us: float + fp8_triton_linear_us: float + fp8_scaled_mm_linear_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + mnk_list = [ + # Llama4 shapes + (16640, 5120, 8192), + (16640, 8192, 5120), + ] + out_dtypes = [torch.bfloat16] + configs = [] + for mnk, out_dtype in itertools.product(mnk_list, out_dtypes): + m, n, k = mnk + configs.append( + ExperimentConfig( + out_dtype=out_dtype, + m=m, + n=n, + k=k, + ) + ) + return configs + + +def run_experiment( + config: ExperimentConfig, profile=False, use_compile=False +) -> ExperimentResult: + M, N, K = config.m, config.n, config.k + inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda") + bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda") + fp8_triton_linear = Float8BlockwiseLinear( + K, N, dtype=config.out_dtype, device="cuda", use_triton=True + ) + fp8_scaled_mm_linear = Float8BlockwiseLinear( + K, N, dtype=config.out_dtype, device="cuda", use_triton=False + ) + + def warmup(func, *args, **kwargs): + for _ in range(3): + func(*args, **kwargs) + + # bfloat16 bench and profile + labels = inputs.new_empty(M, N).fill_(1.0) + bf16_linear_us = bench_fwd_bwd_microseconds( + bf16_linear, + inputs, + labels=labels, + use_compile=use_compile, + ) + if profile: + print("Profiling bf16_linear") + profile_fwd_bwd( + bf16_linear, + inputs, + labels=labels, + profile_name="bf16_linear_profile", + use_compile=use_compile, + ) + + # FP8 triton bench and profile + fp8_triton_linear_us = bench_fwd_bwd_microseconds( + fp8_triton_linear, + inputs, + labels=labels, + ) + if profile: + print("Profiling fp8_triton_linear") + profile_fwd_bwd( + fp8_triton_linear, + inputs, + labels=labels, + profile_name="fp8_triton_linear_profile", + ) + + # FP8 torch._scaled_mm bench and profile + fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds( + fp8_scaled_mm_linear, + inputs, + labels=labels, + use_compile=use_compile, + ) + if profile: + print("Profiling fp8_scaled_mm_linear") + profile_fwd_bwd( + fp8_scaled_mm_linear, + inputs, + labels=labels, + profile_name="fp8_scaled_mm_linear_profile", + use_compile=use_compile, + ) + + return ExperimentResult( + bf16_linear_us=bf16_linear_us, + fp8_triton_linear_us=fp8_triton_linear_us, + fp8_scaled_mm_linear_us=fp8_scaled_mm_linear_us, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "M", + "N", + "K", + "out_dtype", + "bf16_mm_linear_us", + "fp8_triton_linear_us", + "fp8_scaled_mm_linear_us", + ] + rows = [] + for experiment in experiments: + m, n, k = experiment.config.m, experiment.config.n, experiment.config.k + rows.append( + [ + m, + n, + k, + experiment.config.out_dtype, + experiment.result.bf16_linear_us, + experiment.result.fp8_triton_linear_us, + experiment.result.fp8_scaled_mm_linear_us, + ] + ) + print(tabulate(rows, headers=headers)) + + +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 + + +def main(args: argparse.Namespace): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config, profile=args.profile, use_compile=args.compile) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling") + parser.add_argument("--compile", action="store_true", help="Enable compilation") + args = parser.parse_args() + main(args) diff --git a/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py b/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py new file mode 100644 index 0000000000..9c49033a9d --- /dev/null +++ b/benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py +import argparse +import itertools +import logging +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm + +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.float8.config import ScalingGranularity +from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.prototype.moe_training.kernels.mxfp8 import ( + torch_to_blocked_2d_M_groups, + torch_to_blocked_per_group_3d, +) +from torchao.prototype.moe_training.utils import generate_jagged_offs +from torchao.prototype.mx_formats.mx_tensor import to_mx + +device = torch.device("cuda") + + +@dataclass(frozen=True) +class ExperimentConfig: + e: int + m: int + n: int + k: int + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_us: float + fp8_rowwise_us: float + mxfp8_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 shapes + M = [16640] + K = [2048, 5120, 8192] + N = [2048, 5120, 8192] + E = [1, 2, 4, 8] + configs = [] + for e, m, n, k in itertools.product( + E, + M, + N, + K, + ): + configs.append( + ExperimentConfig( + e=e, + m=m, + n=n, + k=k, + ) + ) + return configs + + +def run_experiment( + config: ExperimentConfig, args: argparse.Namespace +) -> ExperimentResult: + e, m, n, k = config.e, config.m, config.n, config.k + + # define test inputs + A = torch.randn( + (m, k), + dtype=torch.bfloat16, + device=device, + ) + B_t = torch.randn( + (e, n, k), + dtype=torch.bfloat16, + device=device, + requires_grad=True, + ).transpose(-2, -1) + + # Configure groups + n_groups = e + Mg = A.shape[0] + alignment_size = 16 + offs = generate_jagged_offs(n_groups, Mg, multiple_of=alignment_size) + + # benchmark bf16 grouped mm + bf16_us = benchmark_cuda_function_in_microseconds( + torch._grouped_mm, + A, + B_t, + offs, + out_dtype=torch.bfloat16, + ) + + # bench fp8 rowwise grouped mm + if torch.cuda.get_device_capability() != (9, 0): + logging.warning( + f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" + ) + fp8_rowwise_us = float("inf") + else: + fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs) + + # benchmark mxfp8 grouped mm + if torch.cuda.get_device_capability() != (10, 0): + logging.warning( + f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" + ) + mxfp8_us = float("inf") + else: + mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs) + + return ExperimentResult( + bf16_us=round(bf16_us, 3), + fp8_rowwise_us=round(fp8_rowwise_us, 3), + mxfp8_us=round(mxfp8_us, 3), + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "E", + "M", + "N", + "K", + "bf16_time_us", + "fp8_rowwise_time_us", + "mxfp8_time_us", + "bf16_tflops", + "fp8_rowwise_tflops", + "mxfp8_tflops", + "fp8_rowwise_speedup", + "mxfp8_speedup", + ] + rows = [] + for experiment in experiments: + # calculate tflops + e, m, n, k = ( + experiment.config.e, + experiment.config.m, + experiment.config.n, + experiment.config.k, + ) + flops = 2 * e * m * n * k + bf16_tflops = (flops / 1e12) / (experiment.result.bf16_us / 1e6) + fp8_rowwise_tflops = (flops / 1e12) / (experiment.result.fp8_rowwise_us / 1e6) + mxfp8_tflops = (flops / 1e12) / (experiment.result.mxfp8_us / 1e6) + rows.append( + [ + experiment.config.e, + experiment.config.m, + experiment.config.n, + experiment.config.k, + experiment.result.bf16_us, + experiment.result.fp8_rowwise_us, + experiment.result.mxfp8_us, + round(bf16_tflops, 3), + round(fp8_rowwise_tflops, 3), + round(mxfp8_tflops, 3), + f"{experiment.result.bf16_us / experiment.result.fp8_rowwise_us:.2f}x", + f"{experiment.result.bf16_us / experiment.result.mxfp8_us:.2f}x", + ] + ) + print(tabulate(rows, headers=headers)) + + +# benchmark fp8 grouped mm +def bench_fp8_rowwise_grouped_mm(A, B_t, offs) -> float: + # Convert A to float8, row-major for left operand of grouped GEMM. + A_scales = tensor_to_scale( + A, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=True, + ) + A_scaled = A.to(torch.float32) * A_scales + A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + + # Convert B_t to float8, column-major for right operand of grouped GEMM. + B_t_scales = tensor_to_scale( + B_t, + torch.float8_e4m3fn, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-2, + round_scales_to_power_of_2=True, + ) + B_t_scaled = B_t.to(torch.float32) * B_t_scales + B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) + + # Bench the gemm + fp8_us = benchmark_cuda_function_in_microseconds( + torch._scaled_grouped_mm, + A_fp8_row_major, + B_t_fp8_col_major, + A_scales.squeeze(1).reciprocal(), + B_t_scales.squeeze(1).reciprocal(), + offs, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + return fp8_us + + +def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float: + # A_mx shape: (M, K) + # A_scale shape: (M, K//block_size) + A_scales, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size) + + # B_mx shape: (E, N, K) + # B_scale shape: (E, N, K//block_size) + B_scales, B_fp8 = to_mx( + B_t.transpose(-2, -1), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Convert scales for each group to blocked format. + Mg, K = A_fp8.shape + A_scales_blocked, starting_row_after_padding = torch_to_blocked_2d_M_groups( + A_scales, offs, K + ) + B_scales_blocked = torch_to_blocked_per_group_3d(B_scales) + + # From this, we compute `group_sizes` and `starting_row_after_padding`: + # group_sizes = [32, 32, 64] + # starting_row_after_padding = [0, 32, 64, 128] + zero = torch.tensor([0], dtype=offs.dtype, device=offs.device) + group_sizes = torch.diff(offs, prepend=zero).to(torch.int64) + + # Run the grouped mm + mxfp8_us = benchmark_cuda_function_in_microseconds( + torch.ops.fbgemm.mx8mx8bf16_grouped_stacked, + A_fp8, + B_fp8, + A_scales_blocked, + B_scales_blocked, + group_sizes, + starting_row_after_padding=starting_row_after_padding, + ) + return mxfp8_us + + +def main(args: argparse.Namespace): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config, args) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + args = arg_parser.parse_args() + main(args) diff --git a/benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py b/benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py new file mode 100644 index 0000000000..0ff13759d2 --- /dev/null +++ b/benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these benchmarks, use the following command: +# +# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py +# +####################################################################### + +import argparse +import copy +import logging +import os + +import pytest +import torch +from torch import distributed as dist +from torch import nn +from torch.distributed._composable.fsdp import fully_shard +from torch.nn import functional as F + +from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) +from torchao.quantization.quant_api import quantize_ + +# this benchmark requires torchtitan +try: + from torchtitan.distributed.expert_parallel import ( + set_token_group_alignment_size_m, + ) + from torchtitan.models.moe import MoE, MoEArgs +except ImportError: + pytest.skip( + "torchtitan not installed, skipping MoE tests.", allow_module_level=True + ) + + +def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile: bool): + assert torch.cuda.is_available() + assert recipe_name in ["fp8_rowwise", "mxfp8"] + recipe = MoEScalingType[recipe_name.upper()] + if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != ( + 9, + 0, + ): + logging.warning( + f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" + ) + return + + elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != ( + 10, + 0, + ): + logging.warning( + f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" + ) + return + + # setup distributed for fsdp + setup_distributed() + + # define model args + target_fqns = ["experts"] + model_args = MoEArgs( + num_experts=16, + ) + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE using llama4 shapes + dim, hidden_dim = 5120, 8192 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() + torch.manual_seed(42) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # Token group alignment size must be 16 for fp8 rowwise training + alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16 + set_token_group_alignment_size_m(alignment_size) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig(scaling_type=recipe) + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # FSDP2 + fully_shard(model) + fully_shard(ref_model) + + # inputs (llama4 shapes) + batch, seq = 1, 16640 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + def warmup(model, input): + for _ in range(3): + out = model(input) + loss = F.mse_loss(out, torch.ones_like(out)) + loss.backward() + torch.cuda.synchronize() + + labels = torch.ones_like(x) + + # TODO: bench with fullgraph=True if/when it is supported + bf16_us = bench_fwd_bwd_microseconds( + ref_model, + ref_x, + labels=labels, + use_compile=use_compile, + fullgraph=False, + ) + print(f"BF16 time: {bf16_us} us") + if enable_profile: + print("Profiling bf16 training") + profile_fwd_bwd(ref_model, ref_x, labels=labels, profile_name="bf16_profile") + + scaled_us = bench_fwd_bwd_microseconds( + model, + x, + labels=labels, + use_compile=use_compile, + fullgraph=False, + ) + print(f"Scaled time: {scaled_us} us") + if enable_profile: + print("Profiling quantized training") + profile_fwd_bwd(model, x, labels=labels, profile_name=f"{recipe_name}_profile") + + print(f"Speedup: {bf16_us / scaled_us:.3f}x") + dist.destroy_process_group() + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark MoE layer with FSDP2") + parser.add_argument( + "--profile", + action="store_true", + help="Enable PyTorch profiling and save results to file", + ) + parser.add_argument( + "--recipe", type=str, help="[fp8_rowwise, mxfp8]", required=True + ) + parser.add_argument( + "--compile", + action="store_true", + help="use torch.compile", + ) + args = parser.parse_args() + bench_moe_training_fsdp( + recipe_name=args.recipe, + enable_profile=args.profile, + use_compile=args.compile, + ) diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py new file mode 100644 index 0000000000..a28d981e8a --- /dev/null +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py +import argparse +import itertools +import logging +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm + +from benchmarks.utils import ( + bench_fwd_bwd_microseconds, + bench_fwd_microseconds, + profile_fwd_bwd, +) +from torchao.prototype.moe_training import _scaled_grouped_mm +from torchao.prototype.moe_training.conversion_utils import MoEScalingType +from torchao.prototype.moe_training.utils import generate_jagged_offs + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + +# Dynamic shapes hurt performance +torch._dynamo.config.automatic_dynamic_shapes = False + + +@dataclass(frozen=True) +class ExperimentConfig: + high_precision_dtype: torch.dtype + MNKG: tuple[int] + recipe: MoEScalingType + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_fwd_bwd_us: float + scaled_fwd_bwd_us: float + scaled_fwd_bwd_speedup: float + bf16_fwd_us: float + scaled_fwd_us: float + scaled_fwd_speedup: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + MNKG_list = [ + # Llama4 16e with various experts per device (i.e., different EP degrees) + (16384, 8192, 5120, 1), + (16384, 8192, 5120, 2), + (16384, 8192, 5120, 4), + (16384, 8192, 5120, 8), + (128000, 8192, 5120, 1), + (128000, 8192, 5120, 2), + (128000, 8192, 5120, 4), + (128000, 8192, 5120, 8), + # DSV3 236B with various experts per device (i.e., different EP degrees) + (16384, 1536, 5120, 1), + (16384, 1536, 5120, 2), + (16384, 1536, 5120, 4), + (16384, 1536, 5120, 8), + (128000, 1536, 5120, 1), + (128000, 1536, 5120, 2), + (128000, 1536, 5120, 4), + (128000, 1536, 5120, 8), + # DSV3 671B with various experts per device (i.e., different EP degrees) + (16384, 2048, 7168, 1), + (16384, 2048, 7168, 2), + (16384, 2048, 7168, 4), + (16384, 2048, 7168, 8), + (128000, 2048, 7168, 1), + (128000, 2048, 7168, 2), + (128000, 2048, 7168, 4), + (128000, 2048, 7168, 8), + ] + recipes = [MoEScalingType.FP8_ROWWISE, MoEScalingType.MXFP8] + high_precision_dtypes = [torch.bfloat16] + configs = [] + for MNKG, recipe, high_precision_dtype in itertools.product( + MNKG_list, + recipes, + high_precision_dtypes, + ): + configs.append( + ExperimentConfig( + MNKG=MNKG, + recipe=recipe, + high_precision_dtype=high_precision_dtype, + ) + ) + return configs + + +def run_experiment( + config: ExperimentConfig, args: argparse.Namespace +) -> ExperimentResult: + total_M, N, K, G = config.MNKG + + # define test inputs + A = torch.randn( + (total_M, K), + dtype=config.high_precision_dtype, + device=device, + requires_grad=True, + ) + B_t = torch.randn( + (G, N, K), + dtype=config.high_precision_dtype, + device=device, + requires_grad=True, + ).transpose(-2, -1) + + # - configure input to be row-major with groups divided along the column dimension, + # representing the left operand of grad_weight = grad_output_t @ input + # that occurs in the backward pass of the differentiable scaled grouped mm. + # - the transposed tensor in col-major format with groups along the row dimension, + # which represents the right operand. + token_group_alignment_size = 32 if config.recipe == MoEScalingType.MXFP8 else 16 + offs = generate_jagged_offs(G, total_M, multiple_of=token_group_alignment_size) + + labels = torch.ones( + (A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16 + ) + + # fwd_bwd bf16 benchmark + profiling + bf16_fwd_bwd_us = bench_fwd_bwd_microseconds( + torch._grouped_mm, + A, + B_t, + offs, + labels=labels, + use_compile=args.compile, + fullgraph=False, + ) + if args.profile: + profile_fwd_bwd( + torch._grouped_mm, + A, + B_t, + offs, + labels=labels, + use_compile=args.compile, + fullgraph=False, + profile_name="bf16_profile", + ) + + # fwd_bwd scaled benchmark + profiling + scaled_fwd_bwd_us = bench_fwd_bwd_microseconds( + _scaled_grouped_mm, + A, + B_t, + offs, + scaling_type=config.recipe, + labels=labels, + use_compile=args.compile, + fullgraph=False, + ) + if args.profile: + profile_fwd_bwd( + _scaled_grouped_mm, + A, + B_t, + offs, + scaling_type=config.recipe, + labels=labels, + use_compile=args.compile, + profile_name="scaled_profile", + fullgraph=False, + ) + + # Forward pass benchmarks + bf16_fwd_us = bench_fwd_microseconds( + torch._grouped_mm, + A, + B_t, + offs, + use_compile=args.compile, + fullgraph=True, + ) + scaled_fwd_us = bench_fwd_microseconds( + _scaled_grouped_mm, + A, + B_t, + offs, + scaling_type=config.recipe, + use_compile=args.compile, + fullgraph=True, + ) + + return ExperimentResult( + bf16_fwd_bwd_us=round(bf16_fwd_bwd_us, 3), + scaled_fwd_bwd_us=round(scaled_fwd_bwd_us, 3), + scaled_fwd_bwd_speedup=round(bf16_fwd_bwd_us / scaled_fwd_bwd_us, 3), + bf16_fwd_us=round(bf16_fwd_us, 3), + scaled_fwd_us=round(scaled_fwd_us, 3), + scaled_fwd_speedup=round(bf16_fwd_us / scaled_fwd_us, 3), + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "M,N,K,G", + "recipe", + "bf16_fwd_bwd_us", + "scaled_fwd_bwd_us", + "scaled_fwd_bwd_speedup", + "bf16_fwd_us", + "scaled_fwd_us", + "scaled_fwd_speedup", + ] + rows = [] + for experiment in experiments: + rows.append( + [ + str(experiment.config.MNKG), + experiment.config.recipe, + experiment.result.bf16_fwd_bwd_us, + experiment.result.scaled_fwd_bwd_us, + f"{experiment.result.scaled_fwd_bwd_speedup}x", + experiment.result.bf16_fwd_us, + experiment.result.scaled_fwd_us, + f"{experiment.result.scaled_fwd_speedup}x", + ] + ) + print(tabulate(rows, headers=headers)) + + +def main(args: argparse.Namespace): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + if ( + config.recipe == MoEScalingType.FP8_ROWWISE + and torch.cuda.get_device_capability() != (9, 0) + ): + logging.warning( + f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" + ) + continue + + elif ( + config.recipe == MoEScalingType.MXFP8 + and torch.cuda.get_device_capability() != (10, 0) + ): + logging.warning( + f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" + ) + continue + + result = run_experiment(config, args) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--compile", action="store_true") + arg_parser.add_argument("--profile", action="store_true") + args = arg_parser.parse_args() + main(args) diff --git a/torchao/prototype/moe_training/benchmarks/benchmark_kernels.py b/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_per_group_colwise_scales.py similarity index 50% rename from torchao/prototype/moe_training/benchmarks/benchmark_kernels.py rename to benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_per_group_colwise_scales.py index 37701e6545..2e164b344b 100644 --- a/torchao/prototype/moe_training/benchmarks/benchmark_kernels.py +++ b/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_per_group_colwise_scales.py @@ -6,21 +6,20 @@ # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py import itertools -import time from dataclasses import dataclass from typing import List import torch from tabulate import tabulate from tqdm import tqdm +from triton.testing import do_bench from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, ) from torchao.prototype.moe_training.utils import ( - _to_2d_jagged_float8_tensor_colwise, - _to_2d_jagged_float8_tensor_rowwise, + generate_jagged_offs, + torch_to_float8_per_group_colwise, ) device = torch.device("cuda") @@ -38,8 +37,10 @@ class ExperimentConfig: @dataclass(frozen=True) class ExperimentResult: - torch_time_us: float + torch_loop_time_us: float triton_time_us: float + torch_mem_bw_gbps: float + triton_mem_bw_gbps: float @dataclass(frozen=True) @@ -49,8 +50,8 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)] - n_groups_list = [4, 8, 16] + input_shapes = [(16640, 5120)] # (Mg, K) + n_groups_list = [1, 16, 64] high_precision_dtypes = [torch.bfloat16] configs = [] for input_shape, n_groups, high_precision_dtype in itertools.product( @@ -67,94 +68,104 @@ def get_configs() -> List[ExperimentConfig]: def run_experiment(config: ExperimentConfig) -> ExperimentResult: - # define test inputs - input_tensor = torch.randn( - *config.input_shape, - dtype=config.high_precision_dtype, - device=device, + # Define test inputs + Mg, K = config.input_shape + + # Column major input tensor. + # Right operand in grad_weight = grad_output_t @ input + input_tensor = ( + torch.randn( + Mg, + K, + dtype=config.high_precision_dtype, + device=device, + ) + .transpose(-2, -1) + .contiguous() + .transpose(-2, -1) ) - input_row_major = input_tensor.clone().detach() - input_col_major = input_tensor.clone().detach().t() # - configure input to be row-major with groups divided along the column dimension, # representing the left operand of grad_weight = grad_output_t @ input # that occurs in the backward pass of the differentiable scaled grouped mm. # - the transposed tensor in col-major format with groups along the row dimension, # which represents the right operand. - group_size = input_row_major.shape[1] // config.n_groups n_groups = config.n_groups - offs = torch.arange( - group_size, - group_size * n_groups + 1, - group_size, - device=device, - dtype=torch.int32, - ) + offs = generate_jagged_offs(n_groups, Mg, multiple_of=16) def warmup(func, *args, **kwargs): for _ in range(10): func(*args, **kwargs) - def run_torch( - input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor - ): - _ = _to_2d_jagged_float8_tensor_rowwise( - input_row_major, - offs, - target_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - _ = _to_2d_jagged_float8_tensor_colwise( - input_col_major, - offs, - target_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) + # Bench torch per group colwise + torch_to_float8_per_group_colwise_c = torch.compile( + torch_to_float8_per_group_colwise + ) + warmup( + torch_to_float8_per_group_colwise_c, + input_tensor, + offs, + target_dtype=torch.float8_e4m3fn, + ) + torch_loop_time_us = benchmark_cuda_function_in_microseconds( + torch_to_float8_per_group_colwise_c, + input_tensor, + offs, + target_dtype=torch.float8_e4m3fn, + ) - def run_triton( - input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor - ): - _ = triton_fp8_row_major_jagged_rowwise_scales( - input_row_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - _ = triton_fp8_col_major_jagged_colwise_scales( - input_col_major, - offs, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) + # Bench triton per group colwise + warmup( + triton_fp8_per_group_colwise_scales, + input_tensor, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_fp8_per_group_colwise_scales, + input_tensor, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) - # bench torch - compiled_run_torch = torch.compile(run_torch) - warmup(compiled_run_torch, input_row_major, input_col_major, offs) - start_time_ns = time.perf_counter_ns() - compiled_run_torch(input_row_major, input_col_major, offs) - torch_time_ns = time.perf_counter_ns() - start_time_ns - torch_time_us = torch_time_ns / 1e3 - - # bench triton - warmup(run_triton, input_row_major, input_col_major, offs) - start_time_ns = time.perf_counter_ns() - run_triton(input_row_major, input_col_major, offs) - triton_time_ns = time.perf_counter_ns() - start_time_ns - triton_time_us = triton_time_ns / 1e3 + # Mem bw calculations + bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8 + num_elements = input_tensor.numel() + read_bytes = ( + 2 * num_elements * bytes_per_input_el # read input tensor twice + + 4 * (n_groups * K) # read scales tensor once, 4 bytes per fp32 scale + ) + write_bytes = ( + # 1 byte per output elem in fp8 + num_elements + + + # write scales tensor, 4 bytes per fp32 scale (we actually do this write once per blong along the reduction dim using atomics, but this is an approximation) + 4 * (n_groups * K) + ) + read_write_bytes = read_bytes + write_bytes + torch_mem_bw_gbps = (read_write_bytes) / (torch_loop_time_us / 1e6) / 1e9 + triton_mem_bw_gbps = (read_write_bytes) / (triton_time_us / 1e6) / 1e9 return ExperimentResult( - torch_time_us=torch_time_us, + torch_loop_time_us=torch_loop_time_us, triton_time_us=triton_time_us, + torch_mem_bw_gbps=torch_mem_bw_gbps, + triton_mem_bw_gbps=triton_mem_bw_gbps, ) def print_results(experiments: List[Experiment]): headers = [ - "input_shape", + "Mg,K", "n_groups", "high_precision_dtype", - "torch_time_us", + "torch_loop_time_us", "triton_time_us", + "torch_mem_bw_gbps", + "triton_mem_bw_gbps", + "triton_speedup", ] rows = [] for experiment in experiments: @@ -166,13 +177,20 @@ def print_results(experiments: List[Experiment]): input_shape, experiment.config.n_groups, experiment.config.high_precision_dtype, - experiment.result.torch_time_us, + experiment.result.torch_loop_time_us, experiment.result.triton_time_us, + round(experiment.result.torch_mem_bw_gbps, 3), + round(experiment.result.triton_mem_bw_gbps, 3), + f"{experiment.result.torch_loop_time_us / experiment.result.triton_time_us:.2f}x", ] ) print(tabulate(rows, headers=headers)) +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 + + def main(): torch.random.manual_seed(123) configs = get_configs() diff --git a/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_per_group_rowwise_scales.py b/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_per_group_rowwise_scales.py new file mode 100644 index 0000000000..af14e6a4bc --- /dev/null +++ b/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_per_group_rowwise_scales.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, +) +from torchao.prototype.moe_training.utils import ( + generate_jagged_offs, + torch_to_float8_per_group_rowwise, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + high_precision_dtype: torch.dtype + input_shape: tuple[int] + n_groups: int + + +@dataclass(frozen=True) +class ExperimentResult: + torch_loop_time_us: float + triton_time_us: float + triton_transpose_us: float + torch_mem_bw_gbps: float + triton_mem_bw_gbps: float + triton_transpose_mem_bw_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + input_shapes = [(16640, 8192)] # (Mg, N) + n_groups_list = [1, 16, 64] + high_precision_dtypes = [torch.bfloat16] + configs = [] + for input_shape, n_groups, high_precision_dtype in itertools.product( + input_shapes, n_groups_list, high_precision_dtypes + ): + configs.append( + ExperimentConfig( + input_shape=input_shape, + n_groups=n_groups, + high_precision_dtype=high_precision_dtype, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + # define test inputs + Mg, N = config.input_shape + + # Left operand in grad_weight = grad_output_t @ input + grad_out = torch.randn( + Mg, + N, + dtype=config.high_precision_dtype, + device=device, + ) + grad_out_t = grad_out.transpose(-2, -1) + + # - configure input to be row-major with groups divided along the column dimension, + # representing the left operand of grad_weight = grad_output_t @ input + # that occurs in the backward pass of the differentiable scaled grouped mm. + # - the transposed tensor in col-major format with groups along the row dimension, + # which represents the right operand. + n_groups = config.n_groups + offs = generate_jagged_offs(n_groups, Mg, multiple_of=16) + + def warmup(func, *args, **kwargs): + for _ in range(10): + func(*args, **kwargs) + + # Bench torch per group rowwise + torch_to_float8_per_group_rowwise_c = torch.compile( + torch_to_float8_per_group_rowwise + ) + warmup( + torch_to_float8_per_group_rowwise_c, + grad_out_t, + offs, + target_dtype=torch.float8_e4m3fn, + ) + torch_loop_time_us = benchmark_cuda_function_in_microseconds( + torch_to_float8_per_group_rowwise_c, + grad_out_t, + offs, + target_dtype=torch.float8_e4m3fn, + ) + + # Bench triton per group rowwise scaling kernel + warmup( + triton_fp8_per_group_rowwise_scales, + grad_out_t, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_fp8_per_group_rowwise_scales, + grad_out_t, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + + # Bench method where we compute colwise scales on grad_output (equivalent to rowwise scales on grad_output_t) + def run_triton_transpose_method( + grad_out, offs, output_dtype, round_scales_to_power_of_2 + ): + # Restride input as column major. + # Note this is the transpose of grad_output_t, which is what we are trying to compute per group rowwise scales for. + grad_out = grad_out.t().contiguous().t() + # Compute per group colwise scales, writing to column major format. + fp8_data, scales = triton_fp8_per_group_colwise_scales( + grad_out, offs, output_dtype, round_scales_to_power_of_2 + ) + return fp8_data.t(), scales.t() + + run_triton_c = torch.compile(run_triton_transpose_method) + warmup( + run_triton_c, + grad_out, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + triton_transpose_us = benchmark_cuda_function_in_microseconds( + run_triton_c, + grad_out, + offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + + # Mem bw calculations + bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8 + num_elements = grad_out_t.numel() + + read_bytes = ( + 2 * num_elements * bytes_per_input_el # read input tensor twice + + 4 * (n_groups * N) # read scales tensor once, 4 bytes per fp32 scale + ) + write_bytes = ( + # 1 byte per output elem in fp8 + num_elements + + + # write scales tensor, 4 bytes per fp32 scale (we actually do this write once per blong along the reduction dim using atomics, but this is an approximation) + 4 * (n_groups * N) + ) + + read_write_bytes = read_bytes + write_bytes + torch_mem_bw_gbps = (read_write_bytes) / (torch_loop_time_us / 1e6) / 1e9 + triton_mem_bw_gbps = (read_write_bytes) / (triton_time_us / 1e6) / 1e9 + + # Transpose method has extra reads/writes: + to_col_major_read_write_bytes = ( + 2 * num_elements * bytes_per_input_el + ) # read once, write once when converting input to column major + triton_transpose_mem_bw_gbps = ( + (read_write_bytes + to_col_major_read_write_bytes) + / (triton_transpose_us / 1e6) + / 1e9 + ) + return ExperimentResult( + torch_loop_time_us=torch_loop_time_us, + triton_time_us=triton_time_us, + triton_transpose_us=triton_transpose_us, + torch_mem_bw_gbps=torch_mem_bw_gbps, + triton_mem_bw_gbps=triton_mem_bw_gbps, + triton_transpose_mem_bw_gbps=triton_transpose_mem_bw_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "Mg,N", + "n_groups", + "torch_loop_time_us", + "triton_time_us", + "triton_transpose_us", + "torch_mem_bw_gbps", + "triton_mem_bw_gbps", + "triton_transpose_mem_bw_gbps", + "triton_speedup", + "triton_transpose_speedup", + ] + rows = [] + for experiment in experiments: + input_shape = ( + f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})" + ) + rows.append( + [ + input_shape, + experiment.config.n_groups, + experiment.result.torch_loop_time_us, + experiment.result.triton_time_us, + experiment.result.triton_transpose_us, + round(experiment.result.torch_mem_bw_gbps, 3), + round(experiment.result.triton_mem_bw_gbps, 3), + round(experiment.result.triton_transpose_mem_bw_gbps, 3), + f"{experiment.result.torch_loop_time_us / experiment.result.triton_time_us:.2f}x", + f"{experiment.result.torch_loop_time_us / experiment.result.triton_transpose_us:.2f}x", + ] + ) + print(tabulate(rows, headers=headers)) + + +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_rowwise_3d_transpose_rhs.py b/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_rowwise_3d_transpose_rhs.py new file mode 100644 index 0000000000..dc65af85c5 --- /dev/null +++ b/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_rowwise_3d_transpose_rhs.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.prototype.moe_training.kernels.float8_rowwise import ( + triton_fp8_rowwise_3d_transpose_rhs, + triton_fp8_rowwise_3d_transpose_rhs_fused_reduction, +) +from torchao.prototype.moe_training.utils import ( + torch_to_3d_rowwise_float8_transpose_rhs, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + high_precision_dtype: torch.dtype + input_shape: tuple[int] + power_of_2_scales: bool + + +@dataclass(frozen=True) +class ExperimentResult: + torch_time_us: float + triton_atomic_time_us: float + triton_reduction_time_us: float + torch_mem_bw_gbps: float + triton_atomic_mem_bw_gbps: float + triton_reduction_mem_bw_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 shapes (E, N, K) + input_shapes = [ + (1, 8192, 5120), # w1, w3 + (1, 5120, 8192), # w2 + (16, 8192, 5120), # w1, w3 + (16, 5120, 8192), # w2 + (128, 8192, 5120), # w1, w3 + (128, 5120, 8192), # w2 + ] + high_precision_dtypes = [torch.bfloat16] + power_of_2_scales = [True] + configs = [] + for input_shape, high_precision_dtype, power_of_2_scale in itertools.product( + input_shapes, high_precision_dtypes, power_of_2_scales + ): + configs.append( + ExperimentConfig( + input_shape=input_shape, + high_precision_dtype=high_precision_dtype, + power_of_2_scales=power_of_2_scale, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + # Expert weights will be passed in transposed and column major in practice + input_tensor = torch.randn( + *config.input_shape, + dtype=config.high_precision_dtype, + device=device, + ).transpose(-2, -1) + + def warmup(func, *args, **kwargs): + for _ in range(10): + func(*args, **kwargs) + + def run_torch(input_tensor: torch.Tensor): + out = torch_to_3d_rowwise_float8_transpose_rhs( + input_tensor, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=config.power_of_2_scales, + ) + return out + + def run_triton_atomic(input_tensor: torch.Tensor): + out = triton_fp8_rowwise_3d_transpose_rhs( + input_tensor, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=config.power_of_2_scales, + ) + return out + + def run_triton_reduction(input_tensor: torch.Tensor): + out = triton_fp8_rowwise_3d_transpose_rhs_fused_reduction( + input_tensor, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=config.power_of_2_scales, + ) + return out + + # bench torch + compiled_run_torch = torch.compile(run_torch) + warmup(run_torch, input_tensor) + torch_time_us = benchmark_cuda_function_in_microseconds( + compiled_run_torch, + input_tensor, + ) + + # bench triton atomic method + run_triton_atomic_c = torch.compile(run_triton_atomic) + warmup(run_triton_atomic_c, input_tensor) + triton_atomic_time_us = benchmark_cuda_function_in_microseconds( + run_triton_atomic_c, + input_tensor, + ) + + # bench triton reduction method + run_triton_reduction_c = torch.compile(run_triton_reduction) + warmup(run_triton_reduction_c, input_tensor) + triton_reduction_time_us = benchmark_cuda_function_in_microseconds( + run_triton_reduction_c, + input_tensor, + ) + + # mem bw calculations - excluding scales to simplify calculation + # but still get an accurate estimate. + bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + num_elements = input_tensor.numel() + + read_bytes = num_elements * bytes_per_input_el + write_bytes = num_elements * bytes_per_output_el + + # Both torch.compile codegen and the triton kernel read the input tensor twice + # (once for scale calculations, once for scaling + casting). + torch_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_atomic_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / ( + triton_atomic_time_us / 1e6 + ) + triton_reduction_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / ( + triton_reduction_time_us / 1e6 + ) + + return ExperimentResult( + torch_time_us=torch_time_us, + triton_atomic_time_us=triton_atomic_time_us, + triton_reduction_time_us=triton_reduction_time_us, + torch_mem_bw_gbps=torch_mem_bw_gbps, + triton_atomic_mem_bw_gbps=triton_atomic_mem_bw_gbps, + triton_reduction_mem_bw_gbps=triton_reduction_mem_bw_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape", + "power_of_2_scales", + "torch_time_us", + "triton_atomic_time_us", + "triton_reduction_time_us", + "torch_mem_bw_gbps", + "triton_atomic_mem_bw_gbps", + "triton_reduction_mem_bw_gbps", + "triton_atomic_speedup", + "triton_reduction_speedup", + ] + rows = [] + for experiment in experiments: + input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1], experiment.config.input_shape[2]})" + rows.append( + [ + input_shape, + experiment.config.power_of_2_scales, + experiment.result.torch_time_us, + experiment.result.triton_atomic_time_us, + experiment.result.triton_reduction_time_us, + round(experiment.result.torch_mem_bw_gbps, 3), + round(experiment.result.triton_atomic_mem_bw_gbps, 3), + round(experiment.result.triton_reduction_mem_bw_gbps, 3), + f"{experiment.result.torch_time_us / experiment.result.triton_atomic_time_us:.2f}x", + f"{experiment.result.torch_time_us / experiment.result.triton_reduction_time_us:.2f}x", + ] + ) + print(tabulate(rows, headers=headers)) + + +def benchmark_cuda_function_in_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py b/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py new file mode 100644 index 0000000000..b57ca81d4c --- /dev/null +++ b/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm + +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_3d +from torchao.prototype.moe_training.scaled_grouped_mm import ( + _to_mxfp8_dim1_3d, +) +from torchao.prototype.mx_formats.mx_tensor import to_mx + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: tuple[int] + + +@dataclass(frozen=True) +class ExperimentResult: + # time + to_mx_us: float + cuda_2d_us: float + cuda_3d_us: float + # mem bw + to_mx_gbps: float + cuda_2d_gbps: float + cuda_3d_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 shapes. Input activations are scaled along K dim. + input_shapes = [ + (1, 8192, 5120), + (2, 8192, 5120), + (4, 8192, 5120), + (8, 8192, 5120), + (16, 8192, 5120), + (64, 8192, 5120), + ] + configs = [] + for shape in input_shapes: + configs.append( + ExperimentConfig( + input_shape=shape, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + block_size = 32 + input_shape = config.input_shape + input_tensor = torch.randn( + *input_shape, + dtype=torch.bfloat16, + device=device, + ) + + def using_to_mx(x: torch.Tensor) -> torch.Tensor: + # Reference implementation + s_d1_ref, y_d1_ref = to_mx( + # Transpose (E,N,K) to (E,K,N) so N is final dim, + # since to_mx scales along that dim + x.transpose(-2, -1).contiguous(), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Transpose tensors and scales back so we have effectively + # quantized input shape (E, N, K) along N + y_d1_ref = y_d1_ref.transpose(-2, -1) + s_d1_ref = s_d1_ref.transpose(-2, -1) + return y_d1_ref, s_d1_ref + + # bench to_mx + using_to_mx_c = torch.compile(using_to_mx) + scales_to_mx, data_to_mx = using_to_mx_c(input_tensor) + to_mx_time_us = benchmark_cuda_function_in_microseconds( + using_to_mx_c, + input_tensor, + ) + + # bench 2d dim1 kernel then transforming to col major + using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d) + scales_cuda_2d, data_cuda_2d = using_cuda_2d_c(input_tensor) + time_cuda_2d_us = benchmark_cuda_function_in_microseconds( + using_cuda_2d_c, + input_tensor, + ) + + # bench 3d cuda kernel + data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(input_tensor) + time_cuda_3d_us = benchmark_cuda_function_in_microseconds( + mxfp8_quantize_cuda_3d, + input_tensor, + ) + + # mem bw calculations + bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + data_cuda_3d.numel() * bytes_per_output_el + + scales_cuda_3d.numel() * bytes_per_scale_el + ) + + to_mx_gbps = ((read_bytes + write_bytes) / 1e9) / (to_mx_time_us / 1e6) + cuda_2d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_2d_us / 1e6) + cuda_3d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_3d_us / 1e6) + + return ExperimentResult( + # time + to_mx_us=to_mx_time_us, + cuda_2d_us=time_cuda_2d_us, + cuda_3d_us=time_cuda_3d_us, + # mem bw + to_mx_gbps=to_mx_gbps, + cuda_2d_gbps=cuda_2d_gbps, + cuda_3d_gbps=cuda_3d_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape", + "to_mx_us", + "cuda_2d_us", + "cuda_3d_us", + "to_mx_gbps", + "cuda_2d_gbps", + "cuda_3d_gbps", + ] + rows = [] + for experiment in experiments: + rows.append( + [ + str(experiment.config.input_shape), + experiment.result.to_mx_us, + experiment.result.cuda_2d_us, + experiment.result.cuda_3d_us, + round(experiment.result.to_mx_gbps, 3), + round(experiment.result.cuda_2d_gbps, 3), + round(experiment.result.cuda_3d_gbps, 3), + ] + ) + print(tabulate(rows, headers=headers)) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py new file mode 100644 index 0000000000..b02124b782 --- /dev/null +++ b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm + +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.moe_training.kernels.mxfp8 import ( + compute_blocked_scale_offsets_for_M_groups, + torch_to_blocked_2d_M_groups, + triton_mx_block_rearrange_2d_M_groups, +) +from torchao.prototype.moe_training.utils import generate_jagged_offs + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: tuple[int] + num_groups: int + + +@dataclass(frozen=True) +class ExperimentResult: + torch_time_us: float + triton_time_us: float + torch_mem_bw_gbps: float + triton_mem_bw_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 shapes. Input activations are scaled along K dim. + block_size = 32 + input_shapes = [ + (16640, 5120 // block_size), + ] + num_groups = [16] + configs = [] + for shape, groups in itertools.product( + input_shapes, + num_groups, + ): + configs.append( + ExperimentConfig( + input_shape=shape, + num_groups=groups, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + input_shape, num_groups = config.input_shape, config.num_groups + input_tensor = torch.randint( + low=0, + high=256, + size=input_shape, + dtype=torch.uint8, + device=device, + ) + + Mg, K = input_shape + input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32) + _, output_group_offsets = compute_blocked_scale_offsets_for_M_groups( + input_group_offsets + ) + + # bench torch + compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups) + torch_out_scales, torch_group_offs = compiled_run_torch( + input_tensor, input_group_offsets, K + ) + torch_time_us = benchmark_cuda_function_in_microseconds( + compiled_run_torch, + input_tensor, + input_group_offsets, + K, + ) + + # bench triton + triton_out_scales = triton_mx_block_rearrange_2d_M_groups( + input_tensor, + input_group_offsets, + output_group_offsets, + ) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_mx_block_rearrange_2d_M_groups, + input_tensor, + input_group_offsets, + output_group_offsets, + ) + + # mem bw calculations + bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = triton_out_scales.numel() * bytes_per_output_el + + torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + torch_time_us=torch_time_us, + triton_time_us=triton_time_us, + torch_mem_bw_gbps=torch_mem_bw_gbps, + triton_mem_bw_gbps=triton_mem_bw_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape", + "torch_time_us", + "triton_time_us", + "torch_mem_bw_gbps", + "triton_mem_bw_gbps", + "triton_speedup", + ] + rows = [] + for experiment in experiments: + input_shape = ( + f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})" + ) + rows.append( + [ + input_shape, + experiment.result.torch_time_us, + experiment.result.triton_time_us, + round(experiment.result.torch_mem_bw_gbps, 3), + round(experiment.result.triton_mem_bw_gbps, 3), + f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", + ] + ) + print(tabulate(rows, headers=headers)) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py new file mode 100644 index 0000000000..296270fe62 --- /dev/null +++ b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_per_group_3d.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm + +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.moe_training.kernels.mxfp8 import ( + torch_to_blocked_per_group_3d, + triton_mx_block_rearrange_per_group_3d, +) + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: tuple[int] + + +@dataclass(frozen=True) +class ExperimentResult: + torch_time_us: float + triton_time_us: float + torch_mem_bw_gbps: float + triton_mem_bw_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 shapes. Input activations are scaled along K dim. + block_size = 32 + input_shapes = [ + # w1, w3 scaled along K (fwd) + (1, 8192, 5120 // block_size), + (2, 8192, 5120 // block_size), + (4, 8192, 5120 // block_size), + (8, 8192, 5120 // block_size), + (16, 8192, 5120 // block_size), + # w2 scaled along K (fwd) + (1, 5120, 8192 // block_size), + (2, 5120, 8192 // block_size), + (4, 5120, 8192 // block_size), + (8, 5120, 8192 // block_size), + (16, 5120, 8192 // block_size), + ] + configs = [] + for shape in input_shapes: + configs.append( + ExperimentConfig( + input_shape=shape, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + input_tensor = torch.randint( + low=0, + high=256, + size=config.input_shape, + dtype=torch.uint8, + device=device, + ) + + def warmup(fn, *args, **kwargs): + for _ in range(5): + fn(*args, **kwargs) + + E, N, K = config.input_shape + + # bench torch + compiled_run_torch = torch.compile(torch_to_blocked_per_group_3d) + warmup(compiled_run_torch, input_tensor) + torch_time_us = benchmark_cuda_function_in_microseconds( + compiled_run_torch, + input_tensor, + ) + + # bench triton + triton_out_scales = triton_mx_block_rearrange_per_group_3d(input_tensor) + warmup(triton_mx_block_rearrange_per_group_3d, input_tensor) + triton_time_us = benchmark_cuda_function_in_microseconds( + triton_mx_block_rearrange_per_group_3d, + input_tensor, + ) + + # mem bw calculations + bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = triton_out_scales.numel() * bytes_per_output_el + + torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) + + return ExperimentResult( + torch_time_us=torch_time_us, + triton_time_us=triton_time_us, + torch_mem_bw_gbps=torch_mem_bw_gbps, + triton_mem_bw_gbps=triton_mem_bw_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape", + "torch_time_us", + "triton_time_us", + "torch_mem_bw_gbps", + "triton_mem_bw_gbps", + "triton_speedup", + ] + rows = [] + for experiment in experiments: + input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]}, {experiment.config.input_shape[2]})" + rows.append( + [ + input_shape, + experiment.result.torch_time_us, + experiment.result.triton_time_us, + round(experiment.result.torch_mem_bw_gbps, 3), + round(experiment.result.triton_mem_bw_gbps, 3), + f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", + ] + ) + print(tabulate(rows, headers=headers)) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 6a1f4e8efb..2e1243d1d9 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -166,8 +166,8 @@ def insert_rmsnorm(module: torch.nn.Module): insert_rmsnorm(model.layers) # don't apply int8_mixed_precision to LM head, since it can cause convergence issue. - # TODO: might want to do the same for int8_weight_only to standardize. - if args.quantize == "int8_weight_only": + # TODO: might want to do the same for Int8WeightOnlyConfig to standardize. + if args.quantize == "Int8WeightOnlyConfig": quantize_( model, int8_weight_only_quantized_training(), set_inductor_config=False ) diff --git a/benchmarks/utils.py b/benchmarks/utils.py new file mode 100644 index 0000000000..c59142d571 --- /dev/null +++ b/benchmarks/utils.py @@ -0,0 +1,76 @@ +import torch +from torch.nn import functional as F +from triton.testing import do_bench + + +def bench_fwd_bwd_microseconds( + fn, *args, labels=None, use_compile=False, fullgraph=True, **kwargs +): + assert labels is not None + + def fwd_bwd(*args, **kwargs): + out = fn(*args, **kwargs) + loss = F.mse_loss(out, labels) + loss.backward() + + fwd_bwd_compiled = ( + torch.compile(fwd_bwd, fullgraph=fullgraph) if use_compile else fwd_bwd + ) + return benchmark_cuda_function_in_microseconds( + fwd_bwd_compiled, + *args, + **kwargs, + ) + + +def bench_fwd_microseconds(fn, *args, use_compile=False, fullgraph=True, **kwargs): + fn_compiled = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn + + def inference_fn(*args, **kwargs): + with torch.no_grad(): + return fn_compiled(*args, **kwargs) + + return benchmark_cuda_function_in_microseconds( + inference_fn, + *args, + **kwargs, + ) + + +def profile_fwd_bwd( + fn, + *args, + labels=None, + use_compile=False, + fullgraph=True, + profile_name="profile", + **kwargs, +): + assert labels is not None + fn = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn + wait, warmup, active = 1, 3, 1 + total_steps = wait + warmup + active + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=0 + ), + record_shapes=True, + with_stack=True, + ) as prof: + for _ in range(total_steps): + out = fn(*args, **kwargs) + loss = F.mse_loss(out, labels) + loss.backward() + prof.step() + + # Save profiler results + prof.export_chrome_trace(f"{profile_name}.json") + print(f"Saved: {profile_name}.json") + + +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 diff --git a/docs/source/api_ref_qat.rst b/docs/source/api_ref_qat.rst new file mode 100644 index 0000000000..e0cacab667 --- /dev/null +++ b/docs/source/api_ref_qat.rst @@ -0,0 +1,64 @@ +.. _api_qat: + +======================== +torchao.quantization.qat +======================== + +.. currentmodule:: torchao.quantization.qat + +Main Config for quantize_ +--------------------------------------- +For a full example of how to use QAT with our main `quantize_` API, +please refer to the `QAT README `__. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + QATConfig + QATStep + +Custom QAT APIs +--------------- +.. autosummary:: + :toctree: generated/ + :nosignatures: + + FakeQuantizeConfigBase + IntxFakeQuantizeConfig + Float8FakeQuantizeConfig + FakeQuantizedLinear + FakeQuantizedEmbedding + FakeQuantizerBase + IntxFakeQuantizer + Float8FakeQuantizer + linear.enable_linear_fake_quant + linear.disable_linear_fake_quant + +Legacy QAT APIs +--------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + IntXQuantizationAwareTrainingConfig + FromIntXQuantizationAwareTrainingConfig + Int4WeightOnlyQATQuantizer + linear.Int4WeightOnlyQATLinear + Int8DynActInt4WeightQATQuantizer + linear.Int8DynActInt4WeightQATLinear + Int4WeightOnlyEmbeddingQATQuantizer + embedding.Int4WeightOnlyQATEmbedding + embedding.Int4WeightOnlyEmbedding + Float8ActInt4WeightQATQuantizer + ComposableQATQuantizer + +Prototype +--------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + initialize_fake_quantizers diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index f2fad00b69..c163a4b06a 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -24,6 +24,7 @@ Inference APIs for quantize\_ :nosignatures: Int4WeightOnlyConfig + Float8DynamicActivationInt4WeightConfig Float8DynamicActivationFloat8WeightConfig Float8WeightOnlyConfig Float8StaticActivationFloat8WeightConfig @@ -34,24 +35,6 @@ Inference APIs for quantize\_ UIntXWeightOnlyConfig FPXWeightOnlyConfig -.. currentmodule:: torchao.quantization.qat - -QAT APIs ----------------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - - IntXQuantizationAwareTrainingConfig - FromIntXQuantizationAwareTrainingConfig - FakeQuantizeConfig - Int4WeightOnlyQATQuantizer - Int8DynActInt4WeightQATQuantizer - Int4WeightOnlyEmbeddingQATQuantizer - ComposableQATQuantizer - initialize_fake_quantizers - .. currentmodule:: torchao.quantization Quantization Primitives diff --git a/docs/source/api_ref_utils.rst b/docs/source/api_ref_utils.rst new file mode 100644 index 0000000000..3e85bbb424 --- /dev/null +++ b/docs/source/api_ref_utils.rst @@ -0,0 +1,33 @@ +.. _api_utils: + + +============= +torchao.utils +============= + +.. currentmodule:: torchao.utils + +Tensor Subclass Utils +--------------------- +.. autosummary:: + :toctree: generated/ + :nosignatures: + + TorchAOBaseTensor + +===================================== +torchao.quantization.quantize_.common +===================================== + +.. currentmodule:: torchao.quantization.quantize_.common + +quantize_ API Common Utils +-------------------------- +.. autosummary:: + :toctree: generated/ + :nosignatures: + + KernelPreference + PackingFormat + QuantizeTensorKwargs + _choose_quant_func_and_quantize_tensor diff --git a/docs/source/benchmarking_api_guide.md b/docs/source/benchmarking_api_guide.md new file mode 100644 index 0000000000..bd81a7f65f --- /dev/null +++ b/docs/source/benchmarking_api_guide.md @@ -0,0 +1,213 @@ +# Benchmarking API Guide + +This tutorial will guide you through using the TorchAO benchmarking framework. The tutorial contains integrating new APIs with the framework and dashboard. + +1. [Add an API to benchmarking recipes](#add-an-api-to-benchmarking-recipes) +2. [Add a model architecture for benchmarking recipes](#add-a-model-to-benchmarking-recipes) +3. [Add an HF model to benchmarking recipes](#add-an-hf-model-to-benchmarking-recipes) +4. [Add an API to micro-benchmarking CI dashboard](#add-an-api-to-benchmarking-ci-dashboard) + +## Add an API to Benchmarking Recipes + +The framework currently supports quantization and sparsity recipes, which can be run using the quantize_() or sparsity_() functions: + +To add a new recipe, add the corresponding string configuration to the function `string_to_config()` in `benchmarks/microbenchmarks/utils.py`. + +```python +def string_to_config( + quantization: Optional[str], sparsity: Optional[str], **kwargs +) -> AOBaseConfig: + +# ... existing code ... + +elif quantization == "my_new_quantization": + # If additional information needs to be passed as kwargs, process it here + return MyNewQuantizationConfig(**kwargs) +elif sparsity == "my_new_sparsity": + return MyNewSparsityConfig(**kwargs) + +# ... rest of existing code ... +``` + +Now we can use this recipe throughout the benchmarking framework. + +**Note:** If the `AOBaseConfig` uses input parameters, like bit-width, group-size etc, you can pass them appended to the string config in input. For example, for `GemliteUIntXWeightOnlyConfig` we can pass bit-width and group-size as `gemlitewo--` + +## Add a Model to Benchmarking Recipes + +To add a new model architecture to the benchmarking system, you need to modify `torchao/testing/model_architectures.py`. + +1. To add a new model type, define your model class in `torchao/testing/model_architectures.py`: + +```python +class MyCustomModel(torch.nn.Module): + def __init__(self, input_dim, output_dim, dtype=torch.bfloat16): + super().__init__() + # Define your model architecture + self.layer1 = torch.nn.Linear(input_dim, 512, bias=False).to(dtype) + self.activation = torch.nn.ReLU() + self.layer2 = torch.nn.Linear(512, output_dim, bias=False).to(dtype) + + def forward(self, x): + x = self.layer1(x) + x = self.activation(x) + x = self.layer2(x) + return x +``` + +2. Update the `create_model_and_input_data` function to handle your new model type: + +```python +def create_model_and_input_data( + model_type: str, + m: int, + k: int, + n: int, + high_precision_dtype: torch.dtype = torch.bfloat16, + device: str = "cuda", + activation: str = "relu", +): + # ... existing code ... + + elif model_type == "my_custom_model": + model = MyCustomModel(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + + # ... rest of existing code ... +``` + +### Model Design Considerations + +When adding new models: + +- **Input/Output Dimensions**: Ensure your model handles the (m, k, n) dimension convention where: + - `m`: Batch size or sequence length + - `k`: Input feature dimension + - `n`: Output feature dimension + +- **Data Types**: Support the `high_precision_dtype` parameter (typically `torch.bfloat16`) + +- **Device Compatibility**: Ensure your model works on CUDA, CPU, and other target devices + +- **Quantization Compatibility**: Design your model to work with TorchAO quantization methods + +## Add an HF model to benchmarking recipes +(Coming soon!!!) + +## Add an API to Benchmarking CI Dashboard + +To integrate your API with the CI [dashboard](https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fao&benchmarkName=micro-benchmark+api): + +### 1. Modify Existing CI Configuration + +Add your quantization method to the existing CI configuration file at `benchmarks/dashboard/microbenchmark_quantization_config.yml`: + +```yaml +# benchmarks/dashboard/microbenchmark_quantization_config.yml +benchmark_mode: "inference" +quantization_config_recipe_names: + - "int8wo" + - "int8dq" + - "float8dq-tensor" + - "float8dq-row" + - "float8wo" + - "my_new_quantization" # Add your method here + +output_dir: "benchmarks/microbenchmarks/results" + +model_params: + - name: "small_bf16_linear" + matrix_shapes: + - name: "small_sweep" + min_power: 10 + max_power: 15 + high_precision_dtype: "torch.bfloat16" + torch_compile_mode: "max-autotune" + device: "cuda" + model_type: "linear" +``` + +### 2. Run CI Benchmarks + +Use the CI runner to generate results in PyTorch OSS benchmark database format: + +```bash +python benchmarks/dashboard/ci_microbenchmark_runner.py \ + --config benchmarks/dashboard/microbenchmark_quantization_config.yml \ + --output benchmark_results.json +``` + +### 3. CI Output Format + +The CI runner outputs results in a specific JSON format required by the PyTorch OSS benchmark database: + +```json +[ + { + "benchmark": { + "name": "micro-benchmark api", + "mode": "inference", + "dtype": "int8wo", + "extra_info": { + "device": "cuda", + "arch": "NVIDIA A100-SXM4-80GB" + } + }, + "model": { + "name": "1024-1024-1024", + "type": "micro-benchmark custom layer", + "origins": ["torchao"] + }, + "metric": { + "name": "speedup(wrt bf16)", + "benchmark_values": [1.25], + "target_value": 0.0 + }, + "runners": [], + "dependencies": {} + } +] +``` + +### 4. Integration with CI Pipeline + +To integrate with your CI pipeline, add the benchmark step to your workflow: + +```yaml +# Example GitHub Actions step +- name: Run Microbenchmarks + run: | + python benchmarks/dashboard/ci_microbenchmark_runner.py \ + --config benchmarks/dashboard/microbenchmark_quantization_config.yml \ + --output benchmark_results.json + +- name: Upload Results + # Upload benchmark_results.json to your dashboard system +``` + +## Troubleshooting + +### Running Tests + +To verify your setup and run the test suite: + +```bash +python -m unittest discover benchmarks/microbenchmarks/test +``` + +### Common Issues + +1. **CUDA Out of Memory**: Reduce batch size or matrix dimensions +2. **Missing Quantization Methods**: Ensure TorchAO is properly installed +3. **Device Not Available**: Check device availability and drivers + +### Best Practices + +1. Use `small_sweep` for basic testing, `custom shapes` for comprehensive or model specific analysis +2. Enable profiling only when needed (adds overhead) +3. Test on multiple devices when possible +4. Use consistent naming conventions for reproducibility + +For information on different use-cases for benchmarking, refer to [Benchmarking User Guide](benchmarking_user_guide.md) + +For more detailed information about the framework components, see the README files in the `benchmarks/microbenchmarks/` directory. diff --git a/docs/source/benchmarking_user_guide.md b/docs/source/benchmarking_user_guide.md new file mode 100644 index 0000000000..cff53ab8fd --- /dev/null +++ b/docs/source/benchmarking_user_guide.md @@ -0,0 +1,5 @@ +# Benchmarking User Guide + +This guide is intended to provide instructions for the most fequent benchmarking use-case. If you have any use-case that is not answered here, please create an issue here: [TorchAO Issues](https://github.com/pytorch/ao/issues) + +[Coming Soon !!!] diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index ab6d433e27..353ba754ca 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -4,16 +4,34 @@ Contributor Guide General Guide on Extending torchao ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For a new use case, for example, a training dtype (like fp4 training), it's fine to start with adding a new tensor subclass in prototype folder `torchao/prototype `__, but you could also take a look at ``AffineQuantizedTensor`` if what you want to do is mostly supported there, e.g. adding int3 kernel for the exact same affine quantization. Please feel free to open an issue and if you have questions on what to do for a specific new use case. For more details, please refer to our `quantization overview page `__. +Please start by reading our `quantization overview page `__ first. To contribute to existing code base: -* Adding features to AffineQuantizedTensor, e.g. making it trainable, add tensor parallelism support etc.: `torchao/dtypes/affine_quantized_tensor.py `__ +* Adding a new Tensor: `torchao/quantization/quantize_/workflows `__ * Adding new quantization APIs: `torchao/quantization/quant_api.py `__ +* Adding features to existing Tensor subclasses like ``Float8Tensor``, e.g. adding new operator support, making it trainable, add tensor parallelism support etc., `tensor subclasses `__, `tests `__ * Adding new quantization primitive ops, e.g. slight variations of existing quantization primitive ops: `torchao/quantization/quant_primitives.py `__ * Adding new autotuned triton kernels: `torchao/kernel `__ * Adding new custom cpu/cuda/mps kernels: `torchao/csrc `__ -* Integrating custom kernel with AffineQuantizedTensor (maybe a new layout as well): Add sparse marlin AQT layout `#621 `__ as an example. We are still not decided if we want to split ``AffineQuantizedTensor`` to more tensor subclasses or not. + +Adding New Tensor Subclasses +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +torchao Tensor subclasses are structured by ``derived dtype`` and ``packing format``, please check out the `quantization overview page `__ to understand these concepts. If a new tensor subclass is needed for your use case, i.e. a new dtype, or a new packing format that does not already exist, we could define a new Tensor. + +To understand how to use tensor subclass in the context of quantization, please also check `Writing Your Own Quantized Tensor `__. + +We have utility base class: ``torchao.utils.TorchAOBaseTensor`` that can help define common util functions and methods for you, if you specified the names of Tensor and non-Tensor attributes of the tensor subclass. for example:: + + class MyTensor(TorchAOBaseTensor): + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["device", "dtype"] + + +With the above, we'll have multiple methods and functions available to use for this Tensor, for more details please check the docs for `TorchAOBaseTensor `__ + +.. note:: + Many of the existing use cases in torchao still uses AffineQuantizedTensor, but we plan to move away from it to reduce the abstractions and make it easier for people to contribute to torchao. Adding Efficient Kernels ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -31,50 +49,59 @@ Custom hand written kernels ########################### Custom kernels (implementations) for cpu/cuda/mps can be implemented through `torchao/csrc `__ e.g. int4 cuda, and accessible through torch.ops.my_custom_op -Dispatches -~~~~~~~~~~ +Using hand written kernels in Tensor Subclasses +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For calling optimized kernels, we have ``implements`` from the tensor subclass, for example, if we want to call into a new custom op: ``torch.ops.torchao.my_mm_for_mps``:: + + class Float8Tensor(TorchAOBaseTensor): + ... + + implements = Float8Tensor.implements -For dispatching to optimized kernels for cpu/cuda/mps devices, we can have checks for the dispatch conditions in ``__torch_function__`` or ``__torch_dispatch__`` and dispatch to target operators, for example, condition for bfloat16 activation and uint4 weight kernel can be found `here `__. + @implements([torch.nn.functional.linear, aten.linear.default]) + def _(func, types, args, kwargs): + ... + # call into the custom op + res = torch.ops.torchao.my_mm_for_mps(input_tensor.qdata, weight_tensor.qdata, input_tensor.scale, weight_tensor.scale) + return res -Specifically for ``AffineQuantizedTensor``, we also allow people to extend the quantized linear to use a new efficient kernel or implement by defining two functions: -``dispatch_condition`` (defines the condition to dispatch to the kernel) and impl (actual implementation that takes activation, (quantized) weight, bias Tensor and runs the efficient kernel), both taking ``input_tensor``, ``weight_tensor``, ``bias`` as argument, and can be registered into dispatch of quantized linear in ``AffineQuantizedTensor`` with ``register_aqt_quantized_linear_dispatch``. `Here `__ is an example showing how it works. +KernelPreference +################ -Layout/TensorImpl -~~~~~~~~~~~~~~~~~ +For some tensor subclasses, there could be multiple kernel choices for quantize and mm etc. The recommended way to handle this in torchao tensor subclasses is through ``KernelPreference``, that represents which group of kernels we want to use for quantize, mm, group_mm etc. We can use use ``KernelPreference.AUTO`` as default option, as the option for developers to choose whatever we think is the fastest under different conditions for user, so user don't need to worry about the details, and we can have other more specific kernel options for debugging purposes. + +``Float8Tensor`` for example, has: + +* ``KernelPreference.AUTO`` that will choose the most performant quantize and mm kernel based on hardware (H100 SM89 or SM90+), availability of libraries (whether ``fbgemm_gpu_genai`` is installed), granularity (per row or per tensor) +* ``KernelPreference.TORCH`` will use torchao quantize op (``_choose_scale_float8`` and ``_quantize_affine_float8``) and ``_scaled_mm`` +* ``Kerenel.FBGEMM`` uses fbgemm quantize and mm op (``torch.ops.fbgemm.f8f8bf16_rowwise``) -Sometimes the quantized weights has to be packed in order to yield optimal performance. And this can be abstracted with ``layout``. See `here `__ for full example. Flow ~~~~ -After the tensor subclass is implemented, we can also wrap that into factory functions, e.g.:: - # convert from floating point tensor to my dtype tensor subclass - to_my_dtype = MyDTypeTensor.from_float - -For model level API, people can reuse ``torchao.quantize_`` that allows people to apply a tensor subclass conversion to weight of linear, and allows `filtering function `__ to choose which module the tensor subclass conversion should be applied to. +For model level API, people can reuse ``torchao.quantize_`` that allows people to apply a tensor subclass conversion to weight of linear, and allows `filtering function `__ to choose which module the tensor subclass conversion should be applied to. -See Quantization Algorithms/Flows section for examples of weight only/dynamic quant/static quant and other types of model level APIs based on the factory function. +See Quantization Algorithms/Flows section for examples of weight only/dynamic quant and other types of model level APIs. Using torch.compile for Performance ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Note: for pytorch 2.4 and below, we need to use the following:: - from torchao.utils import unwrap_tensor_subclass - m_unwrapped = unwrap_tensor_subclass(m) - In order to be compatible with ``torch.compile``, to aim for performance optimization, we should run through ``torch.compile`` with ``fullgraph=True`` first, and remove any unnecessary graph breaks. You can add ``TORCH_LOGS="output_code"`` when you run the script in order to see the inductor generated code. e.g. ``TORCH_LOGS="output_code" python example.py``:: + model = torch.compile(model, mode="max-autotune", fullgraph=True) Serialization ~~~~~~~~~~~~~ -Please checkout the `serialization doc `__ for more details. +To enable support for serialization (torch.save and torch.load with tensor subclasses as weights), we need to add the tensor subclass and the relevant object to safe globals (available after torch 2.5), e.g.:: + torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs]) -.. note:: - We are integrated with huggingface transformer and supports serialization/deserialization through the huggingface save_pretrained/push_to_hub/from_pretrained APIs: https://huggingface.co/docs/transformers/main/en/quantization/torchao +Please checkout the `serialization doc `__ for more details. .. note:: - Another example can be found in integration with diffuser: https://github.com/sayakpaul/diffusers-torchao/blob/main/inference/serialization_and_loading.md + We are `integrated `__ with huggingface transformer and supports serialization and deserialization through the huggingface ``save_pretrained``, ``push_to_hub`` and ``from_pretrained`` APIs. We also have `serialization examples `__ with diffuser models. Other Feature Support @@ -85,8 +112,6 @@ The above just talks about basic feature support, we also provide examples on ho * `Quantized Training `__ * `Tensor Parallel Support for Quantized Tensor `__ * `Compatibility with executorch / torchchat `__ -* [TODO] FSDP -* [TODO] QAT Tensor Subclass Functionality/Composability Testing @@ -126,11 +151,16 @@ After you have the quantization flow implemented, you can run benchmark and eval Note: llama model (llama2/llama3) is our representative model for memory bound models and sam is our representative model for compute bound models. * `llama `__ + * `benchmark `__ * `eval `__ + * `sam `__ + * `benchmark and eval `__ Please checkout the ``--help`` option for each of the script to understand the supported options, e.g. you can use ``--profile=profile_path`` to get the chrome trace of the run to understand detailed `chrome trace `__. Please let us know if there are any new important models that makes sense to be added to torchao model benchmark/eval folder. + +Please also check out `Benchmarking User Guide `__ and `Benchmarking API Guide `__ to understand how to use our benchmarking framework. diff --git a/docs/source/finetuning.rst b/docs/source/finetuning.rst index 00e2471e7f..69567af5be 100644 --- a/docs/source/finetuning.rst +++ b/docs/source/finetuning.rst @@ -205,21 +205,14 @@ because we are not actually casting the fake quantized values. .. code:: py - from torchao.quantization import ( - quantize_, - ) - from torchao.quantization.qat import ( - FakeQuantizeConfig, - IntXQuantizationAwareTrainingConfig, - ) + from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig + from torchao.quantization.qat import QATConfig + model = get_model() - # prepare: insert fake quantization ops - # swaps `torch.nn.Linear` with `FakeQuantizedLinear` - activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - weight_config = FakeQuantizeConfig(torch.int4, group_size=32) - qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config) - quantize_(model, qat_config) + # prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear` + base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) + quantize_(model, QATConfig(base_config, step="prepare")) # fine-tune train_loop(model) @@ -232,18 +225,12 @@ The next step is to actually quantize the model: .. code:: py - from torchao.quantization import ( - Int8DynamicActivationInt4WeightConfig, - ) - from torchao.quantization.qat import ( - FromIntXQuantizationAwareTrainingConfig, - ) + from torchao.quantization import Int8DynamicActivationInt4WeightConfig - # convert: transform fake quantization ops into actual quantized ops - # swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts - # quantized activation and weight tensor subclasses - quantize_(model, FromIntXQuantizationAwareTrainingConfig()) - quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) + # convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config` + quantize_(model, QATConfig(base_config, step="convert")) + + # inference or generate Now our model is ready for serving, and will typically have higher quantized accuracy than if we did not apply the prepare step (fake quantization) during @@ -284,10 +271,123 @@ schemes, but these are not customizable unlike the above example. Quantized Low-Rank Adaptation (QLoRA) ##################################### -(Coming soon!) +Low-Rank Adaptation (LoRA) refers to freezing the original model, +and instead training a set of new "adapter" parameters that are a +small fraction of the original parameters, thereby significantly +reducing the memory footprint during training. QLoRA is an extension +of LoRA that additionally quantizes the frozen original model +parameters to 4-bits, thereby further reducing the memory footprint. + +TorchAO offers an implementation of the NF4 data type proposed in +the original `QLoRA paper `__. +This implementation expresses NF4 as a tensor subclass through the +`NF4Tensor `__, +which composes cleanly with other PyTorch features like `torch.compile` +and FSDP2. Users can convert a high precision tensor to NF4 simply +by calling `torchao.dtypes.to_nf4 `__. +For example: + +.. code:: + + class FrozenNF4Linear(nn.Linear): + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + **quantization_kwargs, + ): + super().__init__(in_dim, out_dim, bias=bias, device=device, dtype=dtype) + # No need to train these in QLoRA + self.weight.requires_grad_(False) + if self.bias is not None: + self.bias.requires_grad_(False) + nf4_weight = to_nf4(self.weight, **quantization_kwargs) + self.weight = torch.nn.Parameter(nf4_weight, requires_grad=False) + +QLoRA need not work with NF4 specifically, though NF4 has been +shown to achieve competitive results compared to bf16 baselines +while significantly reducing the memory required for training. +This technique can also compose with other lower bit dtypes +such as regular INT4 or even newer `MXFP4 or NVFP4 `__ +targeting Blackwell GPUs to reap similar memory benefits with +varying tradeoffs. + +Option 1: TorchTune Integration +=============================== + +TorchTune incorporates the `NF4Tensor` in its QLoRA fine-tuning +recipe through their implementation of `LoRALinear `__. +You can also try it out by running the following command, +or refer to their `QLoRA tutorial `__ +for more details. + +.. code:: + + tune run lora_finetune_single_device --config llama3_2/3B_qlora_single_device.yaml + +Option 2: HuggingFace PEFT Integration +====================================== + +`HuggingFace PEFT `__ +also has a limited version of QLoRA leveraging TorchAO's INT8 +quantization, though INT4 or NF4 are not supported yet. Users +can invoke this functionality by preparing their models as follows. +For full details, please refer to `this tutorial `__. + +.. code:: + + from peft import LoraConfig, get_peft_model + from transformers import AutoModelForCausalLM, TorchAoConfig + from torchao.quantization import Int8WeightOnlyConfig + + base_model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + quantization_config=TorchAoConfig(Int8WeightOnlyConfig()), + ) + peft_config = LoraConfig() + model = get_peft_model(base_model, peft_config) Float8 Quantized Fine-tuning ############################ -(Coming soon!) +Similar to `pre-training `__, we can also +leverage float8 in fine-tuning for higher training throughput +with no accuracy degradation and no increase in memory usage. +Float8 training is integrated into TorchTune's distributed +full fine-tuning recipe, leveraging the same APIs as our +integration with TorchTitan. Users can invoke this fine-tuning +recipe as follows: + +.. code:: + + tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3_2/3B_full + enable_fp8_training=true \ + fp8_recipe_name=tensorwise \ + compile=True + +Initial experiments saw up to 16.5% throughput improvement +for fine-tuning Llama3.2-3B in float8: + +.. code:: + + experiment_name tok/s peak_mem_reserved + ---------------------- ------------------- ------------------- + bf16 6502.143 (+0.000%) 30.090 (+0.000%) + fp8_noname 7205.386 (+10.816%) 30.010 (-0.266%) + fp8_tensorwise 7222.198 (+11.074%) 30.010 (-0.266%) + fp8_rowwise 6387.968 (-1.756%) 29.158 (-3.096%) + fp8_rowwise_with_gw_hp 7573.698 (+16.480%) 29.516 (-1.908%) + + experiment_name hellaswag_acc wikitext_word_perplexity + ---------------------- --------------- -------------------------- + bf16 0.533 (+0.000) 12.407 (+0.000) + fp8_noname 0.533 (+0.000) 12.414 (+0.007) + fp8_tensorwise 0.533 (+0.000) 12.412 (+0.005) + fp8_rowwise 0.533 (-0.000) 12.420 (+0.013) + fp8_rowwise_with_gw_hp 0.534 (+0.001) 12.416 (+0.009) + +Please refer to the `pre-training `__ tutorial for more details. diff --git a/docs/source/index.rst b/docs/source/index.rst index aac72590fd..0a96600b70 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -18,9 +18,11 @@ for an overall introduction to the library and recent highlight and updates. :maxdepth: 1 :caption: Developer Notes - quantization - sparsity + quantization_overview contributor_guide + sparsity + benchmarking_api_guide + benchmarking_user_guide .. toctree:: :glob: @@ -29,8 +31,10 @@ for an overall introduction to the library and recent highlight and updates. api_ref_dtypes api_ref_quantization + api_ref_qat api_ref_sparsity api_ref_float8 + api_ref_utils .. toctree:: :glob: @@ -41,6 +45,7 @@ for an overall introduction to the library and recent highlight and updates. finetuning serving torchao_vllm_integration + torchao_hf_integration serialization static_quantization subclass_basic @@ -55,5 +60,5 @@ for an overall introduction to the library and recent highlight and updates. tutorials_source/pt2e_quant_qat tutorials_source/pt2e_quant_x86_inductor tutorials_source/pt2e_quant_xpu_inductor + tutorials_source/pt2e_quant_openvino_inductor tutorials_source/pt2e_quantizer - tutorials_source/openvino_quantizer diff --git a/docs/source/output.png b/docs/source/output.png new file mode 100644 index 0000000000..cf7ebfeccd Binary files /dev/null and b/docs/source/output.png differ diff --git a/docs/source/pretraining.rst b/docs/source/pretraining.rst index da9659b9a0..2f60719ec5 100644 --- a/docs/source/pretraining.rst +++ b/docs/source/pretraining.rst @@ -161,10 +161,6 @@ Below is a code snippet showing how to use it: from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_linear import Float8Linear from torchao.float8 import convert_to_float8_training - from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - - if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = nn.Sequential( diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst deleted file mode 100644 index 929bc1d00c..0000000000 --- a/docs/source/quantization.rst +++ /dev/null @@ -1,243 +0,0 @@ -Quantization Overview ---------------------- - -First we want to lay out the torchao stack:: - - Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. - --------------------------------------------------------------------------------------------- - Quantized Tensors (derived dtypes): AffineQuantizedTensor, CodebookQuantizedTensor - --------------------------------------------------------------------------------------------- - Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize - --------------------------------------------------------------------------------------------- - Basic dtypes: uint1-uint7, int1-int8, float3-float8 - - -Any quantization algorithm will be using some components from the above stack, for example int4 weight-only quantization uses: -(1) weight only quantization flow -(2) `tinygemm bf16 activation + int4 weight kernel `__ and `quant primitive ops `__ -(3) `AffineQuantizedTensor `__ tensor subclass with `TensorCoreTiledLayout `__ -(4) torch.uint4 dtype (simulated with quant_min/quant_max right now) - -Note: we'll also talk about how to compose sparsity with quantization in the Quantized Tensors section - -Basic DTypes -~~~~~~~~~~~~ -`dtype `__ is a bit of overloaded term, by basic dtype, we mean the dtypes that makes sense without any extra metadata (e.g. makes sense when people call ``torch.empty(.., dtype)``), for more details please check out: dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833 - -No matter what quantization we are doing, in the end we will be using some low precision dtypes to represent the quantized data, the dtypes we aim to support in torchao are: - -* ``torch.uint1`` to ``torch.uint8`` available in pytorch 2.3 and later -* ``torch.int1`` to ``torch.int8`` available in pytorch 2.6 and later -* ``torch.float3_e2_m0``, ``torch.float4_e2_m1``, ``torch.float4_e3_m0``, ``torch.float5_e2_m2``, ``torch.float5_e3_m1``, ``torch.float6_e2_m3``, ``torch.float6_e3_m2``, ``torch.float8_e4m3fn``, ``torch.float8_e5m2``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2fnuz`` (float8 is added to torch, we also plan to add float4 and float6 to torch if they become popular) - -Note some of the above are prototype only for now. We'll consider adding then to pytorch core when they become popular and have hardware support. - -Current Support -############### -In terms of actual implementation, there are two parts: -1). In PyTorch, we need to add the dtype to torch.dtype, e.g. torch.uint2, example: pytorch/pytorch#117208, but these are just placeholders so that we can use torch.uint2. -2). Outside of PyTorch (e.g. in torchao), we implement the tensor operations for these dtypes with tensor subclasses, also a standard packing format is needed. - -Adding placeholder dtype in PyTorch -*********************************** - -As mentioned in dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833, the criteria for adding dtype in PyTorch is that it shows wide adoption. For the above mentioned fundamental dtypes, the ones that are supported in PyTorch are: - -* ``torch.uint1`` to ``torch.uint8``, ``torch.int1`` to ``torch.int8``, ``torch.float8_e4m3fn``, ``torch.float8_e5m2``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2fnuz`` - -For the other types we plan to wait until there is more evidence of wide adoption and hardware support. - -Implementing tensor operations for these dtypes with Tensor subclasses -********************************************************************** -For this, the requirement is we decide on a "standard" packing format, and hopefully one that is amenable to efficient implementation, but for both uintx and floatx we haven't integrate enough kernels to decide on this. So current `packing implementations `__ are ont final. We can revisit after there are more uintx, intx and floatx kernels being integrated into torchao. - -Integrate Tensor subclass to pytorch native factory functions -************************************************************* -After that we can connect the factory function with the tensor subclass, for example: ``torch.empty(..., dtype=torch.int4, ...)`` can create a ``Int4Tensor`` tensor subclass with the packing format decided in the previous step. - -Quantization Primitive Ops -~~~~~~~~~~~~~~~~~~~~~~~~~~ -Quantization primitive ops means the operators used to convert between low preicison quantized tensors and high precision tensors. We will mainly have the following quantization primitive operators: -choose_qparams ops: that chooses quantization parameter based on the original Tensor, typically used in dynamic quantization, e.g. scale and zero_point for affine quantization -quantize op: quantizes the original high precision tensor to the low precision tensor with the dtypes mentioned in previous section based on the quantization parameters -dequantize op: dequantizes the low precision tensor into the high precision tensor based on quantization parameters - -There could be variations of the above to accommodate specific use cases, for example for static quantization we may have ``choose_qparams_affine_with_min_max`` that will choose quantization parameters based on min/max values derived from the observation process. - -Efficient kernels -~~~~~~~~~~~~~~~~~ -We'll also have efficient kernels that works with the low precision tensors, for example - -`_weight_int4pack_mm `__ the tinygemm int4 kernel (bf16 activation + int4 weight) -`int_matmul `__ that takes two int8 tensors and outputs an int32 tensor -`int_scaled_matmul `__ that does matmul and also applies a scale to the result. - -Note: We can also rely on torch.compile to generate kernels (through triton), for example the current int8 weight only quantization `kernel `__ just relies on torch.compile to get speedup. In this case there is no specific "efficient kernel" that's corresponding to the type of quantization. - -Quantized Tensors (derived dtypes) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -On top of the basic dtypes, quantization primitive operators and efficient kernels, we can glue everything together and build out a Quantized (low precision) Tensor by subclassing torch.Tensor that can be constructed from a high precision Tensor and some parameters that can configure the specific quantization user wants, we can also call this derived dtypes since it can be represented with Tensors of basic dtypes and some extra metadata like scale. - -Existing example in torchao is ``AffineQuantizedTensor``, meaning the low precision Tensor is quantized from the high precision Tensor by an affine mapping, that is: ``low_precision_val = high_precision_val / scale + zero_point``, where ``scale``/``zero_point`` are the quantization parameters that can be calculated by quantization primitive ops or through some optimization procedure. Affine quantization is a very common type of quantization, since it's straightforward that when we try to map from higher precision values to lower precision values, we do an affine transformation (``high_preicsion_val / scale + zero_point``). Another common type of quantization, especially for lower bitwidths (e.g. lower than 4 bit) is codebook / look up table based quantization. - -Layout and TensorImpl -##################### -Native tensors have a hardcoded list of selections of `layout `__, most common one is strided layout, it provides a strided, multi-dimensional view of storage, we also have some sparse and mkldnn layout. - -Take `sparse COO tensor `__ as an example, it has `torch.sparse_coo` layout, and `SparseTensorImpl `__ which changes how the tensor is stored. - -The idea of packing the tensor into different formats fits nicely with the layout concept, that’s why we want to reuse this for packing. We can use `Layout` for different type of packing format and `TensorImpl` for different storage format implementations. And new TensorImpl that stores the Tensor in a packed format can be added at python level tensor subclasses without modifying C++ pytorch core code. - -For example, for ``_weight_int4pack_mm`` we need to pack the weight to an format that is friendly for Tensor Core, we call it `TensorCoreTiledLayout `__. We add a ``tensor_impl`` for the quantized tensor to store the packed (or unpacked) weight, and we use ``layout`` to store different parameters that's relevant for packing:: - - class AffineQuantizedTensor(...): - # tensor_impl is also implemented with tensor subclass - tensor_impl: torch.Tensor - - # to not conflict with existing layout property, we use `_layout` - @property - def _layout(self) -> Layout: - return self.tensor_impl._layout - -Note that layout is an abstraction not only for custom data representation, it is also used for how the -`TensorImpl` interacts with different operators, e.g. the same data representation can have different -implementations when running the same operator, e.g. transpose, quantized_linear, but the operator semantics should stay the same. - -Quantize + Sparse Tensor can also be supported through the Layout abstraction, for example, `int4 weight only quantization + sparse `__. We also provide some common utils that helps people to add different layouts to a quantized tensor, please check out the developer guide below for code examples. - -Quantization Algorithms/Flows -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -On the top of the stack will be the final quantization algorithms and quantization flows. Traditionally we have weight only quantization, dynamic quantization and static quantization, but now we are also seeing more types of quantization coming up. - -For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype. - -Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in ``Tensor Subclass Developer Guide`` section. - -Weight Only Quantization -######################## -This is the simplest form of quantization and it's easy to apply weight only quantization to the model, especially since we have Quantized Tensor. all we need to do is:: - linear_module.weight = torch.nn.Parameter(to_affine_quantized_intx(linear_module.weight, ...), requires_grad=False)) - -apply the above to all linear modules in the model and we'll get a weight only quantized model. - -Dynamic Activation and Weight Quantization -########################################## - -This is called "dynamic quantization" before but it means we quantize activation dynamically at runtime, and also quantize the weights as well. Compared to the weight only quantization, the main question is how do we apply the quantization to activation. In torchao, the common pattern we use is by applying ``to_linear_activation_quantized`` on top of quantized weight:: - quantized_weight = to_affine_quantized(linear_module.weight) - activation_and_weight_quantized = to_linear_activation_quantized(quantized_weight) - linear_module.weight = torch.nn.Parameter(activation_and_weight_quantized, requires_grad=False)) - -``to_linear_activation_quantized`` is used to apply quantization to activation, it takes a ``input_quant_func`` that will quantize the activation and the original weight, and during runtime when it encounters a ``F.linear`` op, it will apply the stored input_qunat_func to activation and redispatch to ``F.linear`` with quantized activation and weight. - -If the above does not work, user can also do module swaps, or use ``torch.fx.symbolic_trace()`` to get a traced module that you can `modify `__. - -But using tensor subclass is preferred because it is easier for serialization/deserialization, if we use tensor subclasses to support dynamic quantization, then we can load the quantized weights directly without further preparation for the model. Otherwise, we'd need to do module swap or other modifications to the model first before loading the quantized weights. - -Static Activation Quantization and Weight Quantization -###################################################### -Static quantization means activation is statically quantized instead of dynamically quantized at runtime. In terms of flow, static quantization requires calibration with sample data in order that we can figure out the appropriate quantization parameters. - -At the high level there are three steps for static quantization: (1) insert observers (2) calibration (3) quantize the model - - -Insert Observers -**************** -In insert observers step, we need to add observer modules to input (and output) activation and weight of the operator to collect statistics of the Tensor. So there are two things we need to address, how to define observer module? how to add observer module to the model. - -How to define observer module -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Observers are specific to: (1) type of quantization (e.g. affine quantization, look up table based quantization) (2) type of stats we want to track, e.g. min max observer, moving average observer. - -Generally an observer module should define `forward `__ and `calculate_qparams `__ - -For affine quantization, we defined `AffineQuantizedMinMaxObserver `__ that records min_val/max_val based on the granularity of affine quantization, and also defines how to calculate_qparams based on the recorded stats. - -How to add observer module to the model -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -1. Use Tensor Subclasses - If the only operator you are interested in quantizing is linear, you can use `linear activation weight observer `__, we also have a corresponding `insert_observer_ `__ API that handles modifying the weight of linear. - -2. Module Swap - Alternatively, you could also define and `ObservedLinear `__ module (or other module types) and swap the non observed with the observed module - -Calibration -^^^^^^^^^^^ -Calibration step is typically straightforward, typically we just need to run the model through the calibration dataset. For more complicated calibration (e.g. where we record all inputs and do optimizations based on all inputs), we'll cover some of them in next section. - -Quantize -^^^^^^^^ -We can reuse the ``quantize_`` API but provide a different ``apply_tensor_subclass`` function that converts the observed linear module to a linear module with quantized weight and statically quantized input activation, this can be done in the same manner as the dynamic quantization (with ``to_linear_activation_quantized``), see `example `__. - -Alternatively, user can do `module swap `__ as well. - -Other Quantization Flows -######################## - -For other quantization flow/algorithms that does not fit into any of the above, we also intend to provide examples for common patterns. For example, `GPTQ like quantization flow `__ that is adopted by `Autoround `__, it uses `MultiTensor `__ and module hooks to optimize the module. - -If you are working on a new quantization algorithm/flow and not sure how to implement it in a PyTorch native way, please feel free to open an issue to describe how your algorithm works and we can help advise on the implementation details. - -Training -######## -The above flow are mainly focused on inference, but low bit dtype Tensors can be used in training as well. - -Quantization Aware Training -*************************** -TODO - - -Low Bit Optimizers -****************** -Today we have some prototype low bit optimizers: `main/torchao/prototype/low_bit_optim `__ that implements a specific type of 4 bit, 8 bit and float8, and is also composable with FSDP (with look up table quantization). - -Quantized Training -****************** -Similar to low bit optimizers, we have quantized training prototype in `main/torchao/prototype/quantized_training `__, and we could extend AffineQuantizedTensor to support training as well, initial enablement is in progress, but there will be a lot of follow up work needed including making it work for different kernels etc. - -You can also checkout the tutorial for `Quantized Training `__ that talks about how to make a dtype tensor subclass trainable. - -Case Study: How int4 weight only quantization works in torchao? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To connect everything together, here is a more detailed walk through for how int4 weight only quantization is implemented in torchao. - -Quantization Flow: quantize_(model, Int4WeightOnlyConfig()) - * What happens: linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight), requires_grad=False) - * quantization primitive ops: choose_qparams and quantize_affine are called to quantize the Tensor - * quantized Tensor will be `AffineQuantizedTensor`, a quantized tensor with derived dtype (e.g. int4 with scale and zero_point) - * packing op `_convert_weight_to_int4pack` to pack the quantized weight for efficient execution - -During Model Execution: model(input) - * `torch.ops.aten._weight_int4pack_mm` is called on input and the packed weight - -During Quantization -################### -First we start with the API call: ``quantize_(model, Int4WeightOnlyConfig())`` what this does is it converts the weights of nn.Linear modules in the model to int4 quantized tensor (``AffineQuantizedTensor`` that is int4 dtype, asymmetric, per group quantized), using the layout for tinygemm kernel: ``tensor_core_tiled`` layout. - -* `quantize_ `__: the model level API that quantizes the weight of linear by applying the conversion function from user (second argument) -* `Int4WeightOnlyConfig `__: the function that returns a function that converts weight of linear to int4 weight only quantized weight - * Calls quantization primitives ops like choose_qparams_affine and quantize_affine to quantize the Tensor -* `TensorCoreTiledLayout `__: the tensor core tiled layout type, storing parameters for the packing format -* `TensorCoreTiledAQTTensorImpl `__: the tensor core tiled TensorImpl, stores the packed weight for efficient int4 weight only kernel (tinygemm kernel) - -During Model Execution -###################### - -When we run the quantized model ``model(inputs)``, we'll run through the functional linear operator in nn.Linear:: - - return F.linear(input, weight, bias) - -where input is a ``bfloat16`` Tensor, weight is an int4 ``AffineQuantizedTensor``, it calls into a ``__torch_function__`` of the ``AffineQuantizedTensor`` subclass, which will end up in an implementation for ``F.linear`` when one of the input is ``AffineQuantizedTensor``, so it calls:: - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - -The ``_quantized_linear_op`` goes through the ``_AQT_QLINEAR_DISPATCH_TABLE`` and checks each dispatch conditions, if the dispatch condition passes, it will call the implementation with ``input``/``weight``/``bias``. Please check out `this doc `__ for the explanation of ``dispatch_condition`` and ``impl``. - -int4 weight only `dispatch_condition `__ checks if the input is ``bfloat16`` Tensor and weight is a uint4 ``AffineQuantizedTensor`` -wint4 weight only quantization `kernel implementation `__ takes an bfloat16 input Tensor and an int4 AffineQuantizedTensor, and call ``torch.ops.aten._weight_int4pack_mm`` with the input Tensor and the packed weight that's stored in ``weight_tensor.tensor_impl``. - -During Save/Load -################ - -Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. - - diff --git a/docs/source/quantization_overview.rst b/docs/source/quantization_overview.rst new file mode 100644 index 0000000000..f5c82bfe5f --- /dev/null +++ b/docs/source/quantization_overview.rst @@ -0,0 +1,230 @@ +Quantization Overview +--------------------- + +First we want to lay out the torchao stack:: + + Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. + --------------------------------------------------------------------------------------------- + Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor + --------------------------------------------------------------------------------------------- + Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize + --------------------------------------------------------------------------------------------- + Basic dtypes: uint1-uint7, int1-int8, float3-float8 + + +Any quantization algorithm will be using some components from the above stack, for example per row float8 dynamic activation and float8 weight quantization (with default preference) uses: + +* dynamic quantization flow +* `Float8Tensor `__ +* `float8 activation + float8 weight fbgemm kernel `__ and `triton quant primitive ops from fbgemm library `__ +* ``torch.float8_e4m3fn`` dtype + +Basic DTypes +~~~~~~~~~~~~ +`dtype `__ is a bit of overloaded term, by basic dtype, we mean the dtypes that makes sense without any extra metadata (e.g. makes sense when people call ``torch.empty(.., dtype)``), for more details please check out `this post `__. + +No matter what quantization we are doing, in the end we will be using some low precision dtypes to represent the quantized data or quantization parameters, the low precision dtypes relevant for torchao are: + +* ``torch.uint1`` to ``torch.uint7`` available in pytorch 2.3 and later +* ``torch.int1`` to ``torch.int7`` available in pytorch 2.6 and later +* ``torch.float4_e2m1fn_x2``, ``torch.float8_e4m3fn``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2``, ``torch.float8_e5m2fnuz``, ``torch.float8_e8m0fnu`` + +In terms of actual implementation, ``uint1`` to ``uint7`` and ``int1`` to ``int7`` are just placeholders that does not have real implementations (i.e. the ops does not work for the PyTorch Tensor with these dtypes). Example PR added these dtypes can be found `here `__. Floating point dtypes are what we call shell dtypes that have limited op support. + +For more details please check out the `official PyTorch dtype doc `__. + +.. note:: + Dervied dtypes like mxfp8, mxfp4, nvfp4 are implemented with these basic dtypes, e.g. mxfp4 uses ``torch.float8_e8m0fnu`` for scale and ``torch.float4_e2m1fn_x2`` for 4 bit data. + +Quantization Primitive Ops +~~~~~~~~~~~~~~~~~~~~~~~~~~ +Quantization primitive ops means the operators used to convert between low preicison quantized tensors and high precision tensors. We will mainly have the following quantization primitive operators: + +* choose_qparams ops: that chooses quantization parameter based on the original Tensor, typically used in dynamic quantization, e.g. scale and zero_point for affine quantization +* quantize op: quantizes the original high precision tensor to the low precision tensor with the dtypes mentioned in previous section based on the quantization parameters +* dequantize op: dequantizes the low precision tensor into the high precision tensor based on quantization parameters + +There could be variations of the above to accommodate specific use cases, for example for static quantization we may have ``choose_qparams_affine_with_min_max`` that will choose quantization parameters based on min/max values derived from the observation process. + +There could be multiple versions of the op that is different by different kernel libraries that we can use in torchao, for example, for quantizing a bfloat16 Tensor to a raw float8 Tensor and scale: `_choose_scale_float8 `__ and `_quantize_affine_float8 `__ for torchao implementation, and `torch.ops.triton.quantize_fp8_row `__ from fbgemm library. + +Efficient kernels +~~~~~~~~~~~~~~~~~ +We'll also have efficient kernels that works with the low precision tensors, for example: + +* `torch.ops.fbgemm.f8f8bf16_rowwise `__ (rowwise float8 activation and float8 weight matrix multiplication kernel in fbgemm library) +* `torch._scaled_mm `__ (float8 activation and float8 weight matrix multiplication kernel in PyTorch for both rowwise and tensorwise) +* `int_matmul `__ that takes two int8 tensors and outputs an int32 tensor +* `int_scaled_matmul `__ that does matmul and also applies a scale to the result. + +.. note:: + We can also rely on torch.compile to generate kernels (through triton), for example the current int8 weight only quantization `kernel `__ just relies on torch.compile to get speedup. In this case there is no custom handwritten "efficient kernel" that's corresponding to the type of quantization. + +Quantized Tensors (derived dtypes and packing format) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On top of the basic dtypes, quantization primitive operators and efficient kernels, we can glue everything together and build out a Quantized (low precision) Tensor by subclassing torch.Tensor that can be constructed from a high precision Tensor and some parameters that can configure the specific quantization user wants, we can also call this derived dtypes since it can be represented with Tensors of basic dtypes and some extra metadata like scale. + +Another dimension for quantized Tensor is packing format, meaning how the quantized raw data is laid out in memory. For example, for int4, we can pack two elements together side by side in a uint8 value, or people can do some preshuffling/swizzling to make the format more efficient for memory operations (loading from memory to register) and computation. + +So in general we structure Tensor subclasses by dervied dtpype and packing format: + +.. list-table:: Tensor Subclasses in TorchAO + :widths: 20 10 30 40 + :header-rows: 1 + + * - Tensor + - Derived Dtype + - Packing Format + - Support + * - Float8Tensor + - scaled float8 + - plain (no packing needed) + - float8 act + float8 weight dynamic quantization and float8 weight only quantization + * - Int4Tensor + - scaled int4 + - plain (pack 2 adjacent int4 to a single int8 value) + - int4 weight only quantization + * - Int4PreshuffledTensor + - scaled int4 + - preshuffled (special format to optimize for loading) + - float8 act + int4 weight dynamic quantization and int4 weight only quantization + +.. note:: + We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options. + +.. note:: + We also don't use dynamic activation in the name, since we are talking about the weight tensor object, including information about activation in the tensor subclass name will be confusing, but + we do implement both weight only and dynamic activation quantization in the same linear function implementation, without relying on additional abstractions, this keeps relevant quantization operations close + to each other (quantization of activation and weight) in the same tensor subclass. + +In terms of how we quantize a Tensor, most of Tensors are using affine quantization, meaning the low precision Tensor is quantized from the high precision Tensor by an affine mapping, that is: ``low_precision_val = high_precision_val / scale + zero_point``, where ``scale`` and ``zero_point`` are the quantization parameters that can be calculated by quantization primitive ops or through some optimization procedure. Another common type of quantization, especially for lower bitwidths (e.g. lower than 4 bit) is codebook / look up table based quantization where the raw quantized data is the index we can use to look up a ``codebook`` that stores the values or vectors each index corresponds to. A common way to get the codebook and the raw quantized data for codebook quantization is kmeans clustering. + +Quantization Algorithms/Flows +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On the top of the stack will be the final quantization algorithms and quantization flows. Traditionally we have weight only quantization, dynamic quantization and static quantization, but now we are also seeing more types of quantization coming up. + +For demonstration purposes, let's say after previous step we have ``Float8Tensor`` defined. ``Float8Tensor.from_hp`` takes a high precision floating point Tensor and a target_dtype (e.g ``torch.float8_e4m3fn``) and converts it to a ``Float8Tensor`` + +Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in `Contributor Guide `__. + +Weight Only Quantization +######################## +This is the simplest form of quantization and it's easy to apply weight only quantization to the model, especially since we have Quantized Tensor. all we need to do is:: + + linear_module.weight = torch.nn.Parameter(Float8Tensor.from_hp(linear_module.weight, ...), requires_grad=False)) + +apply the above to all linear modules in the model and we'll get a weight only quantized model. + +Dynamic Activation and Weight Quantization +########################################## + +This is called "dynamic quantization" before but it means we quantize activation dynamically at runtime, and also quantize the weights as well. Compared to the weight only quantization, the main question is how do we apply the quantization to activation. In torchao we pass around the quantization keyword args for activation and the keyword args will be applied to activation when needed (e.g. in linear):: + + activation_dtype = torch.float8_e4m3fn + activation_granularity = PerRow() + # define kwargs for float8 activation quantization + act_quant_kwargs = QuantizeTensorToFloat8Kwargs( + activation_dtype, + activation_granularity, + ) + weight_dtype = torch.float8_e4m3fn + weight_granularity = PerRow() + quantized_weight = Float8Tensor.from_hp(linear_module.weight, float8_dtype=weight_dtype, granularity=weight_granularity, act_quant_kwargs=act_quant_kwargs) + linear_module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)) + +Static Activation Quantization and Weight Quantization +###################################################### +We'll skip the instruction for now since we haven't seen many use cases for static quantization with tensor subclass based flow, we recommend to look into the `PT2 export quantization flow `__ for static quantization. + +Other Quantization Flows +######################## + +For other quantization flow/algorithms that does not fit into any of the above, we also intend to provide examples for common patterns. For example, `GPTQ like quantization flow `__ that is adopted by `Autoround `__, it uses `MultiTensor `__ and module hooks to optimize the module. + +If you are working on a new quantization algorithm/flow and not sure how to implement it in a PyTorch native way, please feel free to open an issue to describe how your algorithm works and we can help advise on the implementation details. + +Training +######## +The above flow are mainly focused on inference, but low bit dtype Tensors can be used in training as well. + +User facing docs for float8 training can be found `here `__ and docs for finetuning can be found `here `__ + +Quantization Aware Training +*************************** +TorchAO supports `quantization aware training `__ through the `quantize_` API as well. + + +Low Bit Optimizers +****************** +We support `low bit optimizers `__ that implements a specific type of 4 bit, 8 bit and float8, and is also composable with FSDP (with look up table quantization). + +Quantized Training +****************** +We have quantized training prototype in `main/torchao/prototype/quantized_training `__, and we could extend existing tensor subclasses to support training as well, initial enablement is in progress, but there will be a lot of follow up work needed including making it work for different kernels etc. + +You can also checkout the tutorial for `Quantized Training `__ that talks about how to make a dtype tensor subclass trainable. + +Case Study: How float8 dynamic activation and float8 weight quantization works in torchao? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +To connect everything together, here is a more detailed walk through for float8 dynamic activation and float8 weight quantization in torchao (DEFAULT kernel preference, in H100, when fbgemm_gpu_genai library is installed): + +Quantization Flow: ``quantize_(model, Float8DynamicActivationFloat8WeightConfig())`` + * What happens: ``linear.weight = torch.nn.Parameter(Float8Tensor.from_hp(linear.weight), requires_grad=False)`` + * quantization primitive ops: ``torch.ops.triton.quantize_fp8_row`` + * quantized Tensor will be ``Float8Tensor``, a quantized tensor with derived dtype of scaled float8 + +During Model Execution: model(input) + * ``torch.ops.fbgemm.f8f8bf16_rowwise`` is called on input, raw float8 weight and scale + +During Quantization +################### +First we start with the API call: ``quantize_(model, Float8DynamicActivationFloat8WeightConfig())`` what this does is it converts the weights of nn.Linear modules in the model to ``Float8Tensor``, with plain packing format, no packing is required, since we have ``torch.float8_e4m3fn`` that can represent quantized float8 raw data directly without additional operations. + +* `quantize_ `__: the model level API that quantizes the weight of linear by applying the config from user (second argument) +* `Float8DynamicActivationFloat8WeightConfig `__: the config for float8 dynamic activation and float8 weight quantization + * Calls quantization primitives ops ``torch.ops.triton.quantize_fp8_row`` to quantize a bfloat16 Tensor to float8 raw Tensor and get a scale + + +During Model Execution +###################### + +When we run the quantized model ``model(inputs)``, we'll run through the functional linear operator in nn.Linear:: + + return F.linear(input, weight, bias) + +where input is a ``bfloat16`` Tensor, weight is a ``Float8Tensor``, it calls into a ``__torch_function__`` of the ``Float8Tensor`` subclass, which will end up in an implementation for ``F.linear`` when one of the `input `__ is ``Float8Tensor``:: + + @implements([torch.nn.functional.linear, aten.linear.default]) + def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + # quantizing activation, if `act_quant_kwargs` is specified + if act_quant_kwargs is not None: + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, act_quant_kwargs + ) + + # omitting kernel_preference related code + # granularity checks, let's say we are doing rowwise quant + # both input_tensor and weight_tensor will now be Float8Tensor + xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1]) + wq = weight_tensor.qdata.contiguous() + x_scale = input_tensor.scale + w_scale = weight_tensor.scale + res = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, + wq, + x_scale, + w_scale, + ).reshape(out_shape) + return res + +The function first quantizes the input to be ``Float8Tensor``, then get the raw float Tensor and scale from both the input and weight Tensor: ``t.qdata``, ``t.scale``, and calls the fbgemm kernel to do the matrix multiplication for float8 dynamic quantization: ``torch.ops.fbgemm.f8f8bf16_rowwise``. + +During Save/Load +################ + +Since ``Float8Tensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst index 2bd0744d0c..52947b7622 100644 --- a/docs/source/quick_start.rst +++ b/docs/source/quick_start.rst @@ -57,7 +57,7 @@ for efficient mixed dtype matrix multiplication: # torch 2.4+ only from torchao.quantization import Int4WeightOnlyConfig, quantize_ - quantize_(model, Int4WeightOnlyConfig(group_size=32)) + quantize_(model, Int4WeightOnlyConfig(group_size=32, version=1)) The quantized model is now ready to use! Note that the quantization logic is inserted through tensor subclasses, so there is no change @@ -95,16 +95,10 @@ it is also much faster! .. code:: py from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, benchmark_model, unwrap_tensor_subclass, ) - # Temporary workaround for tensor subclass + torch.compile - # Only needed for torch version < 2.5 - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - num_runs = 100 torch._dynamo.reset() example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) @@ -191,16 +185,16 @@ Please follow these tutorials to get started on PyTorch 2 Export Quantization: Modeling Users: -- `PyTorch 2 Export Post Training Quantization `_ -- `PyTorch 2 Export Quantization Aware Training `_ -- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor `_ -- `PyTorch 2 Export Post Training Quantization with XPU Backend through Inductor `_ -- `PyTorch 2 Export Quantization for OpenVINO torch.compile Backend `_ +- `PyTorch 2 Export Post Training Quantization `__ +- `PyTorch 2 Export Quantization Aware Training `__ +- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor `__ +- `PyTorch 2 Export Post Training Quantization with XPU Backend through Inductor `__ +- `PyTorch 2 Export Quantization for OpenVINO torch.compile Backend `__ Backend Developers (please check out all Modeling Users docs as well): -- `How to Write a Quantizer for PyTorch 2 Export Quantization `_ +- `How to Write a Quantizer for PyTorch 2 Export Quantization `_ Next Steps @@ -210,7 +204,7 @@ In this quick start guide, we learned how to quantize a simple model with torchao. To learn more about the different workflows supported in torchao, see our main `README `__. For a more detailed overview of quantization in torchao, visit -`this page `__. +`this page `__. Finally, if you would like to contribute to torchao, don't forget to check out our `contributor guide `__ and our list of diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 5e0c42f901..64818f53ef 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -7,7 +7,7 @@ Serialization and deserialization flow ====================================== Here is the serialization and deserialization flow:: - + import copy import tempfile import torch @@ -36,7 +36,7 @@ Here is the serialization and deserialization flow:: print(f"original model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB") example_inputs = m.example_inputs(dtype=dtype, device="cuda") - quantize_(m, Int4WeightOnlyConfig()) + quantize_(m, Int4WeightOnlyConfig(version=1)) print(f"quantized model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB") ref = m(*example_inputs) @@ -62,7 +62,7 @@ What happens when serializing an optimized model? To serialize an optimized model, we just need to call ``torch.save(m.state_dict(), f)``, because in torchao, we use tensor subclass to represent different dtypes or support different optimization techniques like quantization and sparsity. So after optimization, the only thing change is the weight Tensor is changed to an optimized weight Tensor, and the model structure is not changed at all. For example: original floating point model ``state_dict``:: - + {"linear1.weight": float_weight1, "linear2.weight": float_weight2} quantized model ``state_dict``:: @@ -75,7 +75,7 @@ The size of the quantized model is typically going to be smaller to the original original model size: 4.0 MB quantized model size: 1.0625 MB - + What happens when deserializing an optimized model? =================================================== To deserialize an optimized model, we can initialize the floating point model in `meta `__ device and then load the optimized ``state_dict`` with ``assign=True`` using `model.load_state_dict `__:: @@ -97,5 +97,3 @@ We can also verify that the weight is properly loaded by checking the type of we type of weight before loading: (, ) type of weight after loading: (, ) - - diff --git a/docs/source/serving.rst b/docs/source/serving.rst index cb61b159c4..d95132ded7 100644 --- a/docs/source/serving.rst +++ b/docs/source/serving.rst @@ -1,12 +1,410 @@ (Part 3) Serving on vLLM, SGLang, ExecuTorch ------------------------------------------------- +============================================ -TorchAO provides an end-to-end pre-training, fine-tuning, and serving -model optimization flow by leveraging our quantization and sparsity -techniques integrated into our partner frameworks. This is part 3 of 3 -such tutorials showcasing this end-to-end flow, focusing on the -serving step. +TorchAO provides an end-to-end pre-training, fine-tuning, and serving model optimization flow by leveraging our quantization and sparsity techniques integrated into our partner frameworks. This is part 3 of 3 such tutorials showcasing this end-to-end flow, focusing on the serving step. .. image:: ../static/e2e_flow_part3.png +This tutorial demonstrates how to perform post-training quantization and deploy models for inference using torchao as the underlying optimization engine, seamlessly integrated through HuggingFace Transformers, vLLM, and ExecuTorch. + +.. contents:: + :local: + :depth: 2 + +Post-training Quantization with HuggingFace +------------------------------------------- + +HuggingFace Transformers provides seamless integration with torchao quantization. The ``TorchAoConfig`` automatically applies torchao's optimized quantization algorithms during model loading. +Please check out our `HF Integration Docs `_ for examples on how to use quantization and sparsity in Transformers and Diffusers. + +Serving and Inference +-------------------- + +Serving and Inference with vLLM +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +vLLM automatically leverages torchao's optimized kernels when serving quantized models, providing significant throughput improvements. + +First, install vLLM with torchao support: + +.. code-block:: bash + + pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly + pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + +To serve in vLLM, we're using the model we quantized and pushed to Hugging Face hub in the previous step :ref:`Post-training Quantization with HuggingFace`. + +.. code-block:: bash + + # Server + vllm serve pytorch/Phi-4-mini-instruct-float8dq --tokenizer microsoft/Phi-4-mini-instruct -O3 + + # Client + curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "pytorch/Phi-4-mini-instruct-float8dq", + "messages": [ + {"role": "user", "content": "Give me a short introduction to large language models."} + ], + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_tokens": 32768 + }' + +Serving a float8 dynamic quantized model with vLLM shows 36% VRAM reduction, 1.15x-1.2x inference speedup and little to no accuracy impact on H100. :ref:`Memory Benchmarking` and :ref:`Performance Benchmarking` for more details. + +.. note:: + For more information on vLLM Integration, please refer to the detailed guide :ref:`torchao_vllm_integration`. + +Serving and Inference with SGLang +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + (Coming soon!) + +Inference with Transformers +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Install the required packages: + +.. code-block:: bash + + pip install git+https://github.com/huggingface/transformers@main + pip install torchao + pip install torch + pip install accelerate + +.. code-block:: python + + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + + torch.random.manual_seed(0) + + model_path = "pytorch/Phi-4-mini-instruct-float8dq" + + model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map="auto", + dtype="auto", + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + messages = [ + {"role": "system", "content": "You are a helpful AI assistant."}, + {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}, + {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."}, + {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"}, + ] + + pipe = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + ) + + generation_args = { + "max_new_tokens": 500, + "return_full_text": False, + "temperature": 0.0, + "do_sample": False, + } + + output = pipe(messages, **generation_args) + print(output[0]['generated_text']) + +Mobile Deployment with ExecuTorch +-------------------------------- + +ExecuTorch enables on-device inference using torchao's mobile-optimized quantization schemes. The 8da4w (8-bit dynamic activation, 4-bit weight) configuration is specifically designed for mobile deployment. Optionally, before lowering to ExecuTorch, we can finetune a model using QAT :doc:`finetuning`, which has demonstrated some improvements in the quality of quantized models. + +[Optional] Untie Embedding Weights +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Optionally, we can quantize the embedding and lm_head differently, since those layers are tied, we first need to untie the model: + +.. code-block:: python + + from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, + ) + import torch + from transformers.modeling_utils import find_tied_parameters + + model_id = "microsoft/Phi-4-mini-instruct" + untied_model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + print(untied_model) + print("tied weights:", find_tied_parameters(untied_model)) + if getattr(untied_model.config.get_text_config(decoder=True), "tie_word_embeddings"): + setattr(untied_model.config.get_text_config(decoder=True), "tie_word_embeddings", False) + + untied_model._tied_weights_keys = [] + untied_model.lm_head.weight = torch.nn.Parameter(untied_model.lm_head.weight.clone()) + + print("tied weights:", find_tied_parameters(untied_model)) + + USER_ID = "YOUR_USER_ID" + MODEL_NAME = model_id.split("/")[-1] + save_to = f"{USER_ID}/{MODEL_NAME}-untied-weights" + + untied_model.push_to_hub(save_to) + tokenizer.push_to_hub(save_to) + + # or save locally + save_to_local_path = f"{MODEL_NAME}-untied-weights" + untied_model.save_pretrained(save_to_local_path) + tokenizer.save_pretrained(save_to) + +Step 1: Create Mobile-Optimized Quantization +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Quantizing the model for mobile deployment using TorchAO's ``Int8DynamicActivationIntxWeightConfig`` configuration. If we've untied the embedding and lm_head following the previous step, we can quantize embedding using ``IntxWeightOnlyConfig`` configuration, and lm_head using ``Int8DynamicActivationIntxWeightConfig`` configuration. + +.. code-block:: python + + from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, + TorchAoConfig, + ) + from torchao.quantization.quant_api import ( + IntxWeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + ModuleFqnToConfig, + quantize_, + ) + from torchao.quantization.granularity import PerGroup, PerAxis + import torch + + # we start from the model with untied weights + model_id = "microsoft/Phi-4-mini-instruct" + USER_ID = "YOUR_USER_ID" + MODEL_NAME = model_id.split("/")[-1] + untied_model_id = f"{USER_ID}/{MODEL_NAME}-untied-weights" + untied_model_local_path = f"{MODEL_NAME}-untied-weights" + + # embedding_config is required only if we untied the embedding and lm_head in the previous step, else we can use only linear config for quantization + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=PerAxis(0), + ) + linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(32), + weight_scale_dtype=torch.bfloat16, + ) + quant_config = ModuleFqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config}) + quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[]) + + # either use `untied_model_id` or `untied_model_local_path` + quantized_model = AutoModelForCausalLM.from_pretrained(untied_model_id, dtype=torch.float32, device_map="auto", quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # Push to hub + MODEL_NAME = model_id.split("/")[-1] + save_to = f"{USER_ID}/{MODEL_NAME}-8da4w" + quantized_model.push_to_hub(save_to, safe_serialization=False) + tokenizer.push_to_hub(save_to) + + +Step 2: Export to ExecuTorch +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Convert the quantized model to .pte file, which can be run on mobile device. + +.. code-block:: bash + + # Install ExecuTorch + git clone https://github.com/pytorch/executorch.git + cd executorch + ./install_requirements.sh + + # Convert checkpoint format for ExecuTorch + python -m executorch.examples.models.phi_4_mini.convert_weights pytorch_model.bin pytorch_model_converted.bin + + # Export to PTE format with torchao optimizations preserved + PARAMS="executorch/examples/models/phi_4_mini/config.json" + python -m executorch.examples.models.llama.export_llama \ + --model "phi_4_mini" \ + --checkpoint "pytorch_model_converted.bin" \ + --params "$PARAMS" \ + -kv \ + --use_sdpa_with_kv_cache \ + -X \ + --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \ + --max_seq_length 128 \ + --max_context_length 128 \ + --output_name="phi4-mini-8da4w.pte" + +The .pte file can be run with ExecuTorch on a mobile phone. Follow the `instructions `_ for doing this on an iOS device. + +Mobile Performance Characteristics +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The torchao-optimized 8da4w model provides: + +- **Memory**: ~3.2GB on iPhone 15 Pro +- **Speed**: ~17 tokens/sec on iPhone 15 Pro +- **Accuracy**: Maintained within 5-10% of original model on most benchmarks + +.. note:: + For detailed instructions on testing the ExecuTorch model and reproducing benchmarks please refer to the `HF Phi-4-mini-instruct-8da4w model `_. + +Evaluation +--------- + +Model Quality Assessment +^^^^^^^^^^^^^^^^^^^^^^ + +Evaluate quantized models using lm-evaluation-harness: + +.. code-block:: bash + + # Install evaluation framework + # Need to install lm-eval from source: https://github.com/EleutherAI/lm-evaluation-harness#install + + # Evaluate baseline model + lm_eval --model hf --model_args pretrained=microsoft/Phi-4-mini-instruct --tasks hellaswag --device cuda:0 --batch_size 8 + + # Evaluate torchao-quantized model (float8dq) + lm_eval --model hf --model_args pretrained=pytorch/Phi-4-mini-instruct-float8dq --tasks hellaswag --device cuda:0 --batch_size 8 + +Memory Benchmarking +^^^^^^^^^^^^^^^^^ +For Phi-4-mini-instruct, when quantized with float8 dynamic quant, we can reduce the peak memory usage by 36% compared to the baseline model. + +.. code-block:: python + + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + # use "microsoft/Phi-4-mini-instruct" or "pytorch/Phi-4-mini-instruct-float8dq" + model_id = "pytorch/Phi-4-mini-instruct-float8dq" + quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype=torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + torch.cuda.reset_peak_memory_stats() + + prompt = "Hey, are you conscious? Can you talk to me?" + messages = [ + { + "role": "system", + "content": "", + }, + {"role": "user", "content": prompt}, + ] + templated_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + print("Prompt:", prompt) + print("Templated prompt:", templated_prompt) + inputs = tokenizer( + templated_prompt, + return_tensors="pt", + ).to("cuda") + generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) + output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + print("Response:", output_text[0][len(prompt):]) + + mem = torch.cuda.max_memory_reserved() / 1e9 + print(f"Peak Memory Usage: {mem:.02f} GB") + +Output: + +.. code:: console + + Prompt: Hey, are you conscious? Can you talk to me? + Templated prompt: <|system|><|end|><|user|>Hey, are you conscious? Can you talk to me?<|end|><|assistant|> + Response: Hello! Yes, I am a digital assistant, and I am fully operational and ready to assist you. How can I help you today? + Peak Memory Usage: 5.70 GB + ++-------------------+---------------------+------------------------------+ +| Benchmark | Phi-4 mini-instruct | Phi-4-mini-instruct-float8dq | ++===================+=====================+==============================+ +| Peak Memory (GB) | 8.91 | 5.70 (36% reduction) | ++-------------------+---------------------+------------------------------+ + +Performance Benchmarking +^^^^^^^^^^^^^^^^^^^^^^ + +Latency Benchmarking +""""""""""""""""""" + +.. code-block:: bash + + # baseline + python benchmarks/benchmark_latency.py --input-len 256 --output-len 256 --model microsoft/Phi-4-mini-instruct --batch-size 1 + + # float8dq + VLLM_DISABLE_COMPILE_CACHE=1 python benchmarks/benchmark_latency.py --input-len 256 --output-len 256 --model pytorch/Phi-4-mini-instruct-float8dq --batch-size 1 + +Serving Benchmarking +""""""""""""""""""""" + +We benchmarked the throughput in a serving environment. + +.. code-block:: bash + + # Setup: Get vllm source code + git clone git@github.com:vllm-project/vllm.git + + # Install vllm + VLLM_USE_PRECOMPILED=1 pip install --editable . + + # Run the benchmarks under vllm root folder: + + # Download sharegpt dataset: + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + + # Other datasets can be found in: https://github.com/vllm-project/vllm/tree/main/benchmarks + # Note: you can change the number of prompts to be benchmarked with --num-prompts argument for benchmark_serving script. + + # For baseline + # Server: + vllm serve microsoft/Phi-4-mini-instruct --tokenizer microsoft/Phi-4-mini-instruct -O3 + # Client: + python benchmarks/benchmark_serving.py --backend vllm --dataset-name sharegpt --tokenizer microsoft/Phi-4-mini-instruct --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --model microsoft/Phi-4-mini-instruct --num-prompts 1 + + # For float8dq + # Server: + VLLM_DISABLE_COMPILE_CACHE=1 vllm serve pytorch/Phi-4-mini-instruct-float8dq --tokenizer microsoft/Phi-4-mini-instruct -O3 + # Client: + python benchmarks/benchmark_serving.py --backend vllm --dataset-name sharegpt --tokenizer microsoft/Phi-4-mini-instruct --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --model pytorch/Phi-4-mini-instruct-float8dq --num-prompts 1 + +Results (H100 machine) +""""""""""""""""""""" + ++----------------------------+---------------------+------------------------------+ +| Benchmark | Phi-4-mini-instruct | Phi-4-mini-instruct-float8dq | ++============================+=====================+==============================+ +| latency (batch_size=1) | 1.64s | 1.41s (1.16x speedup) | ++----------------------------+---------------------+------------------------------+ +| latency (batch_size=128) | 3.1s | 2.72s (1.14x speedup) | ++----------------------------+---------------------+------------------------------+ +| serving (num_prompts=1) | 1.35 req/s | 1.57 req/s (1.16x speedup) | ++----------------------------+---------------------+------------------------------+ +| serving (num_prompts=1000) | 66.68 req/s | 80.53 req/s (1.21x speedup) | ++----------------------------+---------------------+------------------------------+ + +Conclusion +--------- + +This tutorial demonstrated how torchao's quantization and sparsity techniques integrate seamlessly across the entire ML deployment stack: + +- **HuggingFace Transformers** provides easy model loading with torchao quantization +- **vLLM** leverages torchao's optimized kernels for high-throughput serving +- **ExecuTorch** enables mobile deployment with torchao's mobile-optimized schemes +- **lm-evaluation-harness** provides model quality assessment + +All these frameworks use torchao as the underlying optimization engine, ensuring consistent performance gains and ease of integration. The quantization techniques shown provide significant memory reduction (3-4x) and performance improvements (1.5-2x) while maintaining model quality within acceptable bounds for most applications. + +For production deployments, always benchmark on your specific use case and hardware to validate the performance and accuracy trade-offs. diff --git a/docs/source/torchao_hf_integration.md b/docs/source/torchao_hf_integration.md new file mode 100644 index 0000000000..8ab5020133 --- /dev/null +++ b/docs/source/torchao_hf_integration.md @@ -0,0 +1,128 @@ +(torchao_hf_integration)= +# Hugging Face Integration + +```{contents} +:local: +:depth: 2 +``` + +(usage-examples)= +## Quick Start: Usage Example + +First, install the required packages. + +```bash +pip install git+https://github.com/huggingface/transformers@main +pip install git+https://github.com/huggingface/diffusers@main +pip install torchao +pip install torch +pip install accelerate +``` + +(quantizing-models-transformers)= +### 1. Quantizing Models with Transformers + +Below is an example of using `Float8DynamicActivationInt4WeightConfig` on the Llama-3.2-1B model. + +```python +from transformers import TorchAoConfig, AutoModelForCausalLM +from torchao.quantization import Float8DynamicActivationInt4WeightConfig + +# Create quantization configuration +quantization_config = TorchAoConfig( + quant_type=Float8DynamicActivationInt4WeightConfig(group_size=128, use_hqq=True) +) + +# Load and automatically quantize the model +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + torch_dtype="auto", + device_map="auto", + quantization_config=quantization_config +) +``` +```{seealso} +For inference examples and recommended quantization methods based on different hardwares (i.e. A100 GPU, H100 GPU, CPU), see [HF-Torchao Docs (Quantization Examples)](https://huggingface.co/docs/transformers/main/en/quantization/torchao#quantization-examples). + +For inference using vLLM, please see [(Part 3) Serving on vLLM, SGLang, ExecuTorch](https://docs.pytorch.org/ao/main/serving.html) for a full end-to-end tutorial. +``` + +(quantizing-models-diffusers)= +### 2. Quantizing Models with Diffusers + +Below is an example of how we can integrate with Diffusers. + +```python +from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig + +model_id = "black-forest-labs/Flux.1-Dev" +dtype = torch.bfloat16 + +quantization_config = TorchAoConfig("int8wo") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=dtype, +) +pipe = FluxPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=dtype, +) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] +image.save("output.png") +``` + +```{note} +Example Output: +![alt text](output.png "Model Output") +``` + +```{seealso} +Please refer to [HF-TorchAO-Diffuser Docs](https://huggingface.co/docs/diffusers/en/quantization/torchao) for more examples and benchmarking results. +``` + +(saving-models)= +## Saving the Model + +After we quantize the model, we can save it. + +```python +# Save quantized model (see below for safe_serialization enablement progress) +with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=False) + +# optional: push to hub (uncomment the following lines) +# save_to = "your-username/Llama-3.2-1B-int4" +# model.push_to_hub(save_to, safe_serialization=False) + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") +tokenizer.push_to_hub(save_to) +``` + +**Current Status of Safetensors support**: TorchAO quantized models cannot yet be serialized with safetensors due to tensor subclass limitations. When saving quantized models, you must use `safe_serialization=False`. + +```python +# don't serialize model with Safetensors +output_dir = "llama3-8b-int4wo-128" +quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False) +``` + +**Workaround**: For production use, save models with `safe_serialization=False` when pushing to Hugging Face Hub. + +**Future Work**: The TorchAO team is actively working on safetensors support for tensor subclasses. Track progress [here](https://github.com/pytorch/ao/issues/2338) and [here](https://github.com/pytorch/ao/pull/2881). + +(Supported-Quantization-Types)= +## Supported Quantization Types + +Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. + +Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly. + +```{note} +Please refer to the [torchao docs](https://docs.pytorch.org/ao/main/api_ref_quantization.html) for supported quantization types. +``` diff --git a/docs/source/torchao_vllm_integration.md b/docs/source/torchao_vllm_integration.md index 9af8fb3885..1ca027a124 100644 --- a/docs/source/torchao_vllm_integration.md +++ b/docs/source/torchao_vllm_integration.md @@ -45,6 +45,7 @@ from torchao.quantization import Int4WeightOnlyConfig config = Int4WeightOnlyConfig( group_size=128, use_hqq=True, + version=1, ) assert isinstance(config, AOBaseConfig) ``` @@ -65,7 +66,7 @@ config = ModuleFqnToConfig({ "model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64), "model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64), "model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(), - "_default": Int4WeightOnlyConfig(group_size=128) # Default for other modules + "_default": Int4WeightOnlyConfig(group_size=128, version=1) # Default for other modules }) ``` @@ -81,13 +82,13 @@ from torchao.quantization import Int4WeightOnlyConfig # Create quantization configuration quantization_config = TorchAoConfig( - quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True) + quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True, version=1) ) # Load and automatically quantize the model model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.2-1B", - torch_dtype="auto", + dtype="auto", device_map="auto", quantization_config=quantization_config ) @@ -170,7 +171,7 @@ class MyNewQuantConfig(AOBaseConfig): VERSION: ClassVar[int] = 1 class MyQuantizedTensor(TorchAOBaseTensor): - """Example based on FbgemmFp8Tensor - stores quantized data + scale""" + """Example based on Float8Tensor - stores quantized data + scale""" tensor_data_attrs = ["quantized_data", "scale"] tensor_attributes = ["dtype"] diff --git a/docs/source/tutorials_source/pt2e_quant_openvino.rst b/docs/source/tutorials_source/pt2e_quant_openvino_inductor.rst similarity index 98% rename from docs/source/tutorials_source/pt2e_quant_openvino.rst rename to docs/source/tutorials_source/pt2e_quant_openvino_inductor.rst index 827023b300..cf7f1ec896 100644 --- a/docs/source/tutorials_source/pt2e_quant_openvino.rst +++ b/docs/source/tutorials_source/pt2e_quant_openvino_inductor.rst @@ -74,7 +74,7 @@ OpenVINO and NNCF could be easily installed via `pip distribution 0: + print("SOURCES", sources) # Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources sources = [ s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu" @@ -610,6 +657,35 @@ def get_extensions(): ) ) + # Add the mxfp8 casting CUDA extension + if use_cuda: + mxfp8_sources = [ + os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"), + os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"), + ] + + # Only add the extension if the source files exist AND we are building for sm100 + mxfp8_src_files_exist = all(os.path.exists(f) for f in mxfp8_sources) + if mxfp8_src_files_exist and build_for_sm100a: + print("Building mxfp8_cuda extension") + ext_modules.append( + CUDAExtension( + name="torchao.prototype.mxfp8_cuda", + sources=mxfp8_sources, + include_dirs=[ + mxfp8_extension_dir, # For mxfp8_quantize.cuh, mxfp8_extension.cpp, and mxfp8_cuda.cu + ], + extra_compile_args={ + "cxx": ["-std=c++17", "-O3"], + "nvcc": nvcc_args + + [ + "-gencode=arch=compute_100,code=sm_100", + "-gencode=arch=compute_120,code=compute_120", + ], + }, + ), + ) + # Only build the cutlass_90a extension if sm90a is in the architecture flags if ( cutlass_90a_sources is not None @@ -652,7 +728,7 @@ def get_extensions(): ) ) - # Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND + # Build CMakeLists from /torchao/csrc/cpu - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1": build_options = BuildOptions() @@ -665,24 +741,20 @@ def bool_to_on_off(value): ext_modules.append( CMakeExtension( - "torchao.experimental", - cmake_lists_dir="torchao/experimental", + "torchao._C_cpu_shared_kernels_aten", + cmake_lists_dir="torchao/csrc/cpu", cmake_args=( [ f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}", f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}", f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}", - f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}", f"-DTORCHAO_ENABLE_ARM_NEON_DOT={bool_to_on_off(build_options.enable_arm_neon_dot)}", f"-DTORCHAO_ENABLE_ARM_I8MM={bool_to_on_off(build_options.enable_arm_i8mm)}", f"-DTORCHAO_PARALLEL_BACKEND={build_options.parallel_backend}", + "-DTORCHAO_BUILD_TESTS=OFF", + "-DTORCHAO_BUILD_BENCHMARKS=OFF", "-DTorch_DIR=" + torch_dir, ] - + ( - ["-DCMAKE_INSTALL_PREFIX=cmake-out"] - if build_options.build_experimental_mps - else [] - ) ), ) ) diff --git a/test/quantization/test_config_serialization.py b/test/core/test_config.py similarity index 79% rename from test/quantization/test_config_serialization.py rename to test/core/test_config.py index 71cf8e144d..0df31194ac 100644 --- a/test/quantization/test_config_serialization.py +++ b/test/core/test_config.py @@ -7,6 +7,7 @@ import json import os import tempfile +import warnings from dataclasses import dataclass from unittest import mock @@ -15,13 +16,16 @@ from torchao.core.config import ( AOBaseConfig, - VersionMismatchError, config_from_dict, config_to_dict, ) +from torchao.prototype.awq import ( + AWQConfig, + AWQStep, +) from torchao.quantization.quant_api import ( - FbgemmConfig, Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, GemliteUIntXWeightOnlyConfig, @@ -35,7 +39,6 @@ UIntXWeightOnlyConfig, ) from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 # Define test configurations as fixtures configs = [ @@ -46,10 +49,17 @@ weight_dtype=torch.float8_e4m3fn, ), UIntXWeightOnlyConfig(dtype=torch.uint1), + Float8DynamicActivationInt4WeightConfig(), Int4DynamicActivationInt4WeightConfig(), Int4WeightOnlyConfig( group_size=32, ), + Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", + version=2, + ), Int8DynamicActivationInt4WeightConfig( group_size=64, ), @@ -79,11 +89,10 @@ "linear2": Int8DynamicActivationInt4WeightConfig(), } ), + AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING), + AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"), ] -if TORCH_VERSION_AT_LEAST_2_6: - configs += [FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256])] - # Create ids for better test naming def get_config_ids(configs): @@ -145,7 +154,9 @@ def test_reconstructable_dict_file_round_trip(config): # Define a dummy config in a non-allowed module @dataclass class DummyNonAllowedConfig(AOBaseConfig): - VERSION = 2 + # NOTE: must be `version: int` (with type annotations) to + # overload the version variable from AOBaseConfig + version: int = 2 value: int = 42 @@ -166,11 +177,11 @@ def test_disallowed_modules(): reconstructed = config_from_dict(reconstructable) assert isinstance(reconstructed, DummyNonAllowedConfig) assert reconstructed.value == 42 - assert reconstructed.VERSION == 2 + assert reconstructed.version == 2 def test_version_mismatch(): - """Test that version mismatch raises an error during reconstruction.""" + """Test that version mismatch prints a warning during reconstruction.""" # Create a config dummy_config = DummyNonAllowedConfig() reconstructable = config_to_dict(dummy_config) @@ -180,11 +191,27 @@ def test_version_mismatch(): # Patch to allow the module but should still fail due to version mismatch with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}): - with pytest.raises( - VersionMismatchError, - match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2", - ): + with warnings.catch_warnings(record=True) as caught_warnings: config_from_dict(reconstructable) + assert any( + "Stored version is not the same as current default version of the config" + in str(w.message) + for w in caught_warnings + ), "Didn't get expected warning message for version mismatch" + + +def test_default_version(): + """Making sure the default version for a new config inheriting from AOBaseConfig is always 1 + because it's the default version that all children has when they haven't explicitly + defined a version class variable + """ + + @dataclass + class DummyConfig(AOBaseConfig): + pass + + config = DummyConfig() + assert config.version == 1, "Default version must be 1" if __name__ == "__main__": diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index bd5ed0c3b5..83f32c8420 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -24,30 +24,24 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from torchao.float8.config import e4m3_dtype from torchao.quantization import ( - FbgemmConfig, + Float8WeightOnlyConfig, GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, - float8_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, + Int8WeightOnlyConfig, quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.testing.utils import skip_if_no_cuda, skip_if_no_gemlite, skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, check_cpu_version, check_xpu_version, is_fbcode, is_ROCM, is_sm_at_least_89, - is_sm_at_least_90, ) is_cusparselt_available = ( @@ -59,52 +53,49 @@ def get_quantization_functions( do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False ): base_functions = [ - int8_weight_only(), - int8_dynamic_activation_int4_weight(), - int8_dynamic_activation_int8_weight(), - int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), + Int8WeightOnlyConfig(), + Int8DynamicActivationInt4WeightConfig(), + Int8DynamicActivationInt8WeightConfig(), + Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC), ] if do_int4: if check_cpu_version(device): base_functions.append( - int4_weight_only(group_size=32, layout=Int4CPULayout()) + Int4WeightOnlyConfig(group_size=32, layout=Int4CPULayout(), version=1) ) elif check_xpu_version(device): base_functions.append( - int4_weight_only(group_size=32, layout=Int4XPULayout()) + Int4WeightOnlyConfig(group_size=32, layout=Int4XPULayout(), version=1) ) if int4_zp_int: base_functions.append( - int4_weight_only( + Int4WeightOnlyConfig( group_size=32, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT, + version=1, ) ) else: - base_functions.append(int4_weight_only(group_size=32)) + base_functions.append(Int4WeightOnlyConfig(group_size=32, version=1)) if device == "cuda" and not is_ROCM(): base_functions.append( - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=None, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, layout=CutlassInt4PackedLayout(), ) ) - base_functions.append(int4_dynamic_activation_int4_weight()) + base_functions.append(Int4DynamicActivationInt4WeightConfig()) if do_sparse and device != "xpu": base_functions.append( - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()) ) if is_sm_at_least_89(): - base_functions.append(float8_weight_only()) - - if is_sm_at_least_90(): - base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16)) - base_functions.append(FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16)) + base_functions.append(Float8WeightOnlyConfig()) return base_functions @@ -119,7 +110,7 @@ def test_tensor_core_layout_transpose(self): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") t = linear.weight shape = t.shape - apply_int4_weight_only_quant = int4_weight_only(group_size=32) + apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1) quantize_(linear, apply_int4_weight_only_quant) ql = linear aqt = ql.weight @@ -151,11 +142,7 @@ def test_weights_only(self): with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) - # `weights_only=True` is enabled for torch 2.5+ - if TORCH_VERSION_AT_LEAST_2_5: - _ = torch.load(f, weights_only=True) - else: - _ = torch.load(f, weights_only=False) + _ = torch.load(f, weights_only=True) @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) @@ -358,7 +345,7 @@ def test_slice_int4wo(self, device, dtype): # out_feature not divisible by 8 # to test slice + padding for int4 weight only quantization dummy = nn.Linear(256, 321, dtype=dtype, device=device) - quantize_(dummy, Int4WeightOnlyConfig()) + quantize_(dummy, Int4WeightOnlyConfig(version=1)) # make sure these run without error _ = dummy.weight.narrow(0, 0, 64) _ = dummy.weight.narrow(1, 0, 128) @@ -472,7 +459,7 @@ def test_slice_and_copy_int4wo(self, device, dtype): l.weight = torch.nn.Parameter( torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") ) - quantize_(l, Int4WeightOnlyConfig()) + quantize_(l, Int4WeightOnlyConfig(version=1)) param = l.weight param_data = param.data param_data = param_data.narrow(0, 0, 512) @@ -488,7 +475,7 @@ def test_slice_and_copy_int4wo(self, device, dtype): # dummy_l has random input (shouldn't be 0) dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - quantize_(dummy_l, Int4WeightOnlyConfig()) + quantize_(dummy_l, Int4WeightOnlyConfig(version=1)) quantized = dummy_l.weight quantized = quantized.narrow(0, 0, 512) @@ -507,7 +494,7 @@ def test_mm_int4wo(self, device, dtype): l = torch.nn.Linear(512, 1024).to(device).to(dtype) l.weight = torch.nn.Parameter(weight) - quantize_(l, Int4WeightOnlyConfig()) + quantize_(l, Int4WeightOnlyConfig(version=1)) # weight shape: 1024 x 512 weight = l.weight diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 33a1fe66a7..35870a5e6b 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -3,15 +3,6 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import copy import io import random @@ -23,30 +14,30 @@ import pytest import torch from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.utils import run_and_get_code +from torch.testing import FileCheck from torch.testing._internal import common_utils from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale from torchao.float8.float8_utils import compute_error from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, - float8_dynamic_activation_float8_weight, - float8_weight_only, + Float8StaticActivationFloat8WeightConfig, + Float8WeightOnlyConfig, quantize_, ) from torchao.quantization.granularity import ( PerRow, PerTensor, ) -from torchao.quantization.quant_api import ( - float8_static_activation_float8_weight, -) from torchao.quantization.quant_primitives import ( MappingType, - _choose_qparams_affine_float8, + _choose_scale_float8, _dequantize_affine_float8, _quantize_affine_float8, choose_qparams_affine, ) +from torchao.quantization.quantize_.common import KernelPreference from torchao.utils import ( is_sm_at_least_89, is_sm_at_least_90, @@ -118,11 +109,13 @@ def test_fp8_linear_variants( ) mode_map = { "dynamic": partial( - float8_dynamic_activation_float8_weight, granularity=granularity + Float8DynamicActivationFloat8WeightConfig, + granularity=granularity, + version=1, ), - "weight-only": float8_weight_only, + "weight-only": partial(Float8WeightOnlyConfig, version=1), "static": partial( - float8_static_activation_float8_weight, + Float8StaticActivationFloat8WeightConfig, scale=scale, granularity=granularity, ), @@ -151,7 +144,7 @@ def test_fp8_linear_variants( ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): - float8_dynamic_activation_float8_weight(granularity="invalid") + Float8DynamicActivationFloat8WeightConfig(granularity="invalid") @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" @@ -161,7 +154,9 @@ def test_mismatched_granularity(self): ValueError, match="Different granularities for activation and weight are not supported", ): - float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) + Float8DynamicActivationFloat8WeightConfig( + granularity=(PerTensor(), PerRow()) + ) @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" @@ -171,8 +166,8 @@ class UnsupportedGranularity: pass with pytest.raises(ValueError, match="Invalid granularity types"): - float8_dynamic_activation_float8_weight( - granularity=(UnsupportedGranularity(), UnsupportedGranularity()) + Float8DynamicActivationFloat8WeightConfig( + granularity=(UnsupportedGranularity(), UnsupportedGranularity()), ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -186,7 +181,8 @@ def test_per_row_with_float32(self): ): model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") quantize_( - model, float8_dynamic_activation_float8_weight(granularity=PerRow()) + model, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -200,15 +196,18 @@ def test_serialization(self, mode: str): mode_map = { "dynamic": partial( - float8_dynamic_activation_float8_weight, granularity=PerTensor() + Float8DynamicActivationFloat8WeightConfig, + granularity=PerTensor(), + version=1, ), - "weight-only": float8_weight_only, + "weight-only": partial(Float8WeightOnlyConfig, version=1), "static": partial( - float8_static_activation_float8_weight, + Float8StaticActivationFloat8WeightConfig, scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"), granularity=PerTensor(), ), } + factory = mode_map[mode]() quantize_(model, factory) @@ -274,7 +273,10 @@ def test_fp8_weight_dimension_warning(self): "torchao.quantization.quant_api", level="INFO" ) as log_context: quantize_( - model, float8_dynamic_activation_float8_weight(granularity=PerTensor()) + model, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), version=1 + ), ) print(model) @@ -319,7 +321,8 @@ def test_mm_float8dq_per_row( ) test_linear = copy.deepcopy(ref_linear) quantize_( - test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + test_linear, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1), ) quant_weight = test_linear.weight @@ -356,7 +359,50 @@ def test_mm_float8dq_per_row( ) @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) - @common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)]) + def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): + block_size = () + device = "cuda" + input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) + + # testing upper bounds + input_tensor[0][0] = 2000 + scale_ref = _choose_scale_float8( + input_tensor, float8_dtype=float8_dtype, block_size=block_size + ) + + hp_value_ub = 1200 + scale_with_ub = _choose_scale_float8( + input_tensor, + float8_dtype=float8_dtype, + block_size=block_size, + hp_value_ub=hp_value_ub, + ) + # since scale = abs_max / quant_max, larger abs_max means scale is larger + self.assertTrue(scale_ref > scale_with_ub) + + # tesing lower bounds settings + # making sure that abs is on the scale of 1e-20, so hp_value_lb can take effect + input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) * 1e-20 + scale_ref = _choose_scale_float8( + input_tensor, float8_dtype=float8_dtype, block_size=block_size + ) + hp_value_lb = 1e-12 + scale_with_lb = _choose_scale_float8( + input_tensor, + float8_dtype=float8_dtype, + block_size=block_size, + hp_value_lb=hp_value_lb, + ) + # since scale = abs_max / quant_max, larger abs_max means scale is larger + self.assertTrue(scale_ref < scale_with_lb) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) + @common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)]) def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): """Test _dequantize_affine_float8 with various configurations""" @@ -364,7 +410,7 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) # Choose quantization parameters - scale = _choose_qparams_affine_float8( + scale = _choose_scale_float8( input_tensor, float8_dtype=float8_dtype, block_size=block_size ) @@ -395,7 +441,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self): block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim # Choose quantization parameters - scale = _choose_qparams_affine_float8( + scale = _choose_scale_float8( input_tensor, float8_dtype=torch.float8_e4m3fn, block_size=block_size ) @@ -428,7 +474,10 @@ def test_float8_tensor_slicing_basic(self, granularity): # Create and quantize a model model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + model, + Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, version=1 + ), ) weight_impl = model.weight.original_weight_tensor.tensor_impl @@ -462,7 +511,10 @@ def test_float8_tensor_slicing_per_tensor(self): # Create and quantize with per-tensor granularity model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) + model, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), version=1 + ), ) original_weight = model.weight @@ -493,7 +545,8 @@ def test_float8_tensor_slicing_per_row(self): # Create and quantize with per-row granularity model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + model, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1), ) original_weight = model.weight # Shape: (32, 64) @@ -531,7 +584,10 @@ def test_float8_tensor_slicing_edge_cases(self): # Create and quantize a model model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) + model, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), version=1 + ), ) original_weight = model.weight @@ -569,7 +625,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity): quant_model = copy.deepcopy(ref_model) quantize_( quant_model, - Float8DynamicActivationFloat8WeightConfig(granularity=granularity), + Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, version=1 + ), ) # Create input with batch size that works well with slicing @@ -678,8 +736,10 @@ def test_preprocess_scale_3d_reshape(self): @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype): - quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8 - dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8 + quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8_non_decomposed + dequantize_affine_float8 = ( + torch.ops.torchao.dequantize_affine_float8_non_decomposed + ) input = torch.randn(10, 10) with torch.no_grad(): torch._dynamo.reset() @@ -700,21 +760,86 @@ def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype): expected_scale, float8_dtype=float8_dtype, ) - torch.testing.FileCheck().check( - "torch.ops.torchao.quantize_affine_float8.default" - ).run(code_q) + torch.testing.FileCheck().check(f"{quantize_affine_float8}.default").run( + code_q + ) test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code( torch.compile(dequantize_affine_float8), test_q, expected_scale, hp_dtype, ) - torch.testing.FileCheck().check( - "torch.ops.torchao.dequantize_affine_float8.default" - ).run(code_dq) + torch.testing.FileCheck().check(f"{dequantize_affine_float8}.default").run( + code_dq + ) torch.testing.assert_close(expected_quantized, test_q) torch.testing.assert_close(expected_dequantized, test_dq) + @torch.no_grad() + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0" + ) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize("float8_config_version", [1, 2]) + def test_expected_kernels_on_gpu(self, granularity, float8_config_version): + """ + Verify that float8 quantization + torch.compile results in the + expected number of kernels in the GPU trace. + """ + torch.compiler.reset() + + M, K, N = 128, 256, 512 + m = torch.nn.Sequential( + torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) + ) + if float8_config_version == 1: + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, version=1 + ) + else: + assert float8_config_version == 2 + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, + version=2, + kernel_preference=KernelPreference.TORCH, + ) + quantize_( + m, + config, + ) + + m = torch.compile(m) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + out, code = run_and_get_code(m, x) + + # triton kernel call looks like: + # triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_clone_div_expand_permute_transpose_unsqueeze_view_0.run(arg3_1, buf1, buf2, 128, 256, stream=stream0) + # scaled_mm call looks like: + # extern_kernels._scaled_mm(buf1, reinterpret_tensor(arg0_1, (256, 512), (1, 256), 0), buf2, reinterpret_tensor(arg1_1, (1, 512), (1, 1), 0), arg2_1, out_dtype=torch.bfloat16, use_fast_accum=True, out=buf3) + if granularity == PerRow(): + # one triton kernel for quantizing the activation + FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run( + code[0] + ) + # one scaled_mm call + FileCheck().check("def call(").check_count( + "._scaled_mm(", 1, exactly=True + ).run(code[0]) + else: + assert granularity == PerTensor(), "unsupported" + # three triton kernels for quantizing the activation: + # kernel 1: x_max_tmp = max(x, ...) + # kernel 2: x_max = max(x_max_tmp) + # kernel 3: x_float8 = to_float8(x, x_max) + FileCheck().check("def call(").check_count(".run(", 3, exactly=True).run( + code[0] + ) + # one scaled_mm call + FileCheck().check("def call(").check_count( + "._scaled_mm(", 1, exactly=True + ).run(code[0]) + common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 56410bab8f..983f701849 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -16,15 +16,17 @@ ) from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_weight_only, - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, ) from torchao.quantization.observer import PerRow, PerTensor from torchao.quantization.quant_api import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 + +if common_utils.SEED is None: + common_utils.SEED = 1234 try: import gemlite # noqa: F401 @@ -40,7 +42,7 @@ class TestAffineQuantizedTensorParallel(DTensorTestBase): """Basic test case for tensor subclasses""" - QUANT_METHOD_FN = staticmethod(int8_weight_only) + QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig) QUANT_METHOD_KWARGS = {} @staticmethod @@ -124,10 +126,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist(up_dist(input_dtensor)) - if not TORCH_VERSION_AT_LEAST_2_6: - # Need torch 2.6 to support compiled tensor parallelism - return - up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) @@ -135,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(int8_weight_only) + QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig) COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -146,7 +144,8 @@ def test_tp(self, dtype): class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(int4_weight_only) + QUANT_METHOD_FN = staticmethod(Int4WeightOnlyConfig) + QUANT_METHOD_KWARGS = {"version": 1} COMMON_DTYPES = [torch.bfloat16] @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -168,12 +167,12 @@ class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not has_gemlite, "gemlite not available") def test_tp_gemlite(self, dtype): - from torchao.quantization import gemlite_uintx_weight_only + from torchao.quantization import GemliteUIntXWeightOnlyConfig for packing_bitwidth in [32, 8]: for bit_width in [4, 8]: for group_size in [64, 32, None] if bit_width == 4 else [None]: - api = lambda: gemlite_uintx_weight_only( + api = lambda: GemliteUIntXWeightOnlyConfig( group_size, bit_width, packing_bitwidth ) self.QUANT_METHOD_FN = staticmethod(api) @@ -181,7 +180,7 @@ def test_tp_gemlite(self, dtype): class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight) + QUANT_METHOD_FN = staticmethod(Int8DynamicActivationInt8WeightConfig) COMMON_DTYPES = [torch.bfloat16] @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -200,7 +199,7 @@ def test_tp(self, dtype): if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): - QUANT_METHOD_FN = staticmethod(float8_weight_only) + QUANT_METHOD_FN = staticmethod(Float8WeightOnlyConfig) COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -212,7 +211,7 @@ def test_tp(self, dtype): class TestFloat8dqTensorAffineQuantizedTensorParallel( TestAffineQuantizedTensorParallel ): - QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) + QUANT_METHOD_FN = staticmethod(Float8DynamicActivationFloat8WeightConfig) QUANT_METHOD_KWARGS = {"granularity": PerTensor()} COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @@ -225,7 +224,7 @@ def test_tp(self, dtype): class TestFloat8dqRowAffineQuantizedTensorParallel( TestAffineQuantizedTensorParallel ): - QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) + QUANT_METHOD_FN = staticmethod(Float8DynamicActivationFloat8WeightConfig) QUANT_METHOD_KWARGS = {"granularity": PerRow()} COMMON_DTYPES = [torch.bfloat16] diff --git a/test/dtypes/test_fbgemm_fp8.py b/test/dtypes/test_fbgemm_fp8.py deleted file mode 100644 index 1e681d00f9..0000000000 --- a/test/dtypes/test_fbgemm_fp8.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from torch.testing._internal.common_utils import ( - TestCase, - run_tests, -) - -from torchao.float8.config import e4m3_dtype -from torchao.quantization import ( - FbgemmConfig, - quantize_, -) -from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, - is_sm_at_least_90, -) - - -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") -@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") -class TestFbgemmFp8Tensor(TestCase): - def setUp(self): - self.config = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=e4m3_dtype, - output_dtype=torch.bfloat16, - ) - self.bmm_config = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=e4m3_dtype, - output_dtype=torch.bfloat16, - transpose_input=True, - ) - self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] - - def test_linear(self): - dtype = torch.bfloat16 - device = "cuda" - input = torch.randn(1, 128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - original = linear(input) - quantize_(linear, self.config) - quantized = linear(input) - self.assertTrue(compute_error(original, quantized) > 20) - - def test_slice(self): - dtype = torch.bfloat16 - device = "cuda" - dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) - dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) - dummy1.weight = torch.nn.Parameter( - dummy.weight.narrow(0, 0, 64), requires_grad=False - ) - dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) - dummy2.weight = torch.nn.Parameter( - dummy.weight.narrow(1, 0, 128), requires_grad=False - ) - - quantize_(dummy, self.config) - weight1 = dummy.weight.narrow(0, 0, 64) - weight2 = dummy.weight.narrow(1, 0, 128) - self.assertEqual(weight1.float8_data, dummy.weight.float8_data.narrow(0, 0, 64)) - self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) - self.assertEqual( - weight2.float8_data, dummy.weight.float8_data.narrow(1, 0, 128) - ) - self.assertEqual(weight2.scale, dummy.weight.scale) - - # check for sliced weight, before and after float8 quantization - # does not differ too much - input = torch.randn(2, 256, dtype=dtype, device=device) - res_ref = dummy1(input) - dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) - res = dummy(input) - assert compute_error(res, res_ref) > 25 - - input = torch.randn(2, 128, dtype=dtype, device=device) - res_ref = dummy2(input) - dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) - res = dummy(input) - assert compute_error(res, res_ref) > 15 - - def test_slice_and_copy_(self): - l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - l.weight = torch.nn.Parameter( - torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") - ) - quantize_(l, self.config) - param = l.weight - param_data = param.data - param_data = param_data.narrow(0, 0, 512) - assert param.data.float8_data.data_ptr() == param_data.float8_data.data_ptr() - assert param.data.scale.data_ptr() == param_data.scale.data_ptr() - orig_value = param.data.float8_data[0][0].item() - - # dummy_l has random input (shouldn't be 0) - dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - quantize_(dummy_l, self.config) - quantized = dummy_l.weight - quantized = quantized.narrow(0, 0, 512) - - param_data.copy_(quantized) - - # making sure param.data is updated - assert param.data.float8_data[0][0] != orig_value - - def test_bmm(self): - class M(torch.nn.Module): - def __init__(self, weight): - super().__init__() - self.weight = weight - - def forward(self, x): - return torch.bmm(x, self.weight) - - dtype = torch.bfloat16 - device = "cuda" - input = torch.randn(10, 32, 128, dtype=dtype, device=device) - weight = torch.randn(10, 128, 256, dtype=dtype, device=device) - m = M(weight).eval() - original = m(input) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) - quantized = m(input) - self.assertTrue(compute_error(original, quantized) > 20) - - def test_to_device(self): - for device in self.GPU_DEVICES: - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device) - - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device=device) - - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device) - - -if __name__ == "__main__": - run_tests() diff --git a/test/dtypes/test_fbgemm_int4.py b/test/dtypes/test_fbgemm_int4.py deleted file mode 100644 index cba9d81ae0..0000000000 --- a/test/dtypes/test_fbgemm_int4.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from torch.testing._internal.common_utils import ( - TestCase, - run_tests, -) - -from torchao.quantization import ( - FbgemmConfig, - quantize_, -) -from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, - is_sm_at_least_90, -) - - -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") -@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") -class TestFbgemmInt4Tensor(TestCase): - def setUp(self): - self.config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - ) - self.bmm_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], - transpose_input=True, - ) - self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] - - def test_linear(self): - dtype = torch.bfloat16 - device = "cuda" - input = torch.randn(1, 128, dtype=dtype, device=device) - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - original = linear(input) - quantize_(linear, self.config) - quantized = linear(input) - self.assertTrue(compute_error(original, quantized) > 20) - - def test_slice(self): - dtype = torch.bfloat16 - device = "cuda" - dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) - dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) - dummy1.weight = torch.nn.Parameter( - dummy.weight.narrow(0, 0, 64), requires_grad=False - ) - dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) - dummy2.weight = torch.nn.Parameter( - dummy.weight.narrow(1, 0, 128), requires_grad=False - ) - - quantize_(dummy, self.config) - weight1 = dummy.weight.narrow(0, 0, 64) - weight2 = dummy.weight.narrow(1, 0, 128) - self.assertEqual( - weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64) - ) - self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64)) - self.assertEqual( - weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64) - ) - self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1)) - - # check for sliced weight, before and after float8 quantization - # does not differ too much - input = torch.randn(2, 256, dtype=dtype, device=device) - res_ref = dummy1(input) - dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) - res = dummy(input) - assert compute_error(res, res_ref) > 20 - - input = torch.randn(2, 128, dtype=dtype, device=device) - res_ref = dummy2(input) - dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) - res = dummy(input) - assert compute_error(res, res_ref) > 15 - - def test_slice_and_copy_(self): - l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - l.weight = torch.nn.Parameter( - torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") - ) - quantize_(l, self.config) - param = l.weight - param_data = param.data - param_data = param_data.narrow(0, 0, 512) - assert ( - param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr() - ) - assert param.data.scale.data_ptr() == param_data.scale.data_ptr() - assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr() - orig_value = param.data.packed_weight[0][0].item() - - # dummy_l has random input (shouldn't be 0) - dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - quantize_(dummy_l, self.config) - quantized = dummy_l.weight - quantized = quantized.narrow(0, 0, 512) - - param_data.copy_(quantized) - - # making sure param.data is updated - assert param.data.packed_weight[0][0] != orig_value - - def test_bmm(self): - class M(torch.nn.Module): - def __init__(self, weight): - super().__init__() - self.weight = weight - - def forward(self, x): - return torch.bmm(x, self.weight) - - dtype = torch.bfloat16 - device = "cuda" - input = torch.randn(10, 32, 128, dtype=dtype, device=device) - weight = torch.randn(10, 128, 256, dtype=dtype, device=device) - m = M(weight).eval() - original = m(input) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) - quantized = m(input) - self.assertTrue(compute_error(original, quantized) > 18) - - def test_to_device(self): - for device in self.GPU_DEVICES: - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device) - - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device=device) - - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) - linear.to(device) - - -if __name__ == "__main__": - run_tests() diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 237bc2bd92..ab4a13d24c 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -29,11 +29,11 @@ _floatx_unpacked_to_f32, ) from torchao.quantization import ( - fpx_weight_only, + FPXWeightOnlyConfig, quantize_, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode +from torchao.utils import is_fbcode _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] @@ -107,10 +107,6 @@ def test_to_copy_device(self, ebits, mbits): assert floatx_tensor_impl.device.type == "cpu" @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, - reason="quantization only works with torch.compile for 2.5+", - ) @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) @@ -122,7 +118,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias, dtype): linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype) fpx_linear = copy.deepcopy(linear) - quantize_(fpx_linear, fpx_weight_only(ebits, mbits)) + quantize_(fpx_linear, FPXWeightOnlyConfig(ebits, mbits)) x = torch.randn(N, IC, device=device, dtype=dtype) expected = fpx_linear(x) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index f52644cdf3..2a711413f0 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -20,6 +20,7 @@ apply_activation_checkpointing, ) from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.testing._internal import common_utils from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( @@ -29,6 +30,9 @@ run_tests, ) +if common_utils.SEED is None: + common_utils.SEED = 1234 + import torchao from packaging import version from torchao.dtypes._nf4tensor_api import nf4_weight_only @@ -39,7 +43,7 @@ to_nf4, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least bnb_available = False @@ -119,7 +123,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( - TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI" + torch_version_at_least("2.7.0"), reason="Failing in CI" ) # TODO: fix this @skip_if_rocm("ROCm enablement in progress") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @@ -146,7 +150,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @skip_if_rocm("ROCm enablement in progress") @unittest.skipIf( - TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI" + torch_version_at_least("2.7.0"), reason="Failing in CI" ) # TODO: fix this @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_nf4_bnb_linear(self, dtype: torch.dtype): @@ -435,7 +439,17 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): inner_tensor = getattr(viewed_tensor, attr) self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) - @parametrize("input_size", [(512 * 512,), (512, 512)]) + @parametrize("input_size", [(512, 512)]) + def test_tensor_2d_view_valid(self, input_size: Tuple[int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + viewed_tensor = nf4_tensor.view(input_size) + self.assertEqual(viewed_tensor.dim(), 2) + self.assertEqual(viewed_tensor.numel(), math.prod(input_size)) + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(viewed_tensor, attr) + self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) + + @parametrize("input_size", [(512 * 512,)]) def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): nf4_tensor = to_nf4(torch.randn(input_size)) if len(input_size) == 1: @@ -443,11 +457,6 @@ def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): NotImplementedError, "aten.view\\(NF4Tensor\\) with size" ): nf4_tensor.view(input_size) - if len(input_size) == 2: - with self.assertRaisesRegex( - NotImplementedError, "aten.view\\(NF4Tensor\\) with len\\(size\\)" - ): - nf4_tensor.view(input_size) @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]): @@ -741,6 +750,42 @@ def _test_qlora_fsdp2( self.assertEqual(fsdp_loss, base_loss) +class TestComm(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(2) + def test_comm(self): + self.run_subtests( + {"input_size": [512, 2048]}, + self._test_comm, + ) + + def _test_comm(self, input_size: int): + from torch.distributed._composable.fsdp import fully_shard + from torch.distributed._tensor import distribute_tensor + + model = nn.Linear(input_size, input_size, device="cuda") + origin_tensor = model.weight + origin_nf4_tensor = to_nf4(origin_tensor) + model = fully_shard(model) + sharded_tensor = model.weight + sharded_origin_nf4_tensor = distribute_tensor( + origin_nf4_tensor, + sharded_tensor.device_mesh, + sharded_tensor.placements, + ) + + sharded_nf4_detach = sharded_origin_nf4_tensor.detach() + resumed_full_tensor = sharded_nf4_detach.full_tensor() + + self.assertEqual( + origin_nf4_tensor.get_original_weight(), + resumed_full_tensor.get_original_weight(), + ) + + instantiate_parametrized_tests(TestNF4Linear) instantiate_parametrized_tests(TestFSDPOps) diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index f7656ef19e..a1d87dbc91 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -34,7 +34,6 @@ _replace_with_custom_fn_if_matches_filter, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def _apply_weight_only_uint4_quant(model): @@ -243,16 +242,7 @@ def forward(self, x): # program capture m = copy.deepcopy(m_eager) - if TORCH_VERSION_AT_LEAST_2_5: - m = torch.export.texport_for_training( - m, - example_inputs, - ).module() - else: - m = torch._export.capture_pre_autograd_graph( - m, - example_inputs, - ).module() + m = torch.export.export(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 35c722365d..cb0c88b21c 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -7,31 +7,23 @@ import torch from torchao.dtypes.uintx.uintx_layout import to_uintx -from torchao.quantization.quant_api import quantize_, uintx_weight_only +from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, dequantize_affine, quantize_affine, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, -) -# torch.uintx dtypes are introduced in 2.3 -if TORCH_VERSION_AT_LEAST_2_3: - dtypes = ( - torch.uint1, - torch.uint2, - torch.uint3, - torch.uint4, - torch.uint5, - torch.uint6, - torch.uint7, - ) -else: - dtypes = () +dtypes = ( + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, +) group_sizes = [32, 64, 128] devices = ["cpu", "cuda"] @@ -65,13 +57,10 @@ def forward(self, x): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): scale = 512 fp16_mod_on_cpu = Linear16(scale, "cpu") - quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size)) + quantize_(fp16_mod_on_cpu, UIntXWeightOnlyConfig(dtype, group_size=group_size)) test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu") output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu) fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda") @@ -86,13 +75,10 @@ def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_weight_only_model_quant(dtype, group_size, device): scale = 512 fp16 = Linear16(scale, device) - quantize_(fp16, uintx_weight_only(dtype, group_size=group_size)) + quantize_(fp16, UIntXWeightOnlyConfig(dtype, group_size=group_size)) uintx = torch.compile(fp16, fullgraph=True) test_input = torch.randn(scale * 2, dtype=torch.float16, device=device) output = uintx.forward(test_input) @@ -103,9 +89,6 @@ def test_uintx_weight_only_model_quant(dtype, group_size, device): @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" -) def test_uintx_weight_only_quant(dtype, group_size, device): input_float = torch.randn((1, 256), dtype=torch.float16, device=device) mapping_type = MappingType.SYMMETRIC @@ -140,41 +123,26 @@ def test_uintx_weight_only_quant(dtype, group_size, device): @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" -) def test_uintx_target_dtype(dtype): - from torchao.quantization.quant_api import uintx_weight_only - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - quantize_(linear, uintx_weight_only(dtype)) + quantize_(linear, UIntXWeightOnlyConfig(dtype)) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, - reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+", -) def test_uintx_target_dtype_compile(dtype): - from torchao.quantization.quant_api import uintx_weight_only - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - quantize_(linear, uintx_weight_only(dtype)) + quantize_(linear, UIntXWeightOnlyConfig(dtype)) linear = torch.compile(linear) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" -) def test_uintx_model_size(dtype): - from torchao.quantization.quant_api import uintx_weight_only from torchao.utils import get_model_size_in_bytes # scale size = 1/64 * 2 bytes = 1/32 bytes @@ -194,6 +162,6 @@ def test_uintx_model_size(dtype): ) bf16_size = get_model_size_in_bytes(linear) # make sure it runs - quantize_(linear[0], uintx_weight_only(dtype)) + quantize_(linear[0], UIntXWeightOnlyConfig(dtype)) quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size diff --git a/test/float8/test_auto_filter.py b/test/float8/test_auto_filter.py new file mode 100644 index 0000000000..927c33d195 --- /dev/null +++ b/test/float8/test_auto_filter.py @@ -0,0 +1,86 @@ +import pytest +import torch.nn as nn + +from torchao.float8 import _auto_filter_for_recipe +from torchao.float8.float8_linear_utils import ( + _auto_filter_for_rowwise, + _auto_filter_for_tensorwise, +) + + +@pytest.mark.parametrize( + "recipe_type,module_dims,fqn,filter_fqns,expected", + [ + # Tensorwise tests + ("tensorwise", (8192, 2048), "valid.layer", [], True), + # FQN matches filter + ("tensorwise", (8192, 2048), "skip_layer.linear", ["skip_layer"], False), + # Threshold fail + ("tensorwise", (4096, 1024), "valid.layer", [], False), + # Rowwise tests + ("rowwise", (4096, 8192), "valid.layer", [], True), + ("rowwise", (4096, 8192), "skip_layer.linear", ["skip_layer"], False), + # Combined threshold fail + ( + "rowwise", + (2048, 4096), + "valid.layer", + [], + False, + ), + ], +) +def test_end_to_end_filtering(recipe_type, module_dims, fqn, filter_fqns, expected): + """Test complete filtering workflow for both recipe types.""" + in_features, out_features = module_dims + + # Get the filter function + filter_func = _auto_filter_for_recipe(recipe_type, filter_fqns) + + # Create test module + test_module = nn.Linear(in_features, out_features) + + # Test filtering + result = filter_func(test_module, fqn) + assert result is expected + + +def test_exact_boundary_dimensions_rowwise(): + """Test exact boundary dimensions for rowwise filtering.""" + # Test exact thresholds + module_n_2048 = nn.Linear(4096, 2048) # N exactly 2048 + assert _auto_filter_for_rowwise(module_n_2048, "layer", []) is False + + module_k_1024 = nn.Linear(1024, 4112) # K exactly 1024 + assert _auto_filter_for_rowwise(module_k_1024, "layer", []) is False + + +def test_exact_boundary_dimensions_tensorwise(): + """Test exact boundary dimensions for tensorwise filtering.""" + # Test exact combined threshold + module_boundary = nn.Linear(4096, 1024) # K=4096, N=1024 + assert _auto_filter_for_tensorwise(module_boundary, "layer", []) is False + + +def test_partial_fqn_matching(): + """Test partial FQN matching behavior.""" + filter_fqns = ["embed", "norm"] + large_module = nn.Linear(8192, 4096) + + # (fqn, expected result from filter func) + test_cases = [ + ("model.embeddings.linear", False), # Contains "embed" + ("layer.norm.weight", False), # Contains "norm" + ("model.transformer.layer", True), # Doesn't contain either + ("embedding_layer", False), # Contains "embed" as substring + ] + + for fqn, expected_result in test_cases: + result_tensorwise = _auto_filter_for_tensorwise(large_module, fqn, filter_fqns) + result_rowwise = _auto_filter_for_rowwise(large_module, fqn, filter_fqns) + assert result_tensorwise is expected_result, ( + f"Tensorwise result mismatch: fqn={fqn}, expected={expected_result}, actual={result_tensorwise}" + ) + assert result_rowwise is expected_result, ( + f"Rowwise result mismatch: fqn={fqn}, expected={expected_result}, actual={result_rowwise}" + ) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 15099dc2c1..1f9ae19346 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -8,23 +8,11 @@ import random import re import unittest -import warnings import pytest import torch import torch.nn as nn -from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - from torchao.float8.config import ( Float8LinearConfig, Float8LinearRecipeName, @@ -34,16 +22,14 @@ e5m2_dtype, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, -) +from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_ops import addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( get_maybe_axiswise_dim, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, ScaledMMConfig, @@ -56,19 +42,25 @@ tensor_to_scale, ) from torchao.testing.training.test_utils import get_test_float8_linear_config -from torchao.utils import is_MI300, is_ROCM +from torchao.testing.utils import skip_if_rocm +from torchao.utils import ( + is_MI300, + is_ROCM, + is_sm_at_least_89, + is_sm_at_least_90, +) random.seed(0) torch.manual_seed(0) -def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: +def bitwise_identical(a: Float8TrainingTensor, b: Float8TrainingTensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" return True -class TestFloat8Tensor: +class TestFloat8TrainingTensor: def test_preserves_dtype(self) -> None: # hp means high precision, lp means low precision hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) @@ -130,7 +122,7 @@ def test_copy_(self): with pytest.raises(RuntimeError): fp8_a.copy_(b) # Should fail - fp8_b = Float8Tensor( + fp8_b = Float8TrainingTensor( torch.empty(16, dtype=e4m3_dtype), scale_a, torch.bfloat16, @@ -379,21 +371,21 @@ def test_linear_from_config_params( ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) + @pytest.mark.parametrize( + "linear_dtype", [torch.bfloat16, torch.float16, torch.float32] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_90(), "CUDA capability < 9.0" + ) @skip_if_rocm("ROCm enablement in progress") def test_linear_from_recipe( self, recipe_name, x_shape, + linear_dtype: torch.dtype, linear_bias: bool, ): - if torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - - linear_dtype = torch.bfloat16 x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) config = Float8LinearConfig.from_recipe_name(recipe_name) @@ -409,51 +401,30 @@ def test_linear_from_recipe( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) + @pytest.mark.parametrize( + "recipe_name", + [ + Float8LinearRecipeName.TENSORWISE, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, + ], + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_autocast_outputs( self, emulate: bool, linear_dtype: torch.dtype, + recipe_name: Float8LinearRecipeName, ): m_ref = nn.Sequential( nn.Linear(32, 32, device="cuda", dtype=linear_dtype), nn.Linear(32, 32, device="cuda", dtype=linear_dtype), ) - config = Float8LinearConfig( - emulate=emulate, - ) - m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) - - # autocast off - x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - y = m(x) - assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" - - # autocast on - with torch.autocast("cuda"): - y = m(x) - assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" - - with torch.autocast("cuda", dtype=torch.bfloat16): - y = m(x) - assert y.dtype == torch.bfloat16, ( - f"y.dtype is {y.dtype}, expected {torch.bfloat16}" - ) - - @pytest.mark.parametrize( - "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] - ) - @pytest.mark.parametrize( - "emulate", [True, False] if is_sm_at_least_89() else [True] - ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - config = Float8LinearConfig(emulate=emulate) - m = Float8Linear.from_float(copy.deepcopy(m), config) + config = Float8LinearConfig.from_recipe_name(recipe_name) + # work around config being frozen + object.__setattr__(config, "emulate", emulate) - # Cast the module to dtype - m = m.to(dtype=linear_dtype) + m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) @@ -497,10 +468,10 @@ def test_quantize(self): m = nn.Sequential(nn.Linear(32, 32)).cuda() m = convert_to_float8_training(m) assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" - from torchao.quantization.quant_api import float8_weight_only, quantize_ + from torchao.quantization import Float8WeightOnlyConfig, quantize_ - quantize_(m, float8_weight_only()) - assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, ( + quantize_(m, Float8WeightOnlyConfig()) + assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, ( "Post quantization dtype should be torch.float8_e4m3fn" ) with torch.no_grad(): diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index aaf9d3d3f5..04f03bb0ee 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -10,16 +10,6 @@ from io import StringIO import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.nn as nn from torch._dynamo.test_case import TestCase as DynamoTestCase @@ -36,8 +26,16 @@ from torchao.float8.float8_scaling_utils import ( hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig +from torchao.float8.float8_training_tensor import ( + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, +) from torchao.testing.training.test_utils import get_test_float8_linear_config +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) def _test_compile_base( @@ -238,7 +236,7 @@ def forward(self, x): "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): - """Test that having Float8Tensor object at the boundary of a subgraph""" + """Test that having Float8TrainingTensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=True).cuda() compiled_mod = copy.deepcopy(mod) @@ -254,7 +252,7 @@ def test_float8_with_graph_break_in_the_middle(self): "CUDA with float8 support not available", ) def test_float8_graph_input(self): - """Test that having Float8Tensor object as a graph input""" + """Test that having Float8TrainingTensor object as a graph input""" def to_float(x): return x.to_original_precision() @@ -278,7 +276,7 @@ def to_float(x): "CUDA with float8 support not available", ) def test_float8_graph_output(self): - """Test that having Float8Tensor object as a graph output works""" + """Test that having Float8TrainingTensor object as a graph output works""" cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=False).cuda() compiled_mod = torch.compile(mod, backend=cnts) @@ -290,14 +288,14 @@ def test_float8_graph_output(self): for tensor in tensors: assert not isinstance( getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor - ), "Float8Tensor should not contain any FakeTensors!" + ), "Float8TrainingTensor should not contain any FakeTensors!" assert isinstance(y_compiled._orig_dtype, torch.dtype), ( - "Float8Tensor._orig_dtype should be a dtype but got {}".format( + "Float8TrainingTensor._orig_dtype should be a dtype but got {}".format( type(y_compiled._orig_dtype) ) ) assert isinstance(y_compiled._linear_mm_config.output.emulate, bool), ( - "Float8Tensor._emulate should be a bool but got {}".format( + "Float8TrainingTensor._emulate should be a bool but got {}".format( type(y_compiled._linear_mm_config.output.emulate) ) ) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 5509eb1cc2..7285d4bbc0 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -12,14 +12,7 @@ import os -import pytest import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -37,8 +30,8 @@ ) from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, hp_tensor_and_scale_to_float8, @@ -94,8 +87,8 @@ def _test_scaled_mm(mesh: DeviceMesh, size=16): dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False) - assert isinstance(dist_x_fp8.to_local(), Float8Tensor) - assert isinstance(dist_y_fp8.to_local(), Float8Tensor) + assert isinstance(dist_x_fp8.to_local(), Float8TrainingTensor) + assert isinstance(dist_y_fp8.to_local(), Float8TrainingTensor) assert dist_x_fp8.to_local()._orig_dtype == torch.float32 out_fp8 = torch.mm(dist_x_fp8, dist_y_fp8) local_fp8_out = out_fp8.to_local() @@ -128,7 +121,7 @@ def _test_fp8_redistribute(mesh: DeviceMesh, size=16): if isinstance(out_local, AsyncCollectiveTensor): out_local = out_local.wait() - assert isinstance(out_local, Float8Tensor) + assert isinstance(out_local, Float8TrainingTensor) assert out_local._data.dtype == fp8_dtype @@ -183,7 +176,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): loss.backward() -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32): tensorwise_config = Float8LinearConfig(emulate=True) _test_lowp_mlp_tensor_parallelism_base( mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True @@ -198,7 +191,7 @@ def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): ) -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32): tensorwise_config = Float8LinearConfig(emulate=True) _test_lowp_mlp_tensor_parallelism_base( mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True diff --git a/test/float8/test_everything.sh b/test/float8/test_everything.sh index 068c75de63..6d6f835a46 100755 --- a/test/float8/test_everything.sh +++ b/test/float8/test_everything.sh @@ -12,6 +12,7 @@ IS_ROCM=$(rocm-smi --version || true) pytest test/float8/test_base.py pytest test/float8/test_compile.py pytest test/float8/test_numerics_integration.py +pytest test/float8/test_auto_filter.py # These tests do not work on ROCm yet if [ -z "$IS_ROCM" ] diff --git a/test/float8/test_everything_multi_gpu.sh b/test/float8/test_everything_multi_gpu.sh new file mode 100755 index 0000000000..6f391f9699 --- /dev/null +++ b/test/float8/test_everything_multi_gpu.sh @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +#!/bin/bash + +# terminate script on first error +set -e +IS_ROCM=$(rocm-smi --version || true) + +# These tests do not work on ROCm yet +if [ -z "$IS_ROCM" ] +then +./test/float8/test_fsdp.sh +./test/float8/test_fsdp_compile.sh +./test/float8/test_dtensor.sh +python test/float8/test_fsdp2/test_fsdp2.py +fi + +echo "all multi gpu tests successful" diff --git a/test/float8/test_everything_single_gpu.sh b/test/float8/test_everything_single_gpu.sh new file mode 100755 index 0000000000..0b72951126 --- /dev/null +++ b/test/float8/test_everything_single_gpu.sh @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +#!/bin/bash + +# terminate script on first error +set -e + +pytest test/float8/test_base.py --verbose -s +pytest test/float8/test_compile.py --verbose -s +pytest test/float8/test_numerics_integration.py --verbose -s +pytest test/float8/test_auto_filter.py --verbose -s + +echo "all float8 single gpu tests successful" diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index 888c7aadb1..c253af55ea 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -10,10 +10,6 @@ from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) # source for notable single-precision cases: diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py index 3017c8b539..a25bd53509 100644 --- a/test/float8/test_fsdp.py +++ b/test/float8/test_fsdp.py @@ -16,13 +16,6 @@ import warnings import fire -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.distributed as dist import torch.multiprocessing as mp diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index b4c7f9fd15..e7b3b8be91 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -10,13 +10,6 @@ from typing import Any, List, Optional import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - import torch import torch._dynamo.testing import torch.distributed as dist @@ -41,12 +34,13 @@ from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.float8_tensor import GemmInputRole +from torchao.float8.float8_training_tensor import GemmInputRole from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.training.fsdp2_utils import ( check_parity_bf16_mp, check_parity_no_mp, ) +from torchao.utils import is_sm_at_least_89 if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index 93c7735149..ea93d5949d 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -13,14 +13,7 @@ import copy import os -import pytest import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor.parallel import parallelize_module @@ -34,6 +27,8 @@ ) from torchao.testing.training.dtensor_utils import ToyModel +torch.set_float32_matmul_precision("high") + def setup_distributed(): world_size = int(os.environ.get("WORLD_SIZE", -1)) @@ -61,7 +56,7 @@ def _test_fp8_mlp_tensor_parallelism_base( enable_fsdp_float8_all_gather=True, ) - toy_model = ToyModel().to(device) + toy_model = ToyModel(size).to(device) tp_model = copy.deepcopy(toy_model) tp_model = convert_to_float8_training(tp_model, config=config) @@ -94,11 +89,11 @@ def _test_fp8_mlp_tensor_parallelism_base( # TODO(future PR): test numerics, and add more cases -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32): _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32): _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index a78a30925c..eb32c40aa3 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -12,13 +12,6 @@ import warnings import fire -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.distributed as dist import torch.multiprocessing as mp diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index db02444109..8da36cef8e 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -10,16 +10,6 @@ from typing import Optional import pytest - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - is_sm_at_least_89, - is_sm_at_least_90, -) - -if not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - import torch import torch.nn as nn import torch.nn.functional as F @@ -34,6 +24,10 @@ ) from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.training.test_utils import get_test_float8_linear_config +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) torch.manual_seed(0) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index a6990549a3..09bdfa8e61 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -8,16 +8,13 @@ import torch from torchao.quantization import ( + Int4WeightOnlyConfig, MappingType, + UIntXWeightOnlyConfig, ZeroPointDomain, - int4_weight_only, quantize_, - uintx_weight_only, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, -) cuda_available = torch.cuda.is_available() @@ -58,9 +55,11 @@ def _eval_hqq(dtype): ) dummy_linear.weight.data = W if dtype == torch.uint4: - config = int4_weight_only(group_size=max(block_size), use_hqq=True) + config = Int4WeightOnlyConfig( + group_size=max(block_size), use_hqq=True, version=1 + ) else: - config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True) + config = UIntXWeightOnlyConfig(dtype, group_size=max(block_size), use_hqq=True) quantize_(dummy_linear, config) q_tensor_hqq = dummy_linear.weight @@ -78,7 +77,6 @@ def _eval_hqq(dtype): @unittest.skipIf(not cuda_available, "Need CUDA available") -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+") class TestHQQ(unittest.TestCase): def _test_hqq( self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index e6a8341f09..f99cf4a1b4 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -17,6 +17,7 @@ from parameterized import parameterized from torch._dynamo import config from torch._inductor.utils import run_and_get_code +from torch.testing import FileCheck import torchao from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout @@ -37,14 +38,12 @@ # APIs to be deprecated (used for torch 2.2.2 and 2.3) from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, _replace_with_custom_fn_if_matches_filter, - change_linear_weights_to_int4_woqtensors, - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, quantize_, ) from torchao.quantization.quant_primitives import ( @@ -66,27 +65,24 @@ LoggingTensorMode, _apply_logging_hook, _fqn_to_op_to_shape_to_count, + _quant_int8_dynamic_per_token_linear, + _quantize_activation_per_token_absmax, compute_error, dequantize_per_channel, dynamically_quantize_per_channel, - quant_int8_dynamic_per_token_linear, - quantize_activation_per_token_absmax, ) from torchao.quantization.utils import ( compute_error as SQNR, ) from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_7, benchmark_model, check_cpu_version, check_xpu_version, is_fbcode, + is_sm_at_least_89, is_sm_at_least_90, + torch_version_at_least, unwrap_tensor_subclass, ) @@ -113,65 +109,55 @@ def _int8wo_api(mod): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_weight_only(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5 or ( - not TORCH_VERSION_AT_LEAST_2_6 and torch._inductor.config.freezing - ): - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_woqtensors(mod) + quantize_(mod, Int8WeightOnlyConfig(set_inductor_config=False)) def _int8wo_groupwise_api(mod): group_size = 32 - quantize_(mod, int8_weight_only(group_size=group_size, set_inductor_config=False)) + quantize_( + mod, Int8WeightOnlyConfig(group_size=group_size, set_inductor_config=False) + ) def _int8da_int8w_api( mod, act_mapping_type=MappingType.SYMMETRIC, ): - if TORCH_VERSION_AT_LEAST_2_4: - quantize_( - mod, - int8_dynamic_activation_int8_weight( - act_mapping_type=act_mapping_type, - set_inductor_config=False, - ), - ) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - change_linear_weights_to_int8_dqtensors(mod) + quantize_( + mod, + Int8DynamicActivationInt8WeightConfig( + act_mapping_type=act_mapping_type, + set_inductor_config=False, + ), + ) def _int4wo_api(mod, use_hqq=False): if check_cpu_version(next(mod.parameters()).device): quantize_( mod, - int4_weight_only( - layout=Int4CPULayout(), use_hqq=use_hqq, set_inductor_config=False + Int4WeightOnlyConfig( + layout=Int4CPULayout(), + use_hqq=use_hqq, + set_inductor_config=False, + version=1, ), ) unwrap_tensor_subclass(mod) elif check_xpu_version(next(mod.parameters()).device): quantize_( - mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False + mod, + Int4WeightOnlyConfig( + layout=Int4XPULayout(), set_inductor_config=False, version=1 + ), ) unwrap_tensor_subclass(mod) - elif TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int4_weight_only(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) else: - change_linear_weights_to_int4_woqtensors(mod) + quantize_(mod, Int4WeightOnlyConfig(set_inductor_config=False, version=1)) def _int8da_int4w_api(mod): - quantize_(mod, int8_dynamic_activation_int4_weight(set_inductor_config=False)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) + quantize_(mod, Int8DynamicActivationInt4WeightConfig(set_inductor_config=False)) # TODO: use this to reduce the number of tests @@ -390,7 +376,6 @@ def test_swap(self): assert torch.allclose(y_ref, y) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") def test_weight_t_and_non_t_numerics_match(self): # verify that numerics match whether weight is stored # in transposed format (for cuBLAS) vs non-transposed format @@ -554,7 +539,7 @@ def test_dynamic_quant_per_channel_numerics_cuda(self): def _test_quantize_per_token_impl(self, device, dtype): x = torch.randn(3, 3, 3, device=device, dtype=dtype) - xq, scales = quantize_activation_per_token_absmax(x) + xq, scales = _quantize_activation_per_token_absmax(x) block_size = (1, 1, 3) x_dq = dequantize_affine( xq, block_size, scales, None, torch.int8, output_dtype=x.dtype @@ -571,6 +556,11 @@ def test_quantize_per_token_cuda(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_quantize_per_token_impl("cuda", dtype) + @unittest.skipIf(not torch.xpu.is_available(), "XPU not available") + def test_quantize_per_token_xpu(self): + for dtype in (torch.float32, torch.float16, torch.bfloat16): + self._test_quantize_per_token_impl("xpu", dtype) + def _test_per_token_linear_impl(self, device, dtype): x = torch.randn(2, 16, 8, device=device, dtype=dtype) w = torch.randn(16, 8, device=device, dtype=dtype) @@ -578,7 +568,7 @@ def _test_per_token_linear_impl(self, device, dtype): # Note: need to make the weight contiguous because we are # testing in eager mode and cuBlas will not give correct results # for a transposed weight - y = quant_int8_dynamic_per_token_linear( + y = _quant_int8_dynamic_per_token_linear( x, wq.t().contiguous(), w_scales, None, dtype ) y_ref = torch.matmul(x, w.t()) @@ -707,8 +697,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -727,8 +715,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": @@ -786,9 +772,6 @@ def _test_lin_weight_subclass_impl( ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen" - ) def test_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( Int8DynamicallyQuantizedLinearWeight.from_float, @@ -805,9 +788,6 @@ def test_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8DynamicallyQuantizedLinearWeight.from_float, @@ -817,9 +797,6 @@ def test_aq_int8_dynamic_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skip( "This segfaults in CI cuda only, disable to unblock PR, we can investigate " "later if needed" @@ -833,9 +810,6 @@ def test_aq_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight2.from_float, @@ -845,9 +819,6 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQInt8WeightOnlyQuantizedLinearWeight3.from_float, @@ -857,9 +828,6 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( @@ -889,9 +857,6 @@ def test_autoquantizable_flatten_unflatten(self): for device, dtype in COMMON_DEVICE_DTYPE ] ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( @@ -916,9 +881,6 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass( ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch" - ) @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") @unittest.skip("TODO this is not working correctly") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): @@ -930,8 +892,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": @@ -950,8 +910,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") @unittest.skip("Skip to fix CI until we deprecate these APIs long term") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): @@ -1022,14 +980,8 @@ def _test_lin_weight_subclass_api_impl( ) ) ) + @unittest.skip("skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): - if ( - not TORCH_VERSION_AT_LEAST_2_5 - and dtype in (torch.float16, torch.bfloat16) - and act_mapping is MappingType.ASYMMETRIC - and device == "cpu" - ): - self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") api = partial( _int8da_int8w_api, act_mapping_type=act_mapping, @@ -1039,12 +991,6 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): - if ( - not TORCH_VERSION_AT_LEAST_2_6 - and dtype in (torch.float16, torch.bfloat16) - and device == "cpu" - ): - self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_lin_weight_subclass_api_impl( _int8wo_api, device, 40, test_dtype=dtype @@ -1052,9 +998,7 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch._inductor.config.patch({"freezing": True}) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after." - ) + @skip_if_rocm("Test flaky on ROCm, under investigation") def test_int8_weight_only_quant_with_freeze(self, device, dtype): torch._dynamo.reset() self._test_lin_weight_subclass_api_impl( @@ -1062,8 +1006,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1075,7 +1017,6 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "int4 hqq requires torch nightly.") def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1089,14 +1030,12 @@ def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater" - ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_gemlite_layout(self, device, dtype): + from torchao.quantization import GemliteUIntXWeightOnlyConfig + if dtype != torch.float16: self.skipTest("gemlite only works for fp16 dtype") - from torchao.quantization import gemlite_uintx_weight_only if device == "cpu": self.skipTest(f"gemlite is for cuda, not {device}") @@ -1105,7 +1044,7 @@ def test_gemlite_layout(self, device, dtype): for group_size in [64, 32, None] if bit_width == 4 else [None]: api = lambda mod: quantize_( mod, - gemlite_uintx_weight_only( + GemliteUIntXWeightOnlyConfig( group_size, bit_width, packing_bitwidth ), ) @@ -1127,7 +1066,7 @@ def test_gemlite_layout(self, device, dtype): # test that shapes with non divisible by 128 shapes aren't causing errors self._test_lin_weight_subclass_api_impl( - lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)), + lambda mod: quantize_(mod, GemliteUIntXWeightOnlyConfig(None, 4, 32)), device, 15, test_shape=[1, 1025, 513], @@ -1135,8 +1074,6 @@ def test_gemlite_layout(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: @@ -1154,20 +1091,13 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): ): for groupsize in [64, 32]: for layout in layout_list: - kwargs = {"groupsize": groupsize, "layout": layout} + kwargs = {"groupsize": groupsize, "layout": layout, "version": 1} def api(mod): kwargs_copy = kwargs.copy() - if TORCH_VERSION_AT_LEAST_2_4: - kwargs_copy["group_size"] = groupsize - del kwargs_copy["groupsize"] - quantize_(mod, int4_weight_only(**kwargs_copy)) - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(mod) - else: - kwargs_copy["inner_k_tiles"] = inner_k_tiles - del kwargs_copy["layout"] - change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy) + kwargs_copy["group_size"] = groupsize + del kwargs_copy["groupsize"] + quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy)) self._test_lin_weight_subclass_api_impl( api, @@ -1185,7 +1115,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - quantize_(m, int8_dynamic_activation_int8_weight()) + quantize_(m, Int8DynamicActivationInt8WeightConfig()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -1225,7 +1155,7 @@ def test_weight_only_groupwise_embedding_quant(self): quantize_( m, - int8_weight_only(group_size=group_size), + Int8WeightOnlyConfig(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding), ) y_q = m(input) @@ -1248,11 +1178,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): self.skipTest("test requires SM capability of at least (8, 0).") from torch._inductor import config - mixed_mm_key, mixed_mm_val = ( - ("mixed_mm_choice", "triton") - if TORCH_VERSION_AT_LEAST_2_5 - else ("force_mixed_mm", True) - ) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") with config.patch( { @@ -1285,11 +1211,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): torch.manual_seed(0) from torch._inductor import config - mixed_mm_key, mixed_mm_val = ( - ("mixed_mm_choice", "triton") - if TORCH_VERSION_AT_LEAST_2_5 - else ("force_mixed_mm", True) - ) + mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") with config.patch( { @@ -1391,18 +1313,10 @@ def test_save_load_dqtensors(self, device, dtype): @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): - if ( - not TORCH_VERSION_AT_LEAST_2_6 - and dtype in (torch.float16, torch.bfloat16) - and device == "cpu" - ): - self.skipTest("Regression fixed after torch 2.6") undo_recommended_configs() self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch 2.3+.") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1412,9 +1326,6 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "fullgraph requires torch nightly." - ) def test_fullgraph(self): lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( @@ -1463,7 +1374,7 @@ def test_shape_logger(self): class SmoothquantIntegrationTest(unittest.TestCase): @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported") + @unittest.skip("Seg fault?") def test_non_dynamically_quantizable_linear(self): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): self.skipTest("test requires SM capability of at least (8, 0).") @@ -1558,7 +1469,6 @@ class TestAutoQuant(unittest.TestCase): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_one_input(self, device, dtype, m, k, n): undo_recommended_configs() print("(m, k, n): ", (m, k, n)) @@ -1600,7 +1510,6 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_compile(self, device, dtype, m1, m2, k, n): undo_recommended_configs() @@ -1622,9 +1531,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # Skip certain shapes on older PyTorch versions - if (m1 == 1 or m2 == 1) and not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") # TODO remove this once https://github.com/pytorch/pytorch/issues/155838 is resolved if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} is flaky, skipping") @@ -1653,7 +1559,6 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): self.assertTrue(sqnr >= 30) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_mha(self, device, dtype): if device != "cuda" or not torch.cuda.is_available(): self.skipTest(f"autoquant currently does not support {device}") @@ -1676,12 +1581,11 @@ def forward(self, x): assert not isinstance(mod.mha.out_proj.weight, AutoQuantizableLinearWeight) assert isinstance(mod.lin.weight, AutoQuantizableLinearWeight) mod(*input) - from torchao.quantization.autoquant import AUTOQUANT_CACHE + from torchao.quantization.autoquant import _AUTOQUANT_CACHE - assert len(AUTOQUANT_CACHE) > 0 + assert len(_AUTOQUANT_CACHE) > 0 @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_manual(self, device, dtype): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1731,7 +1635,6 @@ def test_autoquant_manual(self, device, dtype): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.") def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1741,9 +1644,27 @@ def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n): self.skipTest("bfloat16 requires sm80+") if m1 == 1 or m2 == 1: self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") - # This test fails on v0.4.0 and torch 2.4, so skipping for now. - if m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4") + + # Note: This test was incorrectly written before with this skip condition: + # + # m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5: + # + # This is actually equivalent to: + # + # m1 == 1 or (m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5) + # + # which means we always skips the test as long as `m1 == 1` regardless of + # the pytorch version, which was not the intended behavior. Unfortunately, + # unskipping this test now leads to the following error when calling + # `aten._int_mm`: + # + # RuntimeError: self.size(0) needs to be greater than 16, but got 1 + # + # Therefore, we keep around this skip condition for now since it doesn't + # change the test behavior from before. For more details, please see + # https://github.com/pytorch/ao/pull/2720. + if m1 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} is not supported") class NeedsKwargs(torch.nn.Module): def __init__(self): @@ -1778,7 +1699,6 @@ def forward(self, x, y): ], ) ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "autoquant requires 2.3+.") def test_autoquant_double_access(self, device, dtype, m, k, n): undo_recommended_configs() if device != "cuda" or not torch.cuda.is_available(): @@ -1831,9 +1751,6 @@ def test_autoquant_min_sqnr(self, device, dtype): self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." - ) def test_autoquant_hp_float(self): device = "cuda" dtype = torch.float32 @@ -1864,9 +1781,6 @@ def test_autoquant_hp_float(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." - ) @unittest.skipIf(not has_gemlite, "gemlite not available") def test_autoquant_int4wo(self, device, dtype): if device == "cpu": @@ -1902,9 +1816,6 @@ def test_autoquant_int4wo(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." - ) @unittest.skipIf( True, "Skipping for now, do to lowering bug in inductor" ) # TODO unblock when fixed @@ -1944,7 +1855,6 @@ def test_autoquant_float8(self, device, dtype): self.assertGreater(compute_error(ref, out), 20) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @unittest.skip( "AOTI tests are failing right now, repro by commenting out the skip and run:" @@ -1955,11 +1865,6 @@ class TestAOTI(unittest.TestCase): list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) def test_aoti(self, api, test_device, test_dtype): - if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda": - self.skipTest( - f"{api} in {test_device} is not support for aoti compilation yet" - ) - if ( test_device == "cuda" and torch.cuda.is_available() @@ -1992,7 +1897,7 @@ def forward(self, x): model(x) api(model) - if not TORCH_VERSION_AT_LEAST_2_7: + if not torch_version_at_least("2.7.0"): unwrap_tensor_subclass(model) # running model @@ -2007,7 +1912,6 @@ def forward(self, x): ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") class TestExport(unittest.TestCase): @parameterized.expand( @@ -2052,7 +1956,7 @@ def forward(self, x): model(x) api(model) - if not TORCH_VERSION_AT_LEAST_2_7: + if not torch_version_at_least("2.7.0"): unwrap_tensor_subclass(model) # running model @@ -2063,12 +1967,7 @@ def forward(self, x): # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu # we can re-enable this after non-functional IR is enabled in export # model = torch.export.export(model, example_inputs).module() - if TORCH_VERSION_AT_LEAST_2_5: - model = torch.export.export_for_training( - model, example_inputs, strict=True - ).module() - else: - model = torch._export.capture_pre_autograd_graph(model, example_inputs) + model = torch.export.export(model, example_inputs, strict=True).module() after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) if api is _int8da_int4w_api: @@ -2077,12 +1976,36 @@ def forward(self, x): self.assertTrue(torch.ops.torchao.quantize_affine.default in targets) self.assertFalse(torch.ops.aten.narrow.default in targets) + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + def test_export_float8(self): + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear = torch.nn.Linear( + in_features=32, out_features=16, bias=False + ) + + def forward(self, x): + return self.linear(x) + + model = SimpleNetwork().eval().cuda() + inp = torch.randn(2, 32).cuda() + config = Float8DynamicActivationFloat8WeightConfig() + quantize_(model, config) + + ep = torch.export.export(model, (inp,)) + print(ep) + FileCheck().check_count( + "torch.ops.torchao.choose_scale_float8.default", 1, exactly=True + ).run(str(ep.graph)) + class TestUtils(unittest.TestCase): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_get_model_size_aqt(self, api, test_device, test_dtype): if test_dtype != torch.bfloat16: self.skipTest(f"{api} in {test_dtype} is not supported yet") diff --git a/test/integration/test_load_and_run_checkpoint.py b/test/integration/test_load_and_run_checkpoint.py new file mode 100644 index 0000000000..6bdee4a1b8 --- /dev/null +++ b/test/integration/test_load_and_run_checkpoint.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import re +import unittest +import warnings + +import torch +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.utils import is_fbcode, is_sm_at_least_90 + +if is_fbcode(): + # don't import from transformer internally, since some imports might be missing + pass +else: + from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + + +# please check model card for how to generate these models + +# high precision model, used for testing config deprecation warning +_HIGH_PRECISION_MODEL = "facebook/opt-125m" + +_DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [ + # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev + ( + "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev", + 1, + "Float8DynamicActivationFloat8WeightConfig", + ), + # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev + ( + "torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev", + 1, + "Int4WeightOnlyConfig", + ), + # model card: https://huggingface.co/torchao-testing/single-linear-IntxWeightOnlyConfig-v1-0.14.dev + ( + "torchao-testing/single-linear-IntxWeightOnlyConfig-v1-0.14.dev", + 1, + "IntxWeightOnlyConfig", + ), + # model card: https://huggingface.co/torchao-testing/single-linear-Int8DynamicActivationIntxWeightConfig-v1-0.14.dev + ( + "torchao-testing/single-linear-Int8DynamicActivationIntxWeightConfig-v1-0.14.dev", + 1, + "Int8DynamicActivationIntxWeightConfig", + ), +] + +_DEPRECATED_MODEL_INFO = [ + # model card: https://huggingface.co/torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev + ( + "torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev", + 1, + "Float8DynamicActivationFloat8WeightConfig", + ), + # model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev + ( + "torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev", + 1, + "Int4WeightOnlyConfig", + ), + # https://huggingface.co/torchao-testing/opt-125m-IntxWeightOnlyConfig-v1-0.14.0.dev + ( + "torchao-testing/opt-125m-IntxWeightOnlyConfig-v1-0.14.0.dev", + 1, + "IntxWeightOnlyConfig", + ), + # https://huggingface.co/torchao-testing/opt-125m-Int8DynamicActivationIntxWeightConfig-v1-0.14.0.dev + ( + "torchao-testing/opt-125m-Int8DynamicActivationIntxWeightConfig-v1-0.14.0.dev", + 1, + "Int8DynamicActivationIntxWeightConfig", + ), +] + +_SINGLE_LINEAR_MODEL_INFO = [ + # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev + ( + "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev", + 2, + "Float8DynamicActivationFloat8WeightConfig", + ), + # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev + ( + "torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev", + 2, + "Int4WeightOnlyConfig", + ), + # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev + ( + "torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev", + 2, + "Int4WeightOnlyConfig", + ), + # model card: https://huggingface.co/torchao-testing/single-linear-IntxWeightOnlyConfig-v2-0.14.dev + ( + "torchao-testing/single-linear-IntxWeightOnlyConfig-v2-0.14.dev", + 2, + "IntxWeightOnlyConfig", + ), + # model card: https://huggingface.co/torchao-testing/single-linear-Int8DynamicActivationIntxWeightConfig-v2-0.14.dev + ( + "torchao-testing/single-linear-Int8DynamicActivationIntxWeightConfig-v2-0.14.dev", + 2, + "Int8DynamicActivationIntxWeightConfig", + ), +] + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") +@unittest.skipIf( + is_fbcode(), + "Skipping the test in fbcode for now, not sure how to download from transformers", +) +class TestLoadAndRunCheckpoint(TestCase): + def _test_single_linear_helper( + self, model_name, version, config_name, is_deprecated + ): + from huggingface_hub import hf_hub_download + + downloaded_model = hf_hub_download(model_name, filename="model.pt") + # Load model weights, example inputs and reference output, + # run the loaded model and make sure the result matches reference output + + with torch.device("meta"): + # 32 and 256 are the args we used when we save the model, see + # model card: + # https://huggingface.co/torchao-testing/single-linear-FP8-v2-0.13-dev + model = torch.nn.Sequential( + torch.nn.Linear(32, 256, dtype=torch.bfloat16) # , device="cuda") + ) + + with ( + open(downloaded_model, "rb") as f, + warnings.catch_warnings(record=True) as caught_warnings, + ): + model.load_state_dict(torch.load(f), assign=True) + if is_deprecated: + pattern = re.compile( + rf"Models quantized with version {version} of .*{re.escape(config_name)}.* (is|are) deprecated" + ) + assert any(pattern.search(str(w.message)) for w in caught_warnings), ( + f"Didn't get expected warning message for deprecation for model: {model_name}" + ) + + downloaded_example_inputs = hf_hub_download( + model_name, filename="model_inputs.pt" + ) + with open(downloaded_example_inputs, "rb") as f: + example_inputs = torch.load(f) + downloaded_output = hf_hub_download(model_name, filename="model_output.pt") + with open(downloaded_output, "rb") as f: + ref_output = torch.load(f) + + output = model(*example_inputs) + self.assertTrue(torch.equal(output, ref_output)) + + @common_utils.parametrize("model_info", _DEPRECATED_SINGLE_LINEAR_MODEL_INFO) + def test_deprecated_single_linear(self, model_info): + model_name, version, config_name = model_info + self._test_single_linear_helper( + model_name, version, config_name, is_deprecated=True + ) + + @common_utils.parametrize("model_info", _SINGLE_LINEAR_MODEL_INFO) + def test_single_linear(self, model_info): + """Test that we can load and run the quantized linear checkpoint with saved sample input + and match the saved output, to make sure there is no BC breaking changes + when we make changes to tensor subclass implementations + """ + model_name, version, config_name = model_info + self._test_single_linear_helper( + model_name, version, config_name, is_deprecated=False + ) + + @common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO) + def test_deprecated_hf_models(self, model_info): + """Test that we print correct warning message when loading a deprecated checkpoint + and making sure the deprecated checkpoints can still be loaded + """ + # Load and quantize model + model_name, version, config_name = model_info + with warnings.catch_warnings(record=True) as caught_warnings: + quantized_model = AutoModelForCausalLM.from_pretrained( + model_name, + dtype="bfloat16", + device_map="cuda:0", + ) + # version mismatch check in config.py + assert any( + "Stored version is not the same as current default version of the config" + in str(w.message) + for w in caught_warnings + ), ( + f"Didn't get expected warning message for version mismatch for config {config_name}, model {model_name}" + ) + + # checkpoint deprecation + pattern = re.compile( + rf"Models quantized with version {version} of .*{re.escape(config_name)}.* (is|are) deprecated" + ) + assert any(pattern.search(str(w.message)) for w in caught_warnings), ( + f"Didn't get expected warning message for deprecation for model {model_name}" + ) + assert isinstance(quantized_model.config.quantization_config, TorchAoConfig) + assert ( + quantized_model.config.quantization_config.quant_type.version == version + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + from huggingface_hub import hf_hub_download + + downloaded_example_inputs = hf_hub_download( + model_name, filename="model_prompt.pt" + ) + with open(downloaded_example_inputs, "rb") as f: + prompt = torch.load(f) + + inputs = tokenizer( + prompt, + return_tensors="pt", + ).to("cuda") + generated_ids = quantized_model.generate( + **inputs, + max_new_tokens=128, + ) + + downloaded_output = hf_hub_download(model_name, filename="model_output.pt") + with open(downloaded_output, "rb") as f: + ref_generated_ids = torch.load(f) + + self.assertTrue(torch.equal(generated_ids, ref_generated_ids)) + + # make sure can successfully decode + _ = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + # make sure we throw warning for config deprecation + with warnings.catch_warnings(record=True) as caught_warnings: + _ = AutoModelForCausalLM.from_pretrained( + _HIGH_PRECISION_MODEL, + dtype="bfloat16", + device_map="cuda:0", + quantization_config=quantized_model.config.quantization_config, + ) + # config version deprecation in quant_api.py + assert any( + f"Config Deprecation: version {version} of {config_name} is deprecated and will no longer be supported in a future release" + in str(w.message) + for w in caught_warnings + ), ( + f"Didn't get expected warning message for version deprecation for config {config_name}, model {model_name}" + ) + + +common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint) + +if __name__ == "__main__": + run_tests() diff --git a/test/integration/test_vllm.py b/test/integration/test_vllm.py index 7bb9a6defa..32a7a8b405 100644 --- a/test/integration/test_vllm.py +++ b/test/integration/test_vllm.py @@ -17,9 +17,9 @@ import torch from packaging import version -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Requires PyTorch 2.8 or higher", allow_module_level=True) @@ -41,6 +41,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig from vllm import LLM, SamplingParams +from torchao.prototype.mx_formats import MXFPInferenceConfig from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.quant_api import ( CutlassInt4PackedLayout, @@ -69,9 +70,7 @@ def get_tests() -> List[TorchAoConfig]: Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout()) ) ] - SM100_TESTS = [ - # TorchAoConfig(MXFPInferenceConfig()) - ] # Failing for : https://github.com/pytorch/ao/issues/2239 + SM100_TESTS = [TorchAoConfig(MXFPInferenceConfig())] # Check CUDA availability first if not torch.cuda.is_available(): @@ -154,7 +153,7 @@ def quantize_and_save_model( # Load and quantize model quantized_model = AutoModelForCausalLM.from_pretrained( model_name, - torch_dtype="bfloat16", + dtype="bfloat16", device_map="cuda", quantization_config=quantization_config, ) diff --git a/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py b/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py new file mode 100644 index 0000000000..06beae5b34 --- /dev/null +++ b/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +triton = pytest.importorskip("triton", reason="Triton required to run this test") + +from packaging import version +from torchao.float8.float8_utils import compute_error +from torchao.prototype.blockwise_fp8_training.kernels import ( + torch_blockwise_scale_act_quant_lhs, + torch_blockwise_scale_act_quant_rhs, + torch_blockwise_scale_weight_quant, + triton_fp8_blockwise_act_quant_lhs, + triton_fp8_blockwise_act_quant_rhs, + triton_fp8_blockwise_act_quant_transposed_lhs, + triton_fp8_blockwise_weight_quant_rhs, + triton_fp8_blockwise_weight_quant_transposed_rhs, + triton_fp8_gemm_1x128_128x1, + triton_fp8_gemm_1x128_128x128, +) +from torchao.testing.utils import skip_if_rocm +from torchao.utils import is_sm_at_least_90 + +BLOCKWISE_SIZE_MNK = [ + # (128, 128, 128), + (2, 512, 128), + (2, 5120, 1280), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif( + version.parse(triton.__version__) < version.parse("3.3.0"), + reason="Triton version < 3.3.0, test skipped", +) +@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +def test_triton_fp8_gemm_1x128_128x128(M, N, K, dtype): + # Simulate output = input @ weight.T + A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + C = A @ B.T + A_q, A_s = triton_fp8_blockwise_act_quant_lhs(A, dtype=dtype) + B_t_q, B_t_s = triton_fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype) + C_q = triton_fp8_gemm_1x128_128x128( + A_q, B_t_q, A_s, B_t_s, out_dtype=torch.bfloat16 + ) + assert not C_q.isnan().any(), "C_q must not contain NaNs" + + sqnr = compute_error(C, C_q) + min_sqnr = 28.0 + assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.skipif( + version.parse(triton.__version__) < version.parse("3.3.0"), + reason="Triton version < 3.3.0, test skipped", +) +@pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +def test_triton_fp8_gemm_1x128_128x1(M, N, K, dtype): + # Simulate grad_weight = grad_output_t @ input + A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda") + B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda") + C = A.T @ B + A_t_q, A_t_s = triton_fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype) + B_q, B_s = triton_fp8_blockwise_act_quant_rhs(B, dtype=dtype) + C_q = triton_fp8_gemm_1x128_128x1(A_t_q, B_q, A_t_s, B_s, out_dtype=torch.bfloat16) + + assert not C_q.isnan().any(), "C_q must not contain NaNs" + assert C.dtype == torch.bfloat16 + assert C_q.dtype == torch.bfloat16 + + sqnr = compute_error(C, C_q) + min_sqnr = 28.0 + assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}" + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("block_size", [128, 256]) +def test_triton_quantize_fp8_act_quant_lhs(block_size): + device = "cuda" + M, K = 4096, 1024 + x = torch.randn(M, K, device=device) + + # Set one scaling block to 0s, so if nan guards/EPS are not applied, the + # quantized tensor will have NaNs due to division by 0 + x[0, :block_size] = 0.0 + + # Get the quantized tensor and reciprocal scales using triton implementation + triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_lhs( + x, + block_size=block_size, + ) + + # Get the quantized tensor and reciprocal scales using reference implementation + ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs(x, tile_size=block_size) + + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" + + # Convert both to float32 for comparison + triton_fp32 = triton_fp8.to(torch.float32) + ref_fp32 = ref_fp8.to(torch.float32) + + # Check that the quantized tensors are close + torch.testing.assert_close( + triton_fp32, + ref_fp32, + atol=0, + rtol=0, + msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", + ) + + # Compare reciprocal scales + torch.testing.assert_close( + triton_scale, + ref_scale, + atol=0, + rtol=0, + msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", + ) + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("block_size", [128, 256]) +def test_triton_quantize_fp8_act_quant_rhs(block_size: int): + device = "cuda" + M, K = 4096, 1024 + x = torch.randn(M, K, device=device) + + # Set one block to 0s, so if nan guards/EPS are not applied, the + # quantized tensor will have NaNs due to division by 0 + x[:block_size, :block_size] = 0.0 + + # Get the quantized tensor and reciprocal scales using triton implementation + triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_rhs( + x, + block_size=block_size, + ) + + # Get the quantized tensor and reciprocal scales using reference implementation + ref_fp8, ref_scale = torch_blockwise_scale_act_quant_rhs(x, block_size=block_size) + + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" + + # Convert both to float32 for comparison + triton_fp32 = triton_fp8.to(torch.float32) + ref_fp32 = ref_fp8.to(torch.float32) + + # Check that the quantized tensors are close + torch.testing.assert_close( + triton_fp32, + ref_fp32, + atol=0, + rtol=0, + msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", + ) + + # Compare reciprocal scales + torch.testing.assert_close( + triton_scale, + ref_scale, + atol=0, + rtol=0, + msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", + ) + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("block_size", [128, 256]) +@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)]) +def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int): + device = "cuda" + x = torch.randn(M, K, device=device) + + # Set one scaling block to 0s, so if nan guards/EPS are not applied, the + # quantized tensor will have NaNs due to division by 0 + x[0, :block_size] = 0.0 + + # Get the quantized tensor and reciprocal scales using triton implementation + triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_transposed_lhs( + x, + block_size=block_size, + ) + + # Get the quantized tensor and reciprocal scales using reference implementation + ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs( + x.t().contiguous(), tile_size=block_size + ) + + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" + + # Convert both to float32 for comparison + triton_fp32 = triton_fp8.to(torch.float32) + ref_fp32 = ref_fp8.to(torch.float32) + + # Check that the quantized tensors are close + torch.testing.assert_close( + triton_fp32, + ref_fp32, + atol=0, + rtol=0, + msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", + ) + + # Compare reciprocal scales + torch.testing.assert_close( + triton_scale, + ref_scale, + atol=0, + rtol=0, + msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", + ) + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("block_size", [128, 256]) +@pytest.mark.parametrize("M,K", [(4096, 1024), (4096, 4 * 4096)]) +def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int): + device = "cuda" + x = torch.randn(M, K, device=device) + + # Set one scaling block to 0s, so if nan guards/EPS are not applied, the + # quantized tensor will have NaNs due to division by 0 + x[:block_size, :block_size] = 0.0 + + # Get the quantized tensor and reciprocal scales using triton implementation + triton_fp8, triton_scale = triton_fp8_blockwise_weight_quant_rhs( + x, + block_size=block_size, + ) + # Get the quantized tensor and reciprocal scales using reference implementation + ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=block_size) + + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" + + # Convert both to float32 for comparison + triton_fp32 = triton_fp8.to(torch.float32) + ref_fp32 = ref_fp8.to(torch.float32) + + # Check that the quantized tensors are close + torch.testing.assert_close( + triton_fp32, + ref_fp32, + atol=0, + rtol=0, + msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", + ) + + # Compare reciprocal scales + torch.testing.assert_close( + triton_scale, + ref_scale, + atol=0, + rtol=0, + msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", + ) + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.skipif(not is_sm_at_least_90(), reason="Requires CUDA capability >= 9.0") +@pytest.mark.parametrize("block_size", [128, 256]) +def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int): + device = "cuda" + M = 512 + K = 2048 + x = torch.randn(M, K, device=device) + + # Set one scaling block to 0s, so if nan guards/EPS are not applied, the + # quantized tensor will have NaNs due to division by 0 + x[:block_size, :block_size] = 0.0 + + # Get the quantized tensor and reciprocal scales using triton implementation + triton_fp8, triton_scale = triton_fp8_blockwise_weight_quant_transposed_rhs( + x, + block_size=block_size, + ) + # Get the quantized tensor and reciprocal scales using reference implementation + ref_fp8, ref_scale = torch_blockwise_scale_weight_quant( + x.t().contiguous(), tile_size=block_size + ) + + assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" + assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" + + # Convert both to float32 for comparison + triton_fp32 = triton_fp8.to(torch.float32) + ref_fp32 = ref_fp8.to(torch.float32) + + # Check that the quantized tensors are close + torch.testing.assert_close( + triton_fp32, + ref_fp32, + atol=0, + rtol=0, + msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", + ) + + # Compare reciprocal scales + torch.testing.assert_close( + triton_scale, + ref_scale, + atol=0, + rtol=0, + msg=f"Scales differ: max diff = {(triton_scale - ref_scale).abs().max().item()}", + ) diff --git a/test/prototype/blockwise_fp8_training/test_blockwise_linear.py b/test/prototype/blockwise_fp8_training/test_blockwise_linear.py new file mode 100644 index 0000000000..fdb1ad42f5 --- /dev/null +++ b/test/prototype/blockwise_fp8_training/test_blockwise_linear.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import pytest +import torch + +from torchao.utils import is_sm_at_least_90 + +triton = pytest.importorskip("triton", reason="Triton required to run this test") +if not is_sm_at_least_90(): + pytest.skip("This test requires SM90 or higher", allow_module_level=True) + + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear + +torch.random.manual_seed(0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("in_features", [4096]) +@pytest.mark.parametrize("out_features", [128256]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("block_size", [128]) +def test_blockwise_quant_linear_fwd_bwd( + in_features, + out_features, + batch_size, + block_size, +): + if in_features % block_size != 0 or out_features % block_size != 0: + pytest.skip(f"Dimensions must be divisible by block_size={block_size}") + + layer_ref = torch.nn.Linear( + in_features=in_features, + out_features=out_features, + bias=False, + ).cuda() + + layer_test = Float8BlockwiseLinear.from_float(copy.deepcopy(layer_ref)) + + # Create input tensor + x_test = torch.randn(batch_size, 256, in_features).cuda().requires_grad_(True) + x_ref = x_test.clone().detach().requires_grad_(True) + + # Forward pass + y_test = layer_test(x_test) + y_ref = layer_ref(x_ref) + + # Compare outputs + sqnr = compute_error(y_ref, y_test) + assert not y_test.isnan().any(), "Output must not contain NaNs" + assert sqnr >= 25.0, f"SQNR: {sqnr.item()} must be >= 25.0" + assert not sqnr.isinf().any(), "SQNR must not be inf" + + # Backward pass + y_test.sum().backward() + y_ref.sum().backward() + + # Compare input grads + sqnr = compute_error(x_ref.grad, x_test.grad) + assert not x_test.grad.isnan().any(), "Input grad must not contain NaNs" + assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0" + + # Compare weight grads + sqnr = compute_error(layer_ref.weight, layer_test.weight) + assert not layer_test.weight.grad.isnan().any(), "Weight grad must not contain NaNs" + assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0" diff --git a/test/prototype/inductor/test_int8_sdpa_fusion.py b/test/prototype/inductor/test_qsdpa_fusion.py similarity index 91% rename from test/prototype/inductor/test_int8_sdpa_fusion.py rename to test/prototype/inductor/test_qsdpa_fusion.py index ec4f928df2..dc754d2682 100644 --- a/test/prototype/inductor/test_int8_sdpa_fusion.py +++ b/test/prototype/inductor/test_qsdpa_fusion.py @@ -11,11 +11,11 @@ from torch.testing._internal.inductor_utils import HAS_CPU import torchao -from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import ( - _int8_sdpa_init, +from torchao.prototype.inductor.fx_passes.qsdpa_fusion import ( + _qsdpa_init, custom_pass, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least class SelfAttnLikeModule(torch.nn.Module): @@ -120,7 +120,7 @@ def _check_common( ) source_code = "\n".join(source_code) if has_fuse_pattern: - self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1) + self.assertGreaterEqual(counters["inductor"]["qsdpa_fuse_attention"], 1) if contains: self.assertTrue( any( @@ -128,6 +128,7 @@ def _check_common( for op_name in [ "qscaled_dot_product", "cpp_fused_quantize_per_tensor", + "cpp_fused__unsafe_view_quantize_per_tensor", ] ) ) @@ -149,16 +150,15 @@ def _check_common( @skipIfRocm @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + not torch_version_at_least("2.7.0"), + reason="qsdpa requires torch 2.7 or later", ) @unittest.skipIf( "CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"), reason="cpp kernels not built", ) @config.patch({"freezing": True}) - def _test_sdpa_int8_rewriter(self): - from torch.export import export_for_training - + def _test_qsdpa_rewriter(self): import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( @@ -193,17 +193,13 @@ def _test_sdpa_int8_rewriter(self): ), config.patch(post_grad_custom_pre_pass=custom_pass), ): - _int8_sdpa_init() + _qsdpa_init() quantizer = X86InductorQuantizer() quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) quantizer.set_function_type_qconfig( torch.matmul, quantizer.get_global_quantization_config() ) - export_model = export_for_training( - mod, - inputs, - strict=True, - ).module() + export_model = torch.export.export(mod, inputs, strict=True).module() prepare_model = prepare_pt2e(export_model, quantizer) prepare_model(*inputs) convert_model = convert_pt2e(prepare_model) @@ -217,9 +213,7 @@ def _test_sdpa_int8_rewriter(self): class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): device = "cpu" - test_sdpa_int8_rewriter_cpu = ( - TestSDPAPatternRewriterTemplate._test_sdpa_int8_rewriter - ) + test_qsdpa_rewriter_cpu = TestSDPAPatternRewriterTemplate._test_qsdpa_rewriter if __name__ == "__main__": diff --git a/test/prototype/moe_training/test_everything.sh b/test/prototype/moe_training/test_everything.sh new file mode 100755 index 0000000000..79b5cf3c15 --- /dev/null +++ b/test/prototype/moe_training/test_everything.sh @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +#!/bin/bash + +# terminate script on first error +set -e +IS_ROCM=$(rocm-smi --version || true) + +# These tests do not work on ROCm yet +if [ -z "$IS_ROCM" ] +then +pytest test/prototype/moe_training/test_kernels.py -s +pytest test/prototype/moe_training/test_training.py -s +./test/prototype/moe_training/test_fsdp.sh +./test/prototype/moe_training/test_tp.sh +./test/prototype/moe_training/test_fsdp_tp.sh +fi + +echo "all tests successful" diff --git a/test/prototype/moe_training/test_fsdp.py b/test/prototype/moe_training/test_fsdp.py index 4994a76854..f1715fd4b1 100644 --- a/test/prototype/moe_training/test_fsdp.py +++ b/test/prototype/moe_training/test_fsdp.py @@ -1,11 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py +# +####################################################################### + import copy import os import pytest import torch + +if torch.version.hip is not None: + pytest.skip( + "ROCm support for MoE quantization is under development", + allow_module_level=True, + ) + from torch import distributed as dist from torch import nn from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.nn import functional as F # this feature requires CUDA and SM89+ @@ -15,39 +36,121 @@ ) from torchao.float8.float8_utils import compute_error -from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig -from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) from torchao.quantization.quant_api import quantize_ +from .testing_utils import _validate_model_conversion + # this test requires torchtitan try: - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m + from torchtitan.models.moe import MoE, MoEArgs except ImportError: - import warnings + pytest.skip( + "torchtitan not installed, skipping MoE tests.", allow_module_level=True + ) + + +@pytest.fixture(scope="module") +def device_mesh_1d() -> DeviceMesh: + """ + Fixture for setting up and tearing down the distributed environment + for the entire test module. + """ + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + if not dist.is_initialized(): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + device_mesh = init_device_mesh("cuda", (world_size,)) + torch.manual_seed(1) + torch.cuda.set_device(rank) + + yield device_mesh - warnings.warn("torchtitan not installed, skipping MoE tests.") - pytest.skip(allow_module_level=True) + dist.destroy_process_group() -def test_moe_float8_training_fsdp(): +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + ["experts,shared_experts"], + ], +) +@pytest.mark.parametrize("compile", [False, True]) +@pytest.mark.parametrize( + "recipe_config", + [ + { + "recipe": MoEScalingType.FP8_ROWWISE, + "group_alignment_size": 16, + "min_out_sqnr": 29.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 23.0, + }, + { + "recipe": MoEScalingType.MXFP8, + "group_alignment_size": 32, + "min_out_sqnr": 28.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 21.0, + }, + ], +) +def test_moe_training_fsdp( + target_fqns: list[str], + compile: bool, + recipe_config: dict, + device_mesh_1d: DeviceMesh, +): + ( + recipe, + group_alignment_size, + min_out_sqnr, + min_input_grad_sqnr, + min_param_grad_sqnr, + ) = ( + recipe_config["recipe"], + recipe_config["group_alignment_size"], + recipe_config["min_out_sqnr"], + recipe_config["min_input_grad_sqnr"], + recipe_config["min_param_grad_sqnr"], + ) assert torch.cuda.is_available() + if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != ( + 9, + 0, + ): + pytest.skip( + f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" + ) - # setup distributed for fsdp - setup_distributed() + elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != ( + 10, + 0, + ): + pytest.skip( + f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" + ) + + # set token group alignment size needed for GEMM (contraction dim stride must be 16 byte aligned) + # or quantization ops (mxfp8 scaling groups are size 1x32) + set_token_group_alignment_size_m(group_alignment_size) # define model args - target_fqns = ["experts"] - model_args = TransformerModelArgs( - moe_enabled=True, + model_args = MoEArgs( num_experts=8, - dim=256, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 4 * 5120 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -66,7 +169,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig() + config = MoETrainingConfig(recipe) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # validate that only the experts were converted @@ -74,13 +177,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: model, target_fqns=target_fqns, ) + if compile: + # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it + model = torch.compile(model, fullgraph=False) + ref_model = torch.compile(ref_model, fullgraph=False) # FSDP2 fully_shard(model) fully_shard(ref_model) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -92,7 +199,9 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -105,52 +214,13 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 30.0, ( - f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) - - dist.destroy_process_group() - - -def _validate_model_conversion( - root_module: nn.Module, - target_fqns: list[str], -): - def _recursive_validate( - module: nn.Module, - cur_fqn: str, - ): - is_allowed_module = cur_fqn in target_fqns - - # check current module params - for param_name, param in module.named_parameters(recurse=False): - is_converted_type = isinstance(param, ScaledGroupedMMTensor) - if is_converted_type: - assert is_allowed_module, ( - f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." - ) - if not is_allowed_module: - assert not is_converted_type, ( - f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." - ) - - # recursively check child modules - for child_name, child_module in module.named_children(): - child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name - _recursive_validate(child_module, child_fqn) - - _recursive_validate(root_module, "") - - -def setup_distributed(): - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) diff --git a/test/prototype/moe_training/test_fsdp.sh b/test/prototype/moe_training/test_fsdp.sh new file mode 100755 index 0000000000..5f858061f4 --- /dev/null +++ b/test/prototype/moe_training/test_fsdp.sh @@ -0,0 +1 @@ +torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py -s diff --git a/test/prototype/moe_training/test_fsdp_tp.py b/test/prototype/moe_training/test_fsdp_tp.py new file mode 100644 index 0000000000..2589ec1a93 --- /dev/null +++ b/test/prototype/moe_training/test_fsdp_tp.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp_tp.py +# +####################################################################### + +import copy +import os + +import pytest +import torch + +if torch.version.hip is not None: + pytest.skip( + "ROCm support for MoE quantization is under development", + allow_module_level=True, + ) + +from torch import distributed as dist +from torch import nn +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor import DTensor +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor import Partial, Replicate, Shard +from torch.nn import functional as F + +try: + from torch.distributed.tensor.parallel import ( + PrepareModuleInputOutput, + parallelize_module, + ) +except ImportError: + pytest.skip( + "torch version is too old, these tests require nightly build. Skipping MoE training tests.", + allow_module_level=True, + ) + + +# this feature requires CUDA and SM89+ +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) +from torchao.quantization.quant_api import quantize_ + +from .testing_utils import _validate_model_conversion + +# this test requires torchtitan +try: + from torchtitan.distributed.expert_parallel import ( + ExpertParallel, + ExpertTensorParallel, + NoParallel, + TensorParallel, + set_token_group_alignment_size_m, + ) + from torchtitan.models.moe import MoE, MoEArgs +except ImportError: + pytest.skip( + "torchtitan not installed, skipping MoE tests.", allow_module_level=True + ) + + +@pytest.fixture(scope="module") +def device_mesh_2d() -> DeviceMesh: + """ + Fixture for setting up and tearing down the distributed environment + for the entire test module. + """ + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + if not dist.is_initialized(): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + device_mesh = init_device_mesh( + "cuda", + (world_size // 2, 2), + mesh_dim_names=("dp", "tp"), + ) + + torch.manual_seed(1) + torch.cuda.set_device(rank) + + yield device_mesh + + dist.destroy_process_group() + + +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + ["experts,shared_experts"], + ], +) +@pytest.mark.parametrize("compile", [False, True]) +@pytest.mark.parametrize( + "recipe_config", + [ + { + "recipe": MoEScalingType.FP8_ROWWISE, + "group_alignment_size": 16, + "min_out_sqnr": 29.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 22.0, + }, + { + "recipe": MoEScalingType.MXFP8, + "group_alignment_size": 32, + "min_out_sqnr": 28.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 21.0, + }, + ], +) +def test_moe_training_fsdp_tp( + target_fqns: list[str], + compile: bool, + recipe_config: dict, + device_mesh_2d: DeviceMesh, +): + ( + recipe, + group_alignment_size, + min_out_sqnr, + min_input_grad_sqnr, + min_param_grad_sqnr, + ) = ( + recipe_config["recipe"], + recipe_config["group_alignment_size"], + recipe_config["min_out_sqnr"], + recipe_config["min_input_grad_sqnr"], + recipe_config["min_param_grad_sqnr"], + ) + assert torch.cuda.is_available() + if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != ( + 9, + 0, + ): + pytest.skip( + f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" + ) + + elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != ( + 10, + 0, + ): + pytest.skip( + f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" + ) + + # set token group alignment size needed for GEMM (contraction dim stride must be 16 byte aligned) + # or quantization ops (mxfp8 scaling groups are size 1x32) + set_token_group_alignment_size_m(group_alignment_size) + + # define model args + model_args = MoEArgs( + num_experts=8, + ) + dim, hidden_dim = 5120, 4 * 5120 + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() + torch.manual_seed(1) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig(scaling_type=recipe) + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # validate that only the experts were converted + _validate_model_conversion( + model, + target_fqns=target_fqns, + ) + if compile: + # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it + model = torch.compile(model, fullgraph=False) + ref_model = torch.compile(ref_model, fullgraph=False) + + # apply TP + apply_moe_ep_tp(model, tp_mesh=device_mesh_2d["tp"], ep_mesh=None, ep_tp_mesh=None) + apply_moe_ep_tp( + ref_model, tp_mesh=device_mesh_2d["tp"], ep_mesh=None, ep_tp_mesh=None + ) + + # apply FSDP2 + fsdp_config = {"mesh": device_mesh_2d["dp"]} + fully_shard(model, **fsdp_config) + fully_shard(ref_model, **fsdp_config) + + # Rough validation that parallelization was applied properly. + assert isinstance(model.experts.w1.data, DTensor), ( + "test model experts.w1 is not a DTensor" + ) + assert isinstance(model.experts.w2.data, DTensor), ( + "test model experts.w2 is not a DTensor" + ) + assert isinstance(model.experts.w3.data, DTensor), ( + "test model experts.w3 is not a DTensor" + ) + assert isinstance(ref_model.experts.w1.data, DTensor), ( + "ref model experts.w1 is not a DTensor" + ) + assert isinstance(ref_model.experts.w2.data, DTensor), ( + "ref model experts.w2 is not a DTensor" + ) + assert isinstance(ref_model.experts.w3.data, DTensor), ( + "ref model experts.w3 is not a DTensor" + ) + + # inputs + batch, seq = 8, 2048 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + # forward pass + ref_out = ref_model(ref_x) + out = model(x) + + # validate output + out_sqnr = compute_error(out, ref_out) + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) + + # compute loss + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + + # backward pass + ref_loss.backward() + out_loss.backward() + + # validate input gradient + input_grad_sqnr = compute_error(x.grad, ref_x.grad) + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." + ) + + # validate param gradients + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + param_grad_sqnr = compute_error(param1.grad, param2.grad) + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." + ) + + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, +): + # Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/ + # that supports single MoE layer independent of a transformer. + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + + parallelize_module( + module=model.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) diff --git a/test/prototype/moe_training/test_fsdp_tp.sh b/test/prototype/moe_training/test_fsdp_tp.sh new file mode 100755 index 0000000000..4c00dcd853 --- /dev/null +++ b/test/prototype/moe_training/test_fsdp_tp.sh @@ -0,0 +1 @@ +torchrun --nproc_per_node=4 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp_tp.py -s diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index ed68e8fa23..495973bf7c 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -7,53 +7,107 @@ import pytest import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - # We need to skip before doing any imports which would use triton, since -# triton won't be available on CPU builds and torch < 2.5 -if not ( - TORCH_VERSION_AT_LEAST_2_5 - and torch.cuda.is_available() - and torch.cuda.get_device_capability()[0] >= 9 -): +# triton won't be available on CPU builds +if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9): pytest.skip("Unsupported PyTorch version", allow_module_level=True) - +from torchao.prototype.moe_training.kernels.float8_rowwise import ( + triton_fp8_rowwise_3d_transpose_rhs, + triton_fp8_rowwise_3d_transpose_rhs_fused_reduction, +) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_per_group_rowwise_scales, +) +from torchao.prototype.moe_training.kernels.mxfp8 import ( + compute_blocked_scale_offsets_for_K_groups, + compute_blocked_scale_offsets_for_M_groups, + torch_to_blocked_2d_K_groups, + torch_to_blocked_2d_M_groups, + torch_to_blocked_per_group_3d, + triton_mx_block_rearrange_2d_K_groups, + triton_mx_block_rearrange_2d_M_groups, + triton_mx_block_rearrange_per_group_3d, ) from torchao.prototype.moe_training.utils import ( _is_column_major, - _to_2d_jagged_float8_tensor_colwise, - _to_2d_jagged_float8_tensor_rowwise, + generate_jagged_offs, + torch_to_3d_rowwise_float8_transpose_rhs, + torch_to_float8_per_group_colwise, + torch_to_float8_per_group_rowwise, ) +from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx from torchao.testing.utils import skip_if_rocm +from torchao.utils import ( + is_sm_at_least_100, +) @skip_if_rocm("ROCm enablement in progress") @pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool): - # tests case where rowwise scales are computed for multiple distinct subtensors, + # Tests case where rowwise scales are computed for multiple distinct subtensors, # with end boundary of each group is determine by their end column indexes (offsets). device = "cuda" m, k, n_groups = 256, 256, 4 - x = torch.randn(m, k * n_groups, device=device) - colwise_offs = torch.arange(k, k * n_groups + 1, k, device=device) + x = torch.randn(k, m * n_groups, device=device) + colwise_offs = torch.arange(m, m * n_groups + 1, m, device=device) - # compute reference with torch impl - ref_fp8_data, ref_scales = _to_2d_jagged_float8_tensor_rowwise( + # Torch reference impl + ref_fp8_data, ref_scales = torch_to_float8_per_group_rowwise( x, colwise_offs, target_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=round_scales_to_power_of_2, ) - kernel_fp8_data, kernel_scales = triton_fp8_row_major_jagged_rowwise_scales( + + # Triton kernel + kernel_fp8_data, kernel_scales = triton_fp8_per_group_rowwise_scales( x, colwise_offs, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=round_scales_to_power_of_2, ) + + assert torch.eq(ref_fp8_data, kernel_fp8_data).all(), "fp8 data not equal" + assert torch.eq(ref_scales, kernel_scales).all(), "scales not equal" + assert not _is_column_major(kernel_fp8_data), "fp8 data is not row major" + + +@skip_if_rocm("ROCm enablement in progress") +@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) +def test_row_major_with_jagged_rowwise_scales_transpose_method( + round_scales_to_power_of_2: bool, +): + # tests case where rowwise scales are computed for multiple distinct subtensors, + # with end boundary of each group is determine by their end column indexes (offsets). + device = "cuda" + m, k, n_groups = 256, 256, 4 + grad_out = torch.randn(m * n_groups, k, device=device) + colwise_offs = torch.arange(m, m * n_groups + 1, m, device=device) + grad_out_t = grad_out.t() + + # compute reference with torch impl + ref_fp8_data, ref_scales = torch_to_float8_per_group_rowwise( + grad_out_t, + colwise_offs, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + + # Transpose method requires grad_out to be column major, then we compute per group + # colwise scales writing to column major, then transpose outputs back to the desired + # shape and row major format. + kernel_fp8_data, kernel_scales = triton_fp8_per_group_colwise_scales( + grad_out.t().contiguous().t(), + colwise_offs, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + kernel_fp8_data = kernel_fp8_data.t() # (mg, n) -> (n, mg) + kernel_scales = kernel_scales.t() # (1, n * n_groups) -> (n * n_groups, 1) + assert torch.eq(ref_fp8_data, kernel_fp8_data).all(), "fp8 data not equal" assert torch.eq(ref_scales, kernel_scales).all(), "scales not equal" assert not _is_column_major(kernel_fp8_data), "fp8 data is not row major" @@ -70,13 +124,13 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo rowwise_offs = torch.arange(m, m * n_groups + 1, m, device=device) # compute reference with torch impl - ref_fp8_data, ref_scales = _to_2d_jagged_float8_tensor_colwise( + ref_fp8_data, ref_scales = torch_to_float8_per_group_colwise( x, rowwise_offs, target_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=round_scales_to_power_of_2, ) - kernel_fp8_data, kernel_scales = triton_fp8_col_major_jagged_colwise_scales( + kernel_fp8_data, kernel_scales = triton_fp8_per_group_colwise_scales( x, rowwise_offs, output_dtype=torch.float8_e4m3fn, @@ -85,3 +139,231 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo assert torch.eq(ref_fp8_data, kernel_fp8_data).all(), "fp8 data not equal" assert torch.eq(ref_scales, kernel_scales).all(), "scales not equal" assert _is_column_major(kernel_fp8_data), "fp8 data is not column major" + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) +def test_fp8_rowwise_3d_transpose_rhs_atomic(round_scales_to_power_of_2: bool): + device = "cuda" + experts, n, k = 8, 4 * 5120, 5120 + + # Example expert weights as it comes into forward transposed + torch.manual_seed(0) + x = torch.randn((experts, n, k), dtype=torch.bfloat16, device=device).transpose( + -2, -1 + ) + + # Compute reference with torch impl + ref_fp8, ref_scales = torch_to_3d_rowwise_float8_transpose_rhs( + x, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + # Torch impl keeps empty scaled dim, so we squeeze it out to be consistent with triton impl + ref_scales = ref_scales.squeeze(1) + + triton_fp8, triton_scales = triton_fp8_rowwise_3d_transpose_rhs( + x, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + assert ref_scales.shape == triton_scales.shape, "scale shapes not equal" + assert ref_scales.stride() == triton_scales.stride(), "scale strides not equal" + assert torch.allclose(ref_scales, triton_scales, rtol=0, atol=0), "scales not equal" + + assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal" + assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal" + assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal" + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) +def test_fp8_rowwise_3d_transpose_rhs_reduction(round_scales_to_power_of_2: bool): + device = "cuda" + experts, n, k = 8, 4 * 5120, 5120 + + # Example expert weights as it comes into forward transposed + torch.manual_seed(0) + x = torch.randn((experts, n, k), dtype=torch.bfloat16, device=device).transpose( + -2, -1 + ) + + # Compute reference with torch impl + ref_fp8, ref_scales = torch_to_3d_rowwise_float8_transpose_rhs( + x, + target_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + # Torch impl keeps empty scaled dim, so we squeeze it out to be consistent with triton impl + ref_scales = ref_scales.squeeze(1) + + triton_fp8, triton_scales = triton_fp8_rowwise_3d_transpose_rhs_fused_reduction( + x, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) + assert ref_scales.shape == triton_scales.shape, "scale shapes not equal" + assert ref_scales.stride() == triton_scales.stride(), "scale strides not equal" + assert torch.allclose(ref_scales, triton_scales, rtol=0, atol=0), "scales not equal" + + assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal" + assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal" + assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal" + + +@skip_if_rocm("ROCm enablement in progress") +@pytest.mark.parametrize( + "m,k,n_groups", [(256, 256, 4), (16640, 5120, 16), (16640, 8192, 16)] +) +def test_triton_mx_block_rearrange_2d_M_groups( + m: int, + k: int, + n_groups: int, +): + device = "cuda" + block_size = 32 + input_data = torch.randn(m, k, device=device) + e8m0_scales, _ = to_mx( + input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + input_group_offsets = generate_jagged_offs( + n_groups, m, multiple_of=block_size, device=device + ) + + # torch reference + ref_out_scales, _ = torch_to_blocked_2d_M_groups( + e8m0_scales, input_group_offsets, k, block_size=block_size + ) + + # triton kernel + _, output_group_offsets = compute_blocked_scale_offsets_for_M_groups( + input_group_offsets + ) + triton_out_scales = triton_mx_block_rearrange_2d_M_groups( + e8m0_scales, + input_group_offsets, + output_group_offsets, + ) + assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), ( + "blocked scales not equal" + ) + + +@skip_if_rocm("ROCm enablement in progress") +@pytest.mark.parametrize("e,n,k", [(1, 8192, 5120), (2, 8192, 5120), (8, 5120, 8192)]) +def test_mxfp8_per_group_blocked_scales_3d( + e: int, + n: int, + k: int, +): + device = "cuda" + block_size = 32 + weights = torch.randn(e, n, k // block_size, device=device) + weight_scales, _ = to_mx( + weights, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + # torch reference + ref_out_scales = torch_to_blocked_per_group_3d(weight_scales) + + # triton kernel + triton_out_scales = triton_mx_block_rearrange_per_group_3d(weight_scales) + assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), ( + "blocked scales not equal" + ) + + +@pytest.mark.skip( + "Temporarily disable and use e2e training numerical tests instead. See: https://github.com/pytorch/ao/pull/2990#discussion_r2354167396" +) +@skip_if_rocm("ROCm enablement in progress") +@pytest.mark.parametrize("m", [256, 512, 1024, 5120]) +@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384]) +@pytest.mark.parametrize("n_groups", [1, 4, 8, 16]) +def test_triton_mx_block_rearrange_2d_K_groups( + m: int, + total_k: int, + n_groups: int, +): + device = "cuda" + block_size = 32 + input_data = torch.randn(m, total_k, device=device) + + e8m0_scales, _ = to_mx( + input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + # Generate group end offsets along total_K, then divide by block_size to get scale group end offsets + input_group_offsets = generate_jagged_offs( + n_groups, total_k, multiple_of=block_size, device=device + ) + scale_group_offsets = input_group_offsets // block_size + + # torch reference + ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_2d_K_groups( + e8m0_scales, + scale_group_offsets, + ) + + # triton kernel + _, output_group_offsets = compute_blocked_scale_offsets_for_K_groups( + scale_group_offsets + ) + assert torch.equal(output_group_offsets, ref_start_cols_after_padding), ( + "output scale group start offsets not equal" + ) + triton_out_scales = triton_mx_block_rearrange_2d_K_groups( + e8m0_scales, + scale_group_offsets, + output_group_offsets, + ) + assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal" + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +@pytest.mark.parametrize("E", (1, 2, 4, 8)) +@pytest.mark.parametrize("N", (32, 1536, 5120, 7168, 8192)) +@pytest.mark.parametrize("K", (32, 1536, 5120, 7168, 8192)) +@pytest.mark.parametrize("input_dtype", (torch.bfloat16,)) +@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,)) +def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode): + from torchao.prototype import mxfp8_cuda + + scaling_mode_str = ( + "floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil" + ) + block_size = 32 + + # Use disinct incrementing values from 0 to E*M*K-1 to make debugging easier. + x = ( + torch.arange(0, E * N * K, dtype=input_dtype, device="cuda") + .reshape(E, N, K) + .contiguous() + ) + + # Reference implementation + s_d1_ref, y_d1_ref = to_mx( + # Transpose so N is final dim, since to_mx scales along that dim + x.transpose(-2, -1).contiguous(), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Transpose tensors and scales back so we have effectively + # quantized input shape (E, N, K) along N + y_d1_ref = y_d1_ref.transpose(-2, -1) + s_d1_ref = s_d1_ref.transpose(-2, -1) + + # CUDA implementation (should work with any stride pattern) + y_d1, s_d1 = mxfp8_cuda.quantize_3d( + x, scale_dim_n=block_size, scaling_mode=scaling_mode_str + ) + # Check scales + torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0) + + # Check quantized values + torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0) + assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match" diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 844220c49c..1fd39451ce 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -6,33 +6,43 @@ import pytest import torch +from torch.nn import functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import torch_version_at_least # We need to skip before doing any imports which would use triton, since # triton won't be available on CPU builds and torch < 2.5 if not ( - TORCH_VERSION_AT_LEAST_2_5 + torch_version_at_least("2.7.0") and torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9 ): pytest.skip("Unsupported PyTorch version", allow_module_level=True) +pytest.importorskip("triton", reason="Triton required to run this test") from torchao.float8.config import ( Float8LinearConfig, Float8LinearRecipeName, ) from torchao.float8.float8_linear import matmul_with_hp_or_float8_args -from torchao.float8.float8_tensor import LinearMMConfig -from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.float8.float8_training_tensor import LinearMMConfig +from torchao.float8.float8_utils import compute_error, tensor_to_scale, to_fp8_saturated from torchao.prototype.moe_training.scaled_grouped_mm import ( + _emulated_mxfp8_scaled_grouped_mm_2d_2d, + _emulated_mxfp8_scaled_grouped_mm_2d_3d, _scaled_grouped_mm, ) +from torchao.prototype.moe_training.utils import ( + _to_mxfp8_per_group_colwise, + _to_mxfp8_per_group_rowwise, + generate_jagged_offs, +) +from torchao.prototype.mx_formats.mx_tensor import to_mx from torchao.testing.utils import skip_if_rocm -@skip_if_rocm("ROCm enablement in progress") +@skip_if_rocm("ROCm not supported") def test_valid_scaled_grouped_mm_2d_3d(): out_dtype = torch.bfloat16 device = "cuda" @@ -86,6 +96,7 @@ def test_valid_scaled_grouped_mm_2d_3d(): assert torch.equal(b_t.grad, ref_b_t.grad) +@skip_if_rocm("ROCm not supported") @pytest.mark.parametrize("m", [16, 17]) @pytest.mark.parametrize("k", [16, 18]) @pytest.mark.parametrize("n", [32, 33]) @@ -212,3 +223,139 @@ def compute_reference_forward( # Concatenate the outputs and verify the full result is correct. output_ref = torch.cat(outputs, dim=0) return output_ref + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)]) +@pytest.mark.parametrize("num_experts", (1, 8, 16)) +def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts): + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda") + offs = generate_jagged_offs(num_experts, M) + x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone() + + # Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm + block_size = 32 + x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size) + + # To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose. + w_scale, w_fp8 = to_mx( + w, + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + ref_out = torch._grouped_mm( + x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16 + ) + out = _emulated_mxfp8_scaled_grouped_mm_2d_3d( + x_fp8, x_scale, w_fp8, w_scale, offs=offs, out_dtype=torch.bfloat16 + ) + + sqnr = compute_error(ref_out, out) + min_sqnr = 27.0 + assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.parametrize("M", (1024, 4096)) +@pytest.mark.parametrize("N", (1024, 4096)) +@pytest.mark.parametrize("num_experts", (8, 16)) +def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts): + # Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x + block_size = 32 + grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + grad_out_t = grad_out.t().contiguous() + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) + x_ref, grad_out_t_ref, offs_ref = x.clone(), grad_out_t.clone(), offs.clone() + + # bf16 reference grouped gemm + ref_out = torch._grouped_mm( + grad_out_t_ref, + x_ref, + offs=offs_ref, + out_dtype=torch.bfloat16, + ) + + # mxpf8 grouped gemm + x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size) + grad_out_t_mx, grad_out_t_scale = _to_mxfp8_per_group_rowwise( + grad_out_t, + offs=offs, + block_size=block_size, + ) + x_mx, x_scale = _to_mxfp8_per_group_colwise( + x, + offs=offs, + block_size=block_size, + ) + out = _emulated_mxfp8_scaled_grouped_mm_2d_2d( + grad_out_t_mx, + grad_out_t_scale, + x_mx, + x_scale, + offs=offs, + out_dtype=torch.bfloat16, + block_size=block_size, + ) + + sqnr = compute_error(ref_out, out) + min_sqnr = 27.0 + assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" + + +@skip_if_rocm("ROCm not supported") +@pytest.mark.parametrize( + "M,K,N", [(1024, 5120, 8192), (2048, 5120, 8192), (16640, 5120, 8192)] +) +@pytest.mark.parametrize("num_experts", (2, 4, 8, 16)) +def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): + from torchao.prototype.moe_training.scaled_grouped_mm import ( + _MXFP8GroupedMM, + ) + + block_size = 32 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) + w = torch.randn( + num_experts, + N, + K, + dtype=torch.bfloat16, + device="cuda", + ) + w_t = w.transpose(-2, -1).requires_grad_(True) + offs = generate_jagged_offs(num_experts, M, multiple_of=block_size) + x_ref, w_t_ref, offs_ref = ( + x.clone().detach().requires_grad_(True), + w_t.clone().detach().requires_grad_(True), + offs.clone(), + ) + + # Forward + out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16) + ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16) + sqnr = compute_error(ref_out, out) + min_sqnr = 27.0 + assert sqnr >= min_sqnr, f"Output sqnr {sqnr} is too low, must be >= {min_sqnr}" + + # Backward + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + ref_loss.backward() + out_loss.backward() + + # Check input grads + min_input_grad_sqnr = 26.0 + sqnr = compute_error(x_ref.grad, x.grad) + assert sqnr >= min_input_grad_sqnr, ( + f"Input grad sqnr {sqnr} is too low, must be >= {min_input_grad_sqnr}" + ) + + # Check weight grads + min_weight_grad_sqnr = 24.0 + sqnr = compute_error(w_t_ref.grad, w_t.grad) + assert sqnr >= min_weight_grad_sqnr, ( + f"Weight grad sqnr {sqnr} is too low, must be >= {min_weight_grad_sqnr}" + ) diff --git a/test/prototype/moe_training/test_tp.py b/test/prototype/moe_training/test_tp.py new file mode 100644 index 0000000000..705f5a40f9 --- /dev/null +++ b/test/prototype/moe_training/test_tp.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py +# +####################################################################### + +import copy +import os + +import pytest +import torch + +if torch.version.hip is not None: + pytest.skip( + "ROCm support for MoE quantization is under development", + allow_module_level=True, + ) + +from torch import distributed as dist +from torch import nn +from torch.distributed._tensor import DTensor +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor import Partial, Replicate, Shard +from torch.nn import functional as F + +try: + from torch.distributed.tensor.parallel import ( + PrepareModuleInputOutput, + parallelize_module, + ) +except ImportError: + pytest.skip( + "torch version is too old, these tests require nightly build. Skipping MoE training tests.", + allow_module_level=True, + ) + +# this feature requires CUDA and SM89+ +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) +from torchao.quantization.quant_api import quantize_ + +from .testing_utils import _validate_model_conversion + +# this test requires torchtitan +try: + from torchtitan.distributed.expert_parallel import ( + ExpertParallel, + ExpertTensorParallel, + NoParallel, + TensorParallel, + set_token_group_alignment_size_m, + ) + from torchtitan.models.moe import MoE, MoEArgs +except ImportError: + pytest.skip( + "torchtitan not installed, skipping MoE tests.", allow_module_level=True + ) + + +@pytest.fixture(scope="module") +def device_mesh_1d() -> DeviceMesh: + """ + Fixture for setting up and tearing down the distributed environment + for the entire test module. + """ + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + if not dist.is_initialized(): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + device_mesh = init_device_mesh("cuda", (world_size,)) + torch.manual_seed(1) + torch.cuda.set_device(rank) + + yield device_mesh + + dist.destroy_process_group() + + +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + ["experts,shared_experts"], + ], +) +@pytest.mark.parametrize("compile", [False, True]) +@pytest.mark.parametrize( + "recipe_config", + [ + { + "recipe": MoEScalingType.FP8_ROWWISE, + "group_alignment_size": 16, + "min_out_sqnr": 29.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 23.0, + }, + { + "recipe": MoEScalingType.MXFP8, + "group_alignment_size": 32, + "min_out_sqnr": 28.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 21.0, + }, + ], +) +def test_moe_training_tp( + target_fqns: list[str], + compile: bool, + recipe_config: dict, + device_mesh_1d: DeviceMesh, +): + ( + recipe, + group_alignment_size, + min_out_sqnr, + min_input_grad_sqnr, + min_param_grad_sqnr, + ) = ( + recipe_config["recipe"], + recipe_config["group_alignment_size"], + recipe_config["min_out_sqnr"], + recipe_config["min_input_grad_sqnr"], + recipe_config["min_param_grad_sqnr"], + ) + assert torch.cuda.is_available() + if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != ( + 9, + 0, + ): + pytest.skip( + f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" + ) + + elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != ( + 10, + 0, + ): + pytest.skip( + f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" + ) + + # set token group alignment size needed for GEMM (contraction dim stride must be 16 byte aligned) + # or quantization ops (mxfp8 scaling groups are size 1x32) + set_token_group_alignment_size_m(group_alignment_size) + + # define model args + model_args = MoEArgs( + num_experts=8, + ) + + # define model args + model_args = MoEArgs( + num_experts=8, + ) + dim, hidden_dim = 5120, 4 * 5120 + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() + torch.manual_seed(1) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig(recipe) + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # validate that only the experts were converted + _validate_model_conversion( + model, + target_fqns=target_fqns, + ) + if compile: + # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it + model = torch.compile(model, fullgraph=False) + ref_model = torch.compile(ref_model, fullgraph=False) + + # apply TP + apply_moe_ep_tp(model, tp_mesh=device_mesh_1d, ep_mesh=None, ep_tp_mesh=None) + apply_moe_ep_tp(ref_model, tp_mesh=device_mesh_1d, ep_mesh=None, ep_tp_mesh=None) + + # Rough validation that parallelization was applied properly. + assert isinstance(model.experts.w1.data, DTensor), ( + "test model experts.w1 is not a DTensor" + ) + assert isinstance(model.experts.w2.data, DTensor), ( + "test model experts.w2 is not a DTensor" + ) + assert isinstance(model.experts.w3.data, DTensor), ( + "test model experts.w3 is not a DTensor" + ) + assert isinstance(ref_model.experts.w1.data, DTensor), ( + "ref model experts.w1 is not a DTensor" + ) + assert isinstance(ref_model.experts.w2.data, DTensor), ( + "ref model experts.w2 is not a DTensor" + ) + assert isinstance(ref_model.experts.w3.data, DTensor), ( + "ref model experts.w3 is not a DTensor" + ) + + # inputs + batch, seq = 8, 2048 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + # forward pass + ref_out = ref_model(ref_x) + out = model(x) + + # validate output + out_sqnr = compute_error(out, ref_out) + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) + + # compute loss + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + + # backward pass + ref_loss.backward() + out_loss.backward() + + # validate input gradient + input_grad_sqnr = compute_error(x.grad, ref_x.grad) + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." + ) + + # validate param gradients + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + param_grad_sqnr = compute_error(param1.grad, param2.grad) + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." + ) + + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, +): + # Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/ + # that supports single MoE layer independent of a transformer. + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "router.gate": NoParallel(), + # input Replicate, output Partial + "shared_expert": TensorParallel(), + } + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + + parallelize_module( + module=model.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) diff --git a/test/prototype/moe_training/test_tp.sh b/test/prototype/moe_training/test_tp.sh new file mode 100755 index 0000000000..2ab7636113 --- /dev/null +++ b/test/prototype/moe_training/test_tp.sh @@ -0,0 +1 @@ +torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_tp.py -s diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index 71320af83e..23cd4080ae 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -12,39 +12,95 @@ ) from torchao.float8.float8_utils import compute_error -from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig -from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) from torchao.quantization.quant_api import quantize_ +from .testing_utils import _validate_model_conversion + # this test requires torchtitan try: - from torchtitan.experiments.llama4.model.args import TransformerModelArgs - from torchtitan.experiments.llama4.model.moe import MoE + from torchtitan.distributed.expert_parallel import ( + set_token_group_alignment_size_m, + ) + from torchtitan.models.moe import MoE, MoEArgs except ImportError: - import warnings - - warnings.warn("torchtitan not installed, skipping MoE tests.") - pytest.skip(allow_module_level=True) + pytest.skip( + "torchtitan not installed, skipping MoE tests.", allow_module_level=True + ) @pytest.mark.parametrize( "target_fqns", + [["experts"]], +) +@pytest.mark.parametrize("compile", [False, True]) +@pytest.mark.parametrize( + "recipe_config", [ - ["experts"], - ["does.not.exist"], + { + "recipe": MoEScalingType.FP8_ROWWISE, + "group_alignment_size": 16, + "min_out_sqnr": 29.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 23.0, + }, + { + "recipe": MoEScalingType.MXFP8, + "group_alignment_size": 32, + "min_out_sqnr": 28.0, + "min_input_grad_sqnr": 29.0, + "min_param_grad_sqnr": 21.0, + }, ], ) -def test_moe_float8_training(target_fqns: list[str]): - model_args = TransformerModelArgs( - moe_enabled=True, +def test_moe_training(target_fqns: list[str], compile: bool, recipe_config: dict): + ( + recipe, + group_alignment_size, + min_out_sqnr, + min_input_grad_sqnr, + min_param_grad_sqnr, + ) = ( + recipe_config["recipe"], + recipe_config["group_alignment_size"], + recipe_config["min_out_sqnr"], + recipe_config["min_input_grad_sqnr"], + recipe_config["min_param_grad_sqnr"], + ) + assert torch.cuda.is_available() + if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != ( + 9, + 0, + ): + pytest.skip( + f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" + ) + + elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != ( + 10, + 0, + ): + pytest.skip( + f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" + ) + + # Set token group alignment size. This is required so that + # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input` + # has the contraction dim be divisible by 16. 16 byte alignment is required + # for the slowest moving dim (stride 1). + set_token_group_alignment_size_m(group_alignment_size) + model_args = MoEArgs( num_experts=8, - dim=256, ) init_std = 0.02 device = torch.device("cuda") # reference bf16 MoE - ref_model = MoE(model_args).to(torch.bfloat16).cuda() + dim, hidden_dim = 5120, 8192 + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() torch.manual_seed(42) ref_model.init_weights(init_std, device) @@ -63,7 +119,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig() + config = MoETrainingConfig(scaling_type=recipe) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # validate that only the experts were converted @@ -71,9 +127,13 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: model, target_fqns=target_fqns, ) + if compile: + # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it + model = torch.compile(model, fullgraph=False) + ref_model = torch.compile(ref_model, fullgraph=False) # inputs - batch, seq, dim = 8, 2048, 256 + batch, seq = 8, 2048 ref_x = torch.randn( batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device ) @@ -85,7 +145,9 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -98,43 +160,13 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 30.0, ( - f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) - - -def _validate_model_conversion( - root_module: nn.Module, - target_fqns: list[str], -): - def _recursive_validate( - module: nn.Module, - cur_fqn: str, - ): - is_allowed_module = cur_fqn in target_fqns - - # check current module params - for param_name, param in module.named_parameters(recurse=False): - is_converted_type = isinstance(param, ScaledGroupedMMTensor) - if is_converted_type: - assert is_allowed_module, ( - f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." - ) - if not is_allowed_module: - assert not is_converted_type, ( - f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." - ) - - # recursively check child modules - for child_name, child_module in module.named_children(): - child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name - _recursive_validate(child_module, child_fqn) - - _recursive_validate(root_module, "") diff --git a/test/prototype/moe_training/testing_utils.py b/test/prototype/moe_training/testing_utils.py new file mode 100644 index 0000000000..cf13b81ae3 --- /dev/null +++ b/test/prototype/moe_training/testing_utils.py @@ -0,0 +1,33 @@ +from torch import nn + +from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor + + +def _validate_model_conversion( + root_module: nn.Module, + target_fqns: list[str], +): + def _recursive_validate( + module: nn.Module, + cur_fqn: str, + ): + is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns]) + + # check current module params + for param_name, param in module.named_parameters(recurse=False): + is_converted_type = isinstance(param, ScaledGroupedMMTensor) + if is_converted_type: + assert is_allowed_module, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + if not is_allowed_module: + assert not is_converted_type, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + + # recursively check child modules + for child_name, child_module in module.named_children(): + child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name + _recursive_validate(child_module, child_fqn) + + _recursive_validate(root_module, "") diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py new file mode 100644 index 0000000000..988a879b5b --- /dev/null +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import pytest +import torch +import torch.nn as nn + +from torchao.prototype.mx_formats.config import ( + MXGemmKernelChoice, +) +from torchao.prototype.mx_formats.inference_workflow import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) +from torchao.quantization import quantize_ +from torchao.quantization.utils import compute_error +from torchao.testing.utils import skip_if_rocm +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_100, + torch_version_at_least, +) + +torch.manual_seed(2) + +if not torch_version_at_least("2.8.0"): + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +# source: https://stackoverflow.com/a/22638709 +@pytest.fixture(autouse=True) +def run_around_tests(): + # 1. before test - set up (currently do nothing) + # 2. run test + yield + # 3. after test - teardown + torch._dynamo.reset() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" +) +@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("compile", [True, False]) +@torch.no_grad() +@skip_if_rocm( + "ROCm float4 gemm require gfx950" +) # TODO(future): deploy gfx950 in ROCM CI +@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required") +def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): + """ + Smoke test for inference compile + """ + # TODO(future): figure out why these CUDA capability conditions are not properly + # applied when inside `pytest.mark.skipif` for this test + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + if not is_sm_at_least_89(): + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + elif elem_dtype == torch.float4_e2m1fn_x2: + if not is_sm_at_least_100(): + pytest.skip("CUDA capability >= 10.0 required for float4 gemm") + + m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") + m_mx = copy.deepcopy(m) + kernel_choice = ( + MXGemmKernelChoice.CUTLASS + if elem_dtype == torch.float4_e2m1fn_x2 + else MXGemmKernelChoice.CUBLAS + ) + config = MXFPInferenceConfig( + activation_dtype=elem_dtype, + weight_dtype=elem_dtype, + gemm_kernel_choice=kernel_choice, + ) + quantize_(m_mx, config=config) + if compile: + m_mx = torch.compile(m_mx, fullgraph=True) + + x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) + y_ref = m(x) + y_mx = m_mx(x) + sqnr = compute_error(y_ref, y_mx) + SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0 + assert sqnr >= SQNR_THRESHOLD, ( + f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize( + "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] +) +@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("use_triton_kernel", [True, False]) +@pytest.mark.parametrize( + "shapes", + [ + (128, 64, 256), + (256, 128, 512), + (145, 64, 256), + (128, 96, 256), + (128, 160, 256), + (64, 64, 256), + (200, 192, 256), + ], + ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}", +) +@torch.no_grad() +@skip_if_rocm("ROCm float4 gemm require gfx950") +def test_inference_workflow_nvfp4( + bias: bool, + compile: bool, + mm_config: NVFP4MMConfig, + inpt_dtype: torch.dtype, + use_triton_kernel: bool, + shapes: tuple, +): + """ + Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16 + Tests both DYNAMIC and WEIGHT_ONLY mm_config modes + """ + # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs + if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100(): + pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm") + + if bias and inpt_dtype == torch.float32: + pytest.xfail("Bias is not supported when module weight is in fp32") + + if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: + pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") + batch_size, in_features, out_features = shapes + + m = nn.Linear(in_features, out_features, bias=bias, dtype=inpt_dtype, device="cuda") + m_mx = copy.deepcopy(m) + + config = NVFP4InferenceConfig( + mm_config=mm_config, use_triton_kernel=use_triton_kernel + ) + quantize_(m_mx, config=config) + + if compile: + m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager") + + x = torch.randn(batch_size, in_features, device="cuda", dtype=inpt_dtype) + y_ref = m(x) + y_mx = m_mx(x) + sqnr = compute_error(y_ref, y_mx) + + if mm_config == NVFP4MMConfig.WEIGHT_ONLY: + SQNR_THRESHOLD = 18.0 + else: + SQNR_THRESHOLD = 15.0 + + assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}" + assert sqnr >= SQNR_THRESHOLD, ( + f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" + ) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index d649b2e04a..024586419a 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -35,27 +35,35 @@ get_bits, pack_uint4, pack_uint6, - triton_f4_to_bf16, triton_f6_e2m3_to_bf16, triton_f6_e3m2_to_bf16, triton_to_mxfp8_dim1, triton_to_mxfp8_dim1_reference, unpack_uint4, ) -from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, + torch_version_at_least, ) torch.manual_seed(0) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) +# TODO: shared utils file for benchmarking and testing +def to_mx_dim1_reference(x_hp, block_size, scaling_mode): + x_hp = x_hp.t().contiguous() + scale_d1, data_d1 = to_mx( + x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode + ) + return data_d1.t(), scale_d1 + + @pytest.mark.skip( reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501 ) @@ -318,37 +326,6 @@ def test_fp4_pack_unpack(): assert torch.all(orig_vals_dq == orig_vals) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -@pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0") -def test_fp4_triton_unscaled_cast(): - packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda") - f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals)) - f32_triton = triton_f4_to_bf16(packed_vals).to(torch.float) - assert torch.all(torch.eq(f32_ref, f32_triton)) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -@pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0") -def test_fp4_triton_scaled_cast(): - size = (256,) - orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100 - mxtensor_ref = MXTensor.to_mx( - orig_vals, block_size=32, elem_dtype=torch.float4_e2m1fn_x2 - ) - mxtensor_triton = MXTensor.to_mx( - orig_vals, - block_size=32, - elem_dtype=torch.float4_e2m1fn_x2, - use_fp4_custom_triton_dequant_kernel=True, - ) - - f32_ref = mxtensor_ref.to_dtype(torch.float) - f32_triton = mxtensor_triton.to_dtype(torch.float) - assert torch.all(torch.eq(f32_ref, f32_triton)) - - @pytest.mark.parametrize("dtype_name", (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2)) def test_fp6_values(dtype_name): """ @@ -488,3 +465,99 @@ def test_rearrange(shape): eager = to_blocked(scales, False) triton = to_blocked(scales, True) torch.testing.assert_close(eager, triton, atol=0, rtol=0) + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +@pytest.mark.parametrize("M", (32, 64, 2048)) +@pytest.mark.parametrize("K", (32, 64, 2048)) +@pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16)) +@pytest.mark.parametrize( + "scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) +) +def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): + from torchao.prototype import mxfp8_cuda + + scaling_mode_str = ( + "floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil" + ) + block_size = 32 + + # Use disinct incrementing values from 0 to M*K-1 to make debugging easier. + x = ( + torch.arange(0, M * K, dtype=input_dtype, device="cuda") + .reshape(M, K) + .contiguous() + ) + + y_d1_ref, s_d1_ref = to_mx_dim1_reference( + x, + block_size=block_size, + scaling_mode=scaling_mode, + ) + + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, + rowwise=False, + colwise=True, + scaling_mode=scaling_mode_str, + scale_dim_x=1, + scale_dim_y=block_size, + ) + + # check scales + torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0) + + # check quantized values + torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0) + assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match" + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +def test_cuda_mx_dim0_not_supported(): + from torchao.prototype import mxfp8_cuda + + M, K = 64, 64 + block_size = 32 + x = ( + torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + .reshape(M, K) + .contiguous() + ) + with pytest.raises(RuntimeError): + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, + rowwise=True, + colwise=False, + scale_dim_x=block_size, + scale_dim_y=1, + ) + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +def test_cuda_mx_dim1_invalid_block_size(): + from torchao.prototype import mxfp8_cuda + + M, K = 64, 64 + x = ( + torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + .reshape(M, K) + .contiguous() + ) + invalid_block_size = 4 + with pytest.raises(RuntimeError): + _, y_d1, _, s_d1 = mxfp8_cuda.quantize( + x, + rowwise=False, + colwise=True, + scale_dim_x=1, + scale_dim_y=invalid_block_size, + ) diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index 4aefb3874e..9dc850a872 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -15,9 +15,9 @@ import pytest import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import is_sm_at_least_100, torch_version_at_least -if not TORCH_VERSION_AT_LEAST_2_7: +if not torch_version_at_least("2.7.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) from torch.distributed._tensor import DTensor, Shard, distribute_tensor @@ -25,6 +25,7 @@ from tqdm import tqdm from torchao.prototype.mx_formats import MXLinearConfig +from torchao.prototype.mx_formats.config import MXFP8Dim1CastKernelChoice from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.testing.training.dtensor_utils import ( _test_lowp_mlp_tensor_parallelism_base, @@ -68,9 +69,9 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): ) -def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16): +def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128): config = MXLinearConfig.from_recipe_name("mxfp8_emulated") - config.block_size = 16 + config.block_size = 32 _test_lowp_mlp_tensor_parallelism_base( mesh, config, size, compile=False, allgather_in_lowp=False ) @@ -79,12 +80,38 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16): ) +def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128): + config = MXLinearConfig.from_recipe_name("mxfp8_emulated") + config.block_size = 32 + config.mxfp8_cast_kernel_choice = MXFP8Dim1CastKernelChoice.TRITON + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=False, allgather_in_lowp=False + ) + # TODO(future PR): enable compile here, currently seeing + # https://www.internalfb.com/phabricator/paste/view/P1851219639 + # _test_lowp_mlp_tensor_parallelism_base( + # mesh, config, size, compile=True, allgather_in_lowp=False + # ) + + +def _test_mxfp8_mlp_tensor_parallelism_dim1_cuda(mesh: DeviceMesh, size=128): + config = MXLinearConfig.from_recipe_name("mxfp8_emulated") + config.block_size = 32 + config.mxfp8_cast_kernel_choice = MXFP8Dim1CastKernelChoice.CUDA + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=False, allgather_in_lowp=False + ) + + if __name__ == "__main__": device_mesh = setup_distributed() tests = [ _test_dtensor_cast_to_mxfp8, _test_mxfp8_mlp_tensor_parallelism, + _test_mxfp8_mlp_tensor_parallelism_dim1_triton, ] + if is_sm_at_least_100(): + tests.append(_test_mxfp8_mlp_tensor_parallelism_dim1_cuda) for test in tqdm(tests, desc="Running tests"): try: diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 8a69737889..c858657af6 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -9,40 +9,31 @@ import pytest import torch import torch.nn as nn -import torch.nn.functional as F from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, - MXInferenceLinearConfig, + MXFP8Dim1CastKernelChoice, MXLinearConfig, MXLinearRecipeName, + ScaleCalculationMode, ) from torchao.prototype.mx_formats.constants import ( DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, - SUPPORTED_ELEM_DTYPES, ) from torchao.prototype.mx_formats.mx_linear import ( - MXInferenceLinear, MXLinear, ) -from torchao.prototype.mx_formats.mx_subclass import ( - MXFPInferenceConfig, - NVFP4InferenceConfig, - NVFP4MMConfig, -) from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error -from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_100, + torch_version_at_least, ) torch.manual_seed(2) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -66,7 +57,7 @@ def run_around_tests(): # only test one type of mixed-dtype overrides, to save testing time (torch.float8_e4m3fn, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2), ] - if TORCH_VERSION_AT_LEAST_2_8 + if torch_version_at_least("2.8.0") else [ # test each dtype (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn), @@ -80,16 +71,32 @@ def run_around_tests(): @pytest.mark.parametrize("elem_dtype", elem_dtypes) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)]) -@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True]) +@pytest.mark.parametrize( + "mxfp8_cast_kernel_choice", + [ + MXFP8Dim1CastKernelChoice.TORCH, + MXFP8Dim1CastKernelChoice.TRITON, + MXFP8Dim1CastKernelChoice.CUDA, + ], +) +@pytest.mark.parametrize( + "scale_calculation_mode", + [ + ScaleCalculationMode.FLOOR, + ScaleCalculationMode.CEIL, + ScaleCalculationMode.EVEN, + ScaleCalculationMode.RCEIL, + ], +) def test_linear_eager_vs_hp( - elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel + elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice, scale_calculation_mode ): """ Smoke test for training linear module with mx weight, compares the following: * baseline: float32 * experiment: emulated MX """ - if use_fp8_dim1_cast_triton_kernel: + if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: if elem_dtype != ( torch.float8_e4m3fn, torch.float8_e4m3fn, @@ -99,6 +106,18 @@ def test_linear_eager_vs_hp( elif not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON: + if scale_calculation_mode != ScaleCalculationMode.FLOOR: + pytest.skip("unsupported configuration") + elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA: + if scale_calculation_mode not in ( + ScaleCalculationMode.FLOOR, + ScaleCalculationMode.RCEIL, + ): + pytest.skip("unsupported configuration") + elif not is_sm_at_least_100(): + pytest.skip("CUDA capability >= 10.0 required for MX dim1 cast cuda kernel") + # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) grad_shape[-1] = 256 @@ -108,11 +127,12 @@ def test_linear_eager_vs_hp( ) m_mx = copy.deepcopy(m) config = MXLinearConfig( - block_size=4, + block_size=32, # Only 32 is supported for now elem_dtype=elem_dtype[0], elem_dtype_weight_override=elem_dtype[1], elem_dtype_grad_output_override=elem_dtype[2], - use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel, + mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice, + scale_calculation_mode=scale_calculation_mode, ) quantize_(m_mx, config) @@ -130,9 +150,9 @@ def test_linear_eager_vs_hp( y_ref.backward(g) y_mx.backward(g) - y_sqnr = compute_error(y_ref, y_mx) - w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad) - x_g_sqnr = compute_error(x_ref.grad, x.grad) + y_sqnr = compute_error(y_ref, y_mx).item() + w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad).item() + x_g_sqnr = compute_error(x_ref.grad, x.grad).item() if elem_dtype == (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn): assert y_sqnr >= 18.0 @@ -226,8 +246,28 @@ def test_activation_checkpointing(): @pytest.mark.parametrize("bias", [False, True]) # TODO(future PR): figure out why torch.compile does not match eager when # autocast is on -@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True]) -def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel): +@pytest.mark.parametrize( + "mxfp8_cast_kernel_choice", + [ + MXFP8Dim1CastKernelChoice.TORCH, + MXFP8Dim1CastKernelChoice.TRITON, + MXFP8Dim1CastKernelChoice.CUDA, + ], +) +@pytest.mark.parametrize( + "scale_calculation_mode", + [ + ScaleCalculationMode.FLOOR, + ScaleCalculationMode.CEIL, + # even + compile does not work yet: + # https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4 + # ScaleCalculationMode.EVEN, + ScaleCalculationMode.RCEIL, + ], +) +def test_linear_compile( + hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice, scale_calculation_mode +): """ Verify that compile does not change numerics of MX linear fw + bw """ @@ -236,7 +276,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke pytest.skip("CUDA capability >= 8.9 required for float8 in triton") if recipe_name in ["mxfp8_cublas", "mxfp4_cutlass"]: - if not TORCH_VERSION_AT_LEAST_2_8: + if not torch_version_at_least("2.8.0"): pytest.skip("torch.compile requires PyTorch 2.8+") if not is_sm_at_least_100(): pytest.skip("CUDA capability >= 10.0 required for MX gemms") @@ -245,7 +285,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke # TODO(future PR): fix this, things are clearly broken with bias=True pytest.skip("this test is broken for non-emulated recipes with bias=True") - if use_fp8_dim1_cast_triton_kernel: + if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"): pytest.skip("unsupported configuration") if not is_sm_at_least_89(): @@ -253,12 +293,33 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke if hp_dtype != torch.bfloat16: pytest.skip("unsupported configuration") + if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON: + if scale_calculation_mode != ScaleCalculationMode.FLOOR: + pytest.skip("unsupported configuration") + elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA: + if scale_calculation_mode not in ( + ScaleCalculationMode.FLOOR, + ScaleCalculationMode.RCEIL, + ): + pytest.skip("unsupported configuration") + if hp_dtype == torch.bfloat16 and recipe_name != "mxfp8_cublas": # TODO(future PR): properly enable float32 + bfloat16 for every # recipe, this needs a cleanup of out_dtype (needs to match in-hp-dtype, even # if the underlying gemm kernel only supports bf16 output) pytest.skip("unsupported configuration") + if ( + hp_dtype == torch.float32 + and recipe_name == "mxfp8_emulated" + and mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TORCH + and not is_sm_at_least_100() + ): + # TODO(future): debug this + pytest.skip( + "there are currently accuracy issues with this configuration on H100 and below" + ) + M, K, N = 128, 256, 512 input_shape = (M, K) grad_shape = (M, N) @@ -266,7 +327,8 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype), ) config = MXLinearConfig.from_recipe_name(recipe_name) - config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel + config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice + config.scale_calculation_mode = scale_calculation_mode quantize_(m_mx, config=config) m_mx_c = copy.deepcopy(m_mx) @@ -299,65 +361,11 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke torch.testing.assert_close(x_g_ref, x_g, atol=0.02, rtol=0.02) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("input_shape", [(2, 4), (1, 2, 4), (1, 1, 2, 4)]) -def test_inference_linear(elem_dtype, bias, input_shape): - """ - Smoke test for inference linear module with mx weight - """ - m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16)) - m = m.cuda() - m_mx = copy.deepcopy(m) - config = MXInferenceLinearConfig(block_size=4, elem_dtype=elem_dtype) - quantize_(m_mx, config=config) - - x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16) - y_ref = m(x) - y_mx = m_mx(x) - sqnr = compute_error(y_ref, y_mx) - if elem_dtype is torch.float8_e4m3fn: - assert sqnr >= 20.0 - else: - assert sqnr >= 11.0 - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" -) -@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_inference_compile_simple(elem_dtype): - """ - Smoke test for inference compile - """ - if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_at_least_89(): - pytest.skip("CUDA capability >= 8.9 required for float8 in triton") - m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16)) - m = m.cuda() - m_mx = copy.deepcopy(m) - config = MXInferenceLinearConfig(block_size=4, elem_dtype=elem_dtype) - quantize_(m_mx, config=config) - m_mx = torch.compile(m_mx, fullgraph="true") - - x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16) - y_ref = m(x) - y_mx = m_mx(x) - sqnr = compute_error(y_ref, y_mx) - if elem_dtype is torch.float8_e4m3fn: - assert sqnr >= 20.0 - else: - assert sqnr >= 11.5 - - def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), nn.Linear(32, 32), ) - m2 = copy.deepcopy(m1) filter_fn = lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "1" # noqa: E731 config = MXLinearConfig(block_size=32) @@ -365,11 +373,6 @@ def test_filter_fn(): assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear - config2 = MXInferenceLinearConfig(block_size=32) - quantize_(m2, config=config2, filter_fn=filter_fn) # noqa: E501 - assert type(m2[0]) == MXInferenceLinear - assert type(m2[1]) == torch.nn.Linear - def test_training_print_str(): m = nn.Sequential(nn.Linear(32, 32)) @@ -378,202 +381,3 @@ def test_training_print_str(): s = str(m) assert "bl_sz=32" in s assert "kernel=emulated" in s - - -def test_inference_print_str(): - m = nn.Sequential(nn.Linear(32, 32)) - config = MXInferenceLinearConfig() - quantize_(m, config=config) - s = str(m) - assert "bl_sz=32" in s - assert "kernel=emulated" in s - - -test_dtypes = ( - [torch.float8_e4m3fn, torch.float4_e2m1fn_x2] - if TORCH_VERSION_AT_LEAST_2_8 - else [ - torch.float8_e4m3fn, - ] -) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" -) -@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("compile", [True, False]) -@torch.no_grad() -@skip_if_rocm( - "ROCm float4 gemm require gfx950" -) # TODO(future): deploy gfx950 in ROCM CI -@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required") -def test_inference_subclass(elem_dtype, bias: bool, compile: bool): - """ - Smoke test for inference compile - """ - # TODO(future): figure out why these CUDA capability conditions are not properly - # applied when inside `pytest.mark.skipif` for this test - if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_at_least_89(): - pytest.skip("CUDA capability >= 8.9 required for float8 in triton") - elif elem_dtype == torch.float4_e2m1fn_x2: - if not is_sm_at_least_100(): - pytest.skip("CUDA capability >= 10.0 required for float4 gemm") - - m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") - m_mx = copy.deepcopy(m) - kernel_choice = ( - MXGemmKernelChoice.CUTLASS - if elem_dtype == torch.float4_e2m1fn_x2 - else MXGemmKernelChoice.CUBLAS - ) - config = MXFPInferenceConfig( - activation_dtype=elem_dtype, - weight_dtype=elem_dtype, - gemm_kernel_choice=kernel_choice, - ) - quantize_(m_mx, config=config) - if compile: - m_mx = torch.compile(m_mx, fullgraph=True) - - x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) - y_ref = m(x) - y_mx = m_mx(x) - sqnr = compute_error(y_ref, y_mx) - SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0 - assert sqnr >= SQNR_THRESHOLD, ( - f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}" - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" -) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("compile", [True, False]) -@pytest.mark.parametrize( - "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] -) -@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) -@torch.no_grad() -@skip_if_rocm("ROCm float4 gemm require gfx950") -def test_inference_subclass_nvfp4( - bias: bool, compile: bool, mm_config: NVFP4MMConfig, inpt_dtype: torch.dtype -): - """ - Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16 - Tests both DYNAMIC and WEIGHT_ONLY mm_config modes - """ - # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs - if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100(): - pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm") - - if bias and inpt_dtype == torch.float32: - pytest.xfail("Bias is not supported when module weight is in fp32") - - if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: - pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") - m = nn.Linear(64, 256, bias=bias, dtype=inpt_dtype, device="cuda") - m_mx = copy.deepcopy(m) - - config = NVFP4InferenceConfig(mm_config=mm_config) - quantize_(m_mx, config=config) - - if compile: - m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager") - - x = torch.randn(128, 64, device="cuda", dtype=inpt_dtype) - y_ref = m(x) - y_mx = m_mx(x) - sqnr = compute_error(y_ref, y_mx) - - if mm_config == NVFP4MMConfig.WEIGHT_ONLY: - SQNR_THRESHOLD = 18.0 - else: - SQNR_THRESHOLD = 15.0 - - assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}" - assert sqnr >= SQNR_THRESHOLD, ( - f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" -) -@pytest.mark.parametrize("use_gelu", [True, False]) -@pytest.mark.parametrize( - "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] -) -@pytest.mark.parametrize("compile", [False]) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) -@torch.no_grad() -@skip_if_rocm("ROCm float4 gemm require gfx950") -def test_nvfp4_matmul_with_amax( - use_gelu: bool, - mm_config: NVFP4MMConfig, - compile: bool, - bias: bool, - inpt_dtype: torch.dtype, -): - from torchao.prototype.mx_formats.nvfp4_tensor import ( - NVFP4Tensor, - per_tensor_amax_to_scale, - ) - - # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs - if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100(): - pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm") - - if bias and inpt_dtype == torch.float32: - pytest.xfail("Bias is not supported when module weight is in fp32") - - if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: - pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") - - m, k, n = 64, 256, 128 - - # Create activation tensor - if use_gelu: - x = torch.randn(m, k, dtype=inpt_dtype, device="cuda") - A = torch.nn.functional.gelu(x) - else: - A = torch.randn(m, k, dtype=inpt_dtype, device="cuda") - - B = torch.randn(n, k, dtype=inpt_dtype, device="cuda") - bias_tensor = torch.randn(n, dtype=inpt_dtype, device="cuda") if bias else None - - # Compute reference - C_ref = F.linear(A, B, bias_tensor) - - a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A))) - b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B))) - A_nvfp4 = NVFP4Tensor.to_nvfp4( - A, - per_tensor_scale=a_scale, - mm_config=mm_config, - ) - B_nvfp4 = NVFP4Tensor.to_nvfp4( - B, - per_tensor_scale=b_scale, - mm_config=mm_config, - ) - - func = torch.compile(F.linear, fullgraph=True) if compile else F.linear - - C_nvfp4 = func(A_nvfp4, B_nvfp4, bias_tensor) - assert C_nvfp4.dtype == inpt_dtype, ( - f"Got {C_nvfp4.dtype} for inpt_dtype={inpt_dtype}" - ) - - sqnr = compute_error(C_ref, C_nvfp4) - SQNR_THRESHOLD = 16.0 - assert sqnr >= SQNR_THRESHOLD, ( - f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}" - ) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index 46380cfb55..7cc876de6b 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -13,11 +13,11 @@ from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, + torch_version_at_least, ) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -38,8 +38,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float: a_mx = MXTensor.to_mx(a, fmt, 32) b_mx = MXTensor.to_mx(b, fmt, 32) - a_data = a_mx._data - b_data = b_mx._data + a_data = a_mx.qdata + b_data = b_mx.qdata assert b_data.is_contiguous() b_data = b_data.transpose(-1, -2) @@ -79,7 +79,7 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float: ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}", ) @pytest.mark.parametrize( - "format", ["fp8", "fp4"] if TORCH_VERSION_AT_LEAST_2_8 else ["fp8"] + "format", ["fp8", "fp4"] if torch_version_at_least("2.8.0") else ["fp8"] ) def test_matrix_multiplication(size, format): M, K, N = size diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 7294590b57..38eefbff07 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -25,13 +25,14 @@ ) from torchao.quantization.utils import compute_error from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, + is_sm_at_least_100, + torch_version_at_least, ) torch.manual_seed(2) -if not TORCH_VERSION_AT_LEAST_2_8: +if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -69,6 +70,14 @@ def assert_sqnr_gt_threshold(orig, new, threshold): else: assert_sqnr_gt_threshold(data_hp, data_mx_dq, 13.0) + # verify that if data.shape is (M, K) then scale.shape is (M, K // block_size) + prev_dims, K = data_hp.shape[:-1], data_hp.shape[-1] + if elem_dtype is torch.float4_e2m1fn_x2: + assert data_mx.qdata.shape == (*prev_dims, K // 2) + else: + assert data_mx.qdata.shape == (*prev_dims, K) + assert data_mx._scale_e8m0.shape == (*prev_dims, K // block_size) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @@ -139,8 +148,8 @@ def test_to_mx_rceil(): data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL ) torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) - assert torch.isnan(data_mx._data[0]) - assert torch.all(data_mx._data[1:] == 0) + assert torch.isnan(data_mx.qdata[0]) + assert torch.all(data_mx.qdata[1:] == 0) # fp32 denorm # fmt: off data_hp = torch.tensor( @@ -161,7 +170,7 @@ def test_to_mx_rceil(): data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL ) torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) - torch.testing.assert_close(data_mx._data, ground_truth_fp8) + torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) # bf16 denorm # fmt: off data_hp = torch.tensor( @@ -182,7 +191,7 @@ def test_to_mx_rceil(): data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL ) torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) - torch.testing.assert_close(data_mx._data, ground_truth_fp8) + torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) # fp32 some denorm # fmt: off data_hp = torch.tensor( @@ -213,7 +222,7 @@ def test_to_mx_rceil(): data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL ) torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) - torch.testing.assert_close(data_mx._data, ground_truth_fp8) + torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) # bf16 some denorm # fmt: off data_hp = torch.tensor( @@ -244,7 +253,7 @@ def test_to_mx_rceil(): data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL ) torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) - torch.testing.assert_close(data_mx._data, ground_truth_fp8) + torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) # zero data_hp = torch.tensor([0] * 32, dtype=torch.uint32).view(torch.float32) ground_truth_scale = torch.tensor([0], dtype=torch.uint8).view(torch.float8_e8m0fnu) @@ -255,7 +264,7 @@ def test_to_mx_rceil(): data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL ) torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) - torch.testing.assert_close(data_mx._data, ground_truth_fp8) + torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) # fp32 normal # fmt: off data_hp = torch.tensor( @@ -286,7 +295,7 @@ def test_to_mx_rceil(): data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL ) torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) - torch.testing.assert_close(data_mx._data, ground_truth_fp8) + torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) # bf16 normal # fmt: off data_hp = torch.tensor( @@ -317,7 +326,7 @@ def test_to_mx_rceil(): data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL ) torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) - torch.testing.assert_close(data_mx._data, ground_truth_fp8) + torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -371,16 +380,15 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): else: raise AssertionError("unsupported") block_size = 4 - use_fp4_custom_triton_dequant_kernel = False tensor_mx = MXTensor( - scale_e8m0, data_bits, + scale_e8m0, elem_dtype, block_size, torch.float, - use_fp4_custom_triton_dequant_kernel, MXGemmKernelChoice.EMULATED, pack_fp6, + None, ) tensor_hp = tensor_mx.to_dtype(torch.float) assert torch.all(torch.isnan(tensor_hp.flatten()[0:4])) @@ -417,14 +425,10 @@ def test_block_sizes(elem_dtype, B): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -@pytest.mark.parametrize("fp4_triton", [False, True]) -def test_transpose(elem_dtype, fp4_triton): +def test_transpose(elem_dtype): """ Verify that transposing an MX tensor works """ - if elem_dtype != torch.float4_e2m1fn_x2 and fp4_triton: - pytest.skip("unsupported configuration") - M, K = 128, 256 block_size = 32 tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) @@ -432,7 +436,6 @@ def test_transpose(elem_dtype, fp4_triton): tensor_hp, elem_dtype, block_size, - use_fp4_custom_triton_dequant_kernel=fp4_triton, ) tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() @@ -464,7 +467,7 @@ def test_fp6_packing(elem_dtype, pack_fp6): else: expected_packed_shape = x.shape - assert x_mx._data.shape == expected_packed_shape + assert x_mx.qdata.shape == expected_packed_shape @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -496,28 +499,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): atol=0, rtol=0, ) - torch.testing.assert_close(x_mx._data, x_mx_c._data, atol=0, rtol=0) + torch.testing.assert_close(x_mx.qdata, x_mx_c.qdata, atol=0, rtol=0) to_dtype_c = torch.compile(to_dtype, fullgraph=True) - use_fp4_custom_triton_dequant_kernel = False pack_fp6 = False x_mx_dq = to_dtype( - x_mx._data, + x_mx.qdata, x_mx._scale_e8m0, x_mx._elem_dtype, x_mx._block_size, hp_dtype, # noqa: E501 - use_fp4_custom_triton_dequant_kernel, pack_fp6, ) x_mx_c_dq = to_dtype_c( - x_mx_c._data, + x_mx_c.qdata, x_mx_c._scale_e8m0, x_mx_c._elem_dtype, x_mx_c._block_size, hp_dtype, - use_fp4_custom_triton_dequant_kernel, pack_fp6, ) torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) @@ -605,7 +605,7 @@ def to_f8(x): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale): from torchao.prototype.mx_formats.nvfp4_tensor import ( @@ -657,3 +657,367 @@ def assert_sqnr_gt_threshold(orig, new, threshold): assert x.t().dtype == x_reconstructed_t.dtype, ( f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}" ) + + +@pytest.mark.parametrize( + "shape", + [ + (128, 4), + (256, 8), + (100, 3), + (4, 4), + (50, 10), + (384, 12), + ], +) +@pytest.mark.parametrize( + "use_triton_kernel", [False, True] if torch.cuda.is_available() else [False] +) +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" +) +def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): + from torchao.prototype.mx_formats.utils import from_blocked, to_blocked + + rows, cols = shape + device = "cuda" if torch.cuda.is_available() else "cpu" + + original = torch.randint(0, 255, (rows, cols), device=device, dtype=torch.uint8) + + blocked = to_blocked(original, use_triton_kernel=use_triton_kernel) + reconstructed = from_blocked(blocked, rows, cols) + + torch.testing.assert_close( + original, + reconstructed, + atol=0.0, + rtol=0.0, + msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}", + ) + + +@pytest.mark.parametrize("is_swizzled_scales", [False, True]) +@pytest.mark.parametrize( + "shape", + [ + (32, 64), + (16, 32), + (64, 128), + (384, 128), + ], +) +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): + """ + Test that NVFP4Tensor can be constructed with swizzled scales and + that the _is_swizzled_scales flag is set correctly. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = shape + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales) + assert tensor._is_swizzled_scales == is_swizzled_scales + reconstructed = tensor.to_dtype(torch.bfloat16) + assert reconstructed.shape == data.shape + + +@pytest.mark.parametrize( + "slice_dim,slice_spec", + [ + # Row slicing - must align with 128-row boundaries + pytest.param(0, slice(0, 128), id="slice_rows[0:128]"), + pytest.param(0, slice(128, 256), id="slice_rows[128:256]"), + # Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size) + pytest.param(1, slice(0, 64), id="slice_cols[0:64]"), + pytest.param(1, slice(64, 128), id="slice_cols[64:128]"), + pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"), + # Test tensor parallelism patterns (half splits) + pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"), + pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"), + # Test quarter splits + pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"), + pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"), + ], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): + """ + Test that slicing works correctly with swizzled scales and maintains + the swizzled state in the output tensor. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + # Use larger tensor sizes that align with swizzled requirements + if slice_dim == 0: + # For row slicing, need at least 256 rows to test 128-row boundaries + M, K = 256, 4096 + else: + # For column slicing, need multiples of 64 columns for alignment + M, K = 128, 4096 + + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + assert tensor._is_swizzled_scales == True + + if slice_dim == 0: + sliced_tensor = tensor[slice_spec, :] + else: + sliced_tensor = tensor[:, slice_spec] + + # Verify sliced tensor maintains swizzled state + assert sliced_tensor._is_swizzled_scales == True + + # Verify sliced tensor can be dequantized + sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16) + + # Compare with direct slicing of original data + original_reconstructed = tensor.to_dtype(torch.bfloat16) + if slice_dim == 0: + expected = original_reconstructed[slice_spec, :] + else: + expected = original_reconstructed[:, slice_spec] + + torch.testing.assert_close(sliced_reconstructed, expected, atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize( + "slice_dim,slice_spec,expected_error", + [ + # Row slicing with misaligned boundaries + pytest.param( + 0, + slice(0, 100), + "Row slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_row_end", + ), + pytest.param( + 0, + slice(50, 150), + "Row slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_row_start", + ), + # Column slicing with misaligned boundaries + pytest.param( + 1, + slice(0, 32), + "Column slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_col_32", + ), + pytest.param( + 1, + slice(16, 80), + "Column slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_col_start", + ), + pytest.param( + 1, + slice(0, 100), + "Column slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_col_end", + ), + # Odd column boundaries (FP4 packing requirement) + pytest.param( + 1, + slice(1, 65), + "start index to be a multiple of 64, got 1", + id="odd_start", + ), + pytest.param( + 1, + slice(0, 65), + " multiple of 64 or equal to tensor size 4096, got 65", + id="odd_end", + ), + ], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_error): + """ + Test that slicing raises appropriate errors for misaligned boundaries. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = 256, 4096 + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + + with pytest.raises(RuntimeError, match=expected_error): + if slice_dim == 0: + _ = tensor[slice_spec, :] + else: + _ = tensor[:, slice_spec] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_view_semantics(): + """ + Test that slicing maintains proper view semantics where possible. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = 256, 4096 + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + + # Test row slicing (should maintain views) + sliced_tensor = tensor[0:128, :] + + # Test that the sliced tensor shares storage with original for data + # (Note: scales might not share storage due to swizzled layout complexity) + assert sliced_tensor.qdata.data_ptr() == tensor.qdata.data_ptr() + + # Test full-width column slicing (should maintain views) + full_width_slice = tensor[:, 0:K] + assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr() + assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_serialization(): + """ + Test that tensor flatten/unflatten preserves the swizzled scales state. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = 32, 64 + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + # Create tensor with swizzled scales + original_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + + # Test serialization + tensor_list, ctx = original_tensor.__tensor_flatten__() + + # Verify swizzled flag is preserved in context + assert "_is_swizzled_scales" in ctx + assert ctx["_is_swizzled_scales"] == True + + # Test deserialization + inner_tensors = {} + for name in tensor_list: + inner_tensors[name] = getattr(original_tensor, name) + + reconstructed_tensor = NVFP4Tensor.__tensor_unflatten__( + inner_tensors, ctx, None, None + ) + + # Verify the swizzled state is preserved + assert reconstructed_tensor._is_swizzled_scales == True + + # Verify functionality is preserved + original_dq = original_tensor.to_dtype(torch.bfloat16) + reconstructed_dq = reconstructed_tensor.to_dtype(torch.bfloat16) + + torch.testing.assert_close(original_dq, reconstructed_dq, atol=1e-6, rtol=1e-6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_get_scales_method(): + """ + Test that the get_scales() method correctly unswizzles scales when needed. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = 32, 64 + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + # Create tensors with both storage methods + regular_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=False) + swizzled_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + + # Get scales from both tensors and verify they are equal + regular_scales = regular_tensor.get_hp_scales() + swizzled_scales = swizzled_tensor.get_hp_scales() + torch.testing.assert_close(regular_scales, swizzled_scales, atol=0.0, rtol=0.0) + + # Verify scales have the expected shape + expected_shape = (M, K // 16) + assert regular_scales.shape == expected_shape + assert swizzled_scales.shape == expected_shape + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}" +) +@pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}") +@pytest.mark.parametrize( + "use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"] +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="requires sm100+ for raw intrinsics" +) +@torch.no_grad() +def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): + """Test that Triton and PyTorch NVFP4 quantization produce equivalent results.""" + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + per_tensor_amax_to_scale, + unpack_uint4, + ) + + torch.manual_seed(42) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + per_tensor_scale = None + if use_per_tensor_scale: + per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x))) + + nvfp4_pt = NVFP4Tensor.to_nvfp4( + x.clone(), + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=True, + use_triton_kernel=False, + ) + + nvfp4_triton = NVFP4Tensor.to_nvfp4( + x.clone(), + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=True, + use_triton_kernel=True, + ) + + torch.testing.assert_close( + nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten() + ) + pt_unpacked = unpack_uint4(nvfp4_pt.qdata) + triton_unpacked = unpack_uint4(nvfp4_triton.qdata) + torch.testing.assert_close( + pt_unpacked, + triton_unpacked, + atol=0, + rtol=0, + ) + + x_pt_dequant = nvfp4_pt.to_dtype(dtype) + x_triton_dequant = nvfp4_triton.to_dtype(dtype) + + sqnr = compute_error(x_pt_dequant, x_triton_dequant) + SQNR_THRESHOLD = 40.0 + + assert sqnr >= SQNR_THRESHOLD, ( + f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD} for M={M}, N={N}, " + f"use_per_tensor_scale={use_per_tensor_scale}, dtype={dtype}" + ) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py new file mode 100644 index 0000000000..1eaa335c1e --- /dev/null +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -0,0 +1,547 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) 2025, NVIDIA CORPORATION. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +import torch.nn.functional as F + +from torchao.prototype.mx_formats.constants import ( + F4_E2M1_MAX, +) +from torchao.prototype.mx_formats.inference_workflow import ( + NVFP4MMConfig, +) +from torchao.prototype.mx_formats.nvfp4_tensor import ( + QuantizeTensorToNVFP4Kwargs, +) +from torchao.quantization.utils import compute_error +from torchao.testing.utils import skip_if_rocm +from torchao.utils import ( + is_sm_at_least_100, + torch_version_at_least, +) + +torch.manual_seed(2) + +if not torch_version_at_least("2.8.0"): + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +@pytest.mark.parametrize( + "dtype,shape,use_per_tensor_scale", + [ + (torch.bfloat16, (32, 64), False), + (torch.float32, (64, 128), False), + (torch.bfloat16, (128, 256), False), + (torch.bfloat16, (64, 128), True), + ], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" +) +def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale): + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + per_tensor_amax_to_scale, + ) + + x = torch.randn(shape, dtype=dtype, device="cuda") + if use_per_tensor_scale: + tensor_amax = torch.max(torch.abs(x)) + scale = per_tensor_amax_to_scale(tensor_amax) + else: + scale = None + + x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale) + x_reconstructed = x_nvfp4.to_dtype(dtype) + + def assert_sqnr_gt_threshold(orig, new, threshold): + sqnr = compute_error(orig, new) + if torch.all(torch.isnan(sqnr)): + # if both operands are full of zeroes, sqnr is nan and this is ok + # test for this explicitly + assert torch.all(orig == 0) and torch.all(new == 0) + else: + assert sqnr >= threshold + + reconstructed_amax = x_nvfp4.get_hp_scales().view(shape[0], -1, 1) * F4_E2M1_MAX + max_abs = torch.amax( + torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1 + ).unsqueeze(-1) + + assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0) + assert_sqnr_gt_threshold(x, x_reconstructed, 8.0) + + assert x.shape == x_reconstructed.shape, ( + f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}" + ) + assert x.dtype == x_reconstructed.dtype, ( + f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}" + ) + + x_nvfp4_t = x_nvfp4.t() + x_reconstructed_t = x_nvfp4_t.to_dtype(dtype) + assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0) + + assert x.t().shape == x_reconstructed_t.shape, ( + f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}" + ) + assert x.t().dtype == x_reconstructed_t.dtype, ( + f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}" + ) + + +@pytest.mark.parametrize("is_swizzled_scales", [False, True]) +@pytest.mark.parametrize( + "shape", + [ + (32, 64), + (16, 32), + (64, 128), + (384, 128), + ], +) +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape): + """ + Test that NVFP4Tensor can be constructed with swizzled scales and + that the _is_swizzled_scales flag is set correctly. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = shape + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales) + assert tensor._is_swizzled_scales == is_swizzled_scales + reconstructed = tensor.to_dtype(torch.bfloat16) + assert reconstructed.shape == data.shape + + +@pytest.mark.parametrize( + "slice_dim,slice_spec", + [ + # Row slicing - must align with 128-row boundaries + pytest.param(0, slice(0, 128), id="slice_rows[0:128]"), + pytest.param(0, slice(128, 256), id="slice_rows[128:256]"), + # Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size) + pytest.param(1, slice(0, 64), id="slice_cols[0:64]"), + pytest.param(1, slice(64, 128), id="slice_cols[64:128]"), + pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"), + # Test tensor parallelism patterns (half splits) + pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"), + pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"), + # Test quarter splits + pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"), + pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"), + ], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec): + """ + Test that slicing works correctly with swizzled scales and maintains + the swizzled state in the output tensor. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + # Use larger tensor sizes that align with swizzled requirements + if slice_dim == 0: + # For row slicing, need at least 256 rows to test 128-row boundaries + M, K = 256, 4096 + else: + # For column slicing, need multiples of 64 columns for alignment + M, K = 128, 4096 + + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + assert tensor._is_swizzled_scales == True + + if slice_dim == 0: + sliced_tensor = tensor[slice_spec, :] + else: + sliced_tensor = tensor[:, slice_spec] + + # Verify sliced tensor maintains swizzled state + assert sliced_tensor._is_swizzled_scales == True + + # Verify sliced tensor can be dequantized + sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16) + + # Compare with direct slicing of original data + original_reconstructed = tensor.to_dtype(torch.bfloat16) + if slice_dim == 0: + expected = original_reconstructed[slice_spec, :] + else: + expected = original_reconstructed[:, slice_spec] + + torch.testing.assert_close(sliced_reconstructed, expected, atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize( + "slice_dim,slice_spec,expected_error", + [ + # Row slicing with misaligned boundaries + pytest.param( + 0, + slice(0, 100), + "Row slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_row_end", + ), + pytest.param( + 0, + slice(50, 150), + "Row slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_row_start", + ), + # Column slicing with misaligned boundaries + pytest.param( + 1, + slice(0, 32), + "Column slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_col_32", + ), + pytest.param( + 1, + slice(16, 80), + "Column slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_col_start", + ), + pytest.param( + 1, + slice(0, 100), + "Column slicing of NVFP4Tensor with swizzled scales requires", + id="misaligned_col_end", + ), + # Odd column boundaries (FP4 packing requirement) + pytest.param( + 1, + slice(1, 65), + "start index to be a multiple of 64, got 1", + id="odd_start", + ), + pytest.param( + 1, + slice(0, 65), + " multiple of 64 or equal to tensor size 4096, got 65", + id="odd_end", + ), + ], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_error): + """ + Test that slicing raises appropriate errors for misaligned boundaries. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = 256, 4096 + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + + with pytest.raises(RuntimeError, match=expected_error): + if slice_dim == 0: + _ = tensor[slice_spec, :] + else: + _ = tensor[:, slice_spec] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_view_semantics(): + """ + Test that slicing maintains proper view semantics where possible. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = 256, 4096 + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + + # Test row slicing (should maintain views) + sliced_tensor = tensor[0:128, :] + + # Test that the sliced tensor shares storage with original for data + # (Note: scales might not share storage due to swizzled layout complexity) + assert sliced_tensor.qdata.data_ptr() == tensor.qdata.data_ptr() + + # Test full-width column slicing (should maintain views) + full_width_slice = tensor[:, 0:K] + assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr() + assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_serialization(): + """ + Test that tensor flatten/unflatten preserves the swizzled scales state. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = 32, 64 + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + # Create tensor with swizzled scales + original_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + + # Test serialization + tensor_list, ctx = original_tensor.__tensor_flatten__() + + # Verify swizzled flag is preserved in context + assert "_is_swizzled_scales" in ctx + assert ctx["_is_swizzled_scales"] == True + + # Test deserialization + inner_tensors = {} + for name in tensor_list: + inner_tensors[name] = getattr(original_tensor, name) + + reconstructed_tensor = NVFP4Tensor.__tensor_unflatten__( + inner_tensors, ctx, None, None + ) + + # Verify the swizzled state is preserved + assert reconstructed_tensor._is_swizzled_scales == True + + # Verify functionality is preserved + original_dq = original_tensor.to_dtype(torch.bfloat16) + reconstructed_dq = reconstructed_tensor.to_dtype(torch.bfloat16) + + torch.testing.assert_close(original_dq, reconstructed_dq, atol=1e-6, rtol=1e-6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_swizzled_scales_get_scales_method(): + """ + Test that the get_scales() method correctly unswizzles scales when needed. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + M, K = 32, 64 + data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + # Create tensors with both storage methods + regular_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=False) + swizzled_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True) + + # Get scales from both tensors and verify they are equal + regular_scales = regular_tensor.get_hp_scales() + swizzled_scales = swizzled_tensor.get_hp_scales() + torch.testing.assert_close(regular_scales, swizzled_scales, atol=0.0, rtol=0.0) + + # Verify scales have the expected shape + expected_shape = (M, K // 16) + assert regular_scales.shape == expected_shape + assert swizzled_scales.shape == expected_shape + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}" +) +@pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}") +@pytest.mark.parametrize( + "use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"] +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="requires sm100+ for raw intrinsics" +) +@torch.no_grad() +def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): + """Test that Triton and PyTorch NVFP4 quantization produce equivalent results.""" + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + per_tensor_amax_to_scale, + unpack_uint4, + ) + + torch.manual_seed(42) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + per_tensor_scale = None + if use_per_tensor_scale: + per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x))) + + nvfp4_pt = NVFP4Tensor.to_nvfp4( + x.clone(), + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=True, + use_triton_kernel=False, + ) + + nvfp4_triton = NVFP4Tensor.to_nvfp4( + x.clone(), + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=True, + use_triton_kernel=True, + ) + + torch.testing.assert_close( + nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten() + ) + pt_unpacked = unpack_uint4(nvfp4_pt.qdata) + triton_unpacked = unpack_uint4(nvfp4_triton.qdata) + torch.testing.assert_close( + pt_unpacked, + triton_unpacked, + atol=0, + rtol=0, + ) + + x_pt_dequant = nvfp4_pt.to_dtype(dtype) + x_triton_dequant = nvfp4_triton.to_dtype(dtype) + + sqnr = compute_error(x_pt_dequant, x_triton_dequant) + SQNR_THRESHOLD = 40.0 + + assert sqnr >= SQNR_THRESHOLD, ( + f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD} for M={M}, N={N}, " + f"use_per_tensor_scale={use_per_tensor_scale}, dtype={dtype}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" +) +@pytest.mark.parametrize("use_gelu", [True, False]) +@pytest.mark.parametrize( + "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] +) +@pytest.mark.parametrize("compile", [False]) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("use_triton_kernel", [True, False]) +@pytest.mark.parametrize( + "shapes", + [ + (128, 64, 256), + (256, 128, 512), + (157, 64, 256), + (128, 96, 256), + (128, 160, 256), + (64, 64, 256), + (200, 192, 256), + ], + ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}", +) +@torch.no_grad() +@skip_if_rocm("ROCm float4 gemm require gfx950") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for fp4" +) +def test_nvfp4_matmul_with_amax( + use_gelu: bool, + mm_config: NVFP4MMConfig, + compile: bool, + bias: bool, + inpt_dtype: torch.dtype, + use_triton_kernel: bool, + shapes: tuple, +): + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + per_tensor_amax_to_scale, + ) + + # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs + if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100(): + pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm") + + if bias and inpt_dtype == torch.float32: + pytest.xfail("Bias is not supported when module weight is in fp32") + + if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: + pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") + + m, k, n = shapes + + # Create activation tensor + if use_gelu: + x = torch.randn(m, k, dtype=inpt_dtype, device="cuda") + A = torch.nn.functional.gelu(x) + else: + A = torch.randn(m, k, dtype=inpt_dtype, device="cuda") + + B = torch.randn(n, k, dtype=inpt_dtype, device="cuda") + bias_tensor = torch.randn(n, dtype=inpt_dtype, device="cuda") if bias else None + + # Compute reference + C_ref = F.linear(A, B, bias_tensor) + + a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A))) + b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B))) + act_quant_kwargs = None + if mm_config == NVFP4MMConfig.DYNAMIC: + act_quant_kwargs = QuantizeTensorToNVFP4Kwargs() + A_nvfp4 = NVFP4Tensor.to_nvfp4( + A, + per_tensor_scale=a_scale, + is_swizzled_scales=True, + use_triton_kernel=use_triton_kernel, + ) + B_nvfp4 = NVFP4Tensor.to_nvfp4( + B, + per_tensor_scale=b_scale, + is_swizzled_scales=True, + use_triton_kernel=use_triton_kernel, + act_quant_kwargs=act_quant_kwargs, + ) + + func = torch.compile(F.linear, fullgraph=True) if compile else F.linear + + C_nvfp4 = func(A_nvfp4, B_nvfp4, bias_tensor) + assert C_nvfp4.dtype == inpt_dtype, ( + f"Got {C_nvfp4.dtype} for inpt_dtype={inpt_dtype}" + ) + + sqnr = compute_error(C_ref, C_nvfp4) + SQNR_THRESHOLD = 16.0 + assert sqnr >= SQNR_THRESHOLD, ( + f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +def test_nvfp4_to_copy(): + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor + + x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda() + y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16) + assert torch.equal(x.qdata, y.qdata) + assert torch.equal(x._scale_e4m3, y._scale_e4m3) + assert x._per_tensor_scale is None + assert y._per_tensor_scale is None + assert x._act_per_tensor_scale is None + assert y._act_per_tensor_scale is None + assert x._block_size == y._block_size + assert x.use_triton_kernel == y.use_triton_kernel + assert x.act_quant_kwargs == y.act_quant_kwargs + assert x.dtype == torch.float32 + assert y.dtype == torch.bfloat16 diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py new file mode 100644 index 0000000000..b67bf2bf0c --- /dev/null +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -0,0 +1,65 @@ +import json +import tempfile +import unittest + +import torch +from safetensors.torch import load_file, save_file +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao import quantize_ +from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, +) +from torchao.quantization.granularity import PerRow +from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig +from torchao.utils import ( + is_sm_at_least_89, +) + + +def load_data(file_path: str, device: str): + loaded_tensors = load_file(file_path, device) + with open(file_path, "rb") as f: + import struct + + header_size = struct.unpack("= 8.3.0" +) +class TestCodebookQuantization(unittest.TestCase): + def setUp(self): + torch.manual_seed(123) + self.input = torch.randn(100, 256, dtype=torch.float32) + self.code_dtype = torch.uint8 + self.block_size = [-1, 4] + self.nbits = 8 + + def test_choose_qparams_codebook(self): + codebook, wq = choose_qparams_and_quantize_codebook_coreml( + self.input, + self.code_dtype, + self.block_size, + ) + group_size = self.block_size[-1] + self.assertEqual(codebook.shape, (1, 256 // group_size, 2**self.nbits, 1)) + self.assertEqual(wq.shape, (100, 256)) + + self.assertFalse(torch.isnan(codebook).any()) + self.assertFalse(torch.isnan(wq).any()) + + def test_codebook_quantized_tensor_from_float(self): + cqt = CodebookQuantizedTensor.from_float( + self.input, + self.code_dtype, + self.block_size, + ) + + dequant = cqt.dequantize() + sqnr = compute_error(dequant, self.input) + self.assertGreater(sqnr, 30) + + def test_codebook_quantized_tensor_from_float2(self): + block_size = [-1, 16] + code_dtype = torch.uint4 + + cqt = CodebookQuantizedTensor.from_float( + self.input, + code_dtype, + block_size, + ) + + dequant = cqt.dequantize() + + sqnr = compute_error(dequant, self.input) + self.assertGreater(sqnr, 18) + + def test_quantize_api(self): + m = torch.nn.Sequential(torch.nn.Linear(64, 64)) + quantize_( + m, + CodebookWeightOnlyConfig(dtype=self.code_dtype, block_size=self.block_size), + ) + assert type(m[0].weight) == CodebookQuantizedTensor + + def test_choose_qparams_codebook_row_grouping(self): + # Test with a block_size that forces row-wise grouping: [10, 256] + # Input tensor is (100, 256) + row_grouped_block_size = [10, -1] + num_row_groups = ( + self.input.shape[0] // row_grouped_block_size[0] + ) # 100 // 10 = 10 + + codebook, wq = choose_qparams_and_quantize_codebook_coreml( + self.input, + self.code_dtype, + row_grouped_block_size, + ) + + # Expected shape for row-wise grouping is (num_row_groups, 1, 2**nbits, 1) + self.assertEqual(codebook.shape, (num_row_groups, 1, 2**self.nbits, 1)) + self.assertEqual(wq.shape, (100, 256)) + + self.assertFalse(torch.isnan(codebook).any()) + self.assertFalse(torch.isnan(wq).any()) + + def test_codebook_quantized_tensor_from_float_row_grouping(self): + # Test end-to-end quantization/dequantization with row grouping + row_grouped_block_size = [20, -1] # 100 is divisible by 20 + cqt = CodebookQuantizedTensor.from_float( + self.input, + self.code_dtype, + row_grouped_block_size, + ) + + dequant = cqt.dequantize() + # The SQNR will be different from column grouping, but should still be high + sqnr = compute_error(dequant, self.input) + self.assertGreater(sqnr, 30) + + def test_export(self): + m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32) + quantize_(m, CodebookWeightOnlyConfig(self.code_dtype, self.block_size)) + example_inputs = (torch.randn(1, 128, dtype=torch.float32),) + m = torch.export.export(m, example_inputs).module() + targets = [n.target for n in m.graph.nodes] + self.assertTrue(torch.ops.quant.dequantize_codebook.default in targets) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/test/prototype/test_embedding.py similarity index 95% rename from torchao/experimental/tests/test_embedding_xbit_quantizer.py rename to test/prototype/test_embedding.py index 442612410e..7d020920a7 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/test/prototype/test_embedding.py @@ -12,18 +12,15 @@ from parameterized import param, parameterized from torch.testing import FileCheck -from torchao.dtypes import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, -) -from torchao.experimental.quant_api import ( +from torchao.prototype.quantization.embedding.api import ( EmbeddingQuantizer, - SharedEmbeddingQuantizer, + TiedEmbeddingQuantizer, ) from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import ( - FakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, Int4WeightOnlyEmbeddingQATQuantizer, + IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, ) from torchao.quantization.quant_api import ( @@ -32,9 +29,13 @@ MappingType, quantize_, ) +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + _is_kernel_library_loaded, +) from torchao.quantization.utils import compute_error +@unittest.skipIf(not _is_kernel_library_loaded(), "Need torchao lowbit kernels") class TestEmbeddingQuantizer(unittest.TestCase): def test_accuracy(self): granularity = PerGroup(128) @@ -152,16 +153,14 @@ def test_shared_embedding(self): weight_dtype=weight_dtype, weight_granularity=PerAxis(0), weight_mapping_type=weight_mapping_type, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout( - target="universal" - ), + intx_packing_format="opaque_torchao_lowbit", ), filter_fn=lambda m, fqn: fqn == "2", ) # Do shared embedding quantization quantized_model = copy.deepcopy(model) - SharedEmbeddingQuantizer( + TiedEmbeddingQuantizer( weight_dtype=weight_dtype, granularity=PerAxis(0), mapping_type=weight_mapping_type, @@ -183,10 +182,9 @@ def test_shared_embedding(self): self.assertTrue(torch.allclose(result, exported_result)) # Check the shared_embedding and linear ops use the same lifted weight - weight = "b_getattr_l__fn_____0___unembedding_packed_weights" expected_lines = [ - f"torch.ops.torchao._shared_embedding_4bit.default({weight}, 4096, 131, 4096, reshape)", - f"torch.ops.torchao._linear_8bit_act_4bit_weight.default(linear, {weight}, 4096, 131, 4096)", + "torch.ops.torchao._shared_embedding_4bit.default", + "torch.ops.torchao._linear_8bit_act_4bit_weight.default", ] for line in expected_lines: FileCheck().check_count(line, 1, exactly=True).run( @@ -282,7 +280,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( ) embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( weight_dtype, group_size=group_size, is_symmetric=is_symmetric, diff --git a/test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py b/test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py new file mode 100644 index 0000000000..25d5398c50 --- /dev/null +++ b/test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import tempfile +import unittest + +import torch +import torch.nn as nn +from parameterized import param, parameterized +from torch import uint1, uint2, uint3, uint4 + +from torchao.prototype.quantization.codebook_groupwise.api import ( + GroupwiseLutWeightConfig, +) +from torchao.prototype.quantization.codebook_utils.codebook_utils import ( + group_size_to_block_shapes, +) +from torchao.quantization.quant_api import quantize_ +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + _is_kernel_library_loaded, +) + + +@unittest.skipIf(not _is_kernel_library_loaded(), "Need torchao lowbit kernels") +class TestGroupwiseLowbitWeightLut(unittest.TestCase): + """ + Test suite for the GroupwiseLutWeight quantization scheme, updated for the + new simplified API. + """ + + TEST_CASES = [ + param( + code_dtype=code_dtype, + lut_group_size=lut_group_size, + weight_dtype=weight_dtype, + has_bias=has_bias, + ) + for code_dtype in [uint1, uint2, uint3, uint4] + for lut_group_size in [256, 512] + for weight_dtype in [torch.float32] + for has_bias in [True, False] + ] + + # -------------------------------------------------------------------------- + # Test 1: End-to-End Model Accuracy + # -------------------------------------------------------------------------- + @parameterized.expand(TEST_CASES) + def test_e2e_accuracy_vs_reference( + self, + code_dtype, + lut_group_size, + weight_dtype, + has_bias, + ): + """ + Tests the numerical accuracy of the full quantized model against a reference. + This now uses the `use_qdq_reference` flag instead of layout objects. + """ + m, k, n = 3, 64, 32 + activations = torch.randn(m, k, dtype=weight_dtype) + model = nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=weight_dtype) + + # --- 2. Update tensor_shape to reflect the new (k, n) layout --- + lut_block_shape = group_size_to_block_shapes( + lut_group_size=lut_group_size, tensor_shape=(n, k) + ) + + # --- Quantize using C++ ops --- + quantized_model = copy.deepcopy(model) + perf_config = GroupwiseLutWeightConfig( + code_dtype=code_dtype, + weight_dtype=weight_dtype, + lut_block_shape=lut_block_shape, + use_qdq_reference=False, + ) + quantize_(quantized_model, perf_config) + with torch.no_grad(): + actual_result = quantized_model(activations) + + # --- Quantize for Reference (using Python ops) --- + reference_model = copy.deepcopy(model) + ref_config = GroupwiseLutWeightConfig( + code_dtype=code_dtype, + weight_dtype=weight_dtype, + lut_block_shape=lut_block_shape, + use_qdq_reference=True, + ) + quantize_(reference_model, ref_config) + with torch.no_grad(): + expected_result = reference_model(activations) + # Compare results + self.assertTrue( + torch.allclose(actual_result, expected_result, atol=1e-2, rtol=1e-2) + ) + + def tearDown(self): + """ + Clear the TorchDynamo cache after each test case to prevent + recompilation errors in parameterized tests. + """ + super().tearDown() + torch._dynamo.reset() + + # -------------------------------------------------------------------------- + # Test 2: Deployment Readiness (Updated for new API) + # -------------------------------------------------------------------------- + @parameterized.expand(TEST_CASES) + def test_export_compile_aoti( + self, + code_dtype, + lut_group_size, + weight_dtype, + has_bias, + ): + """ + Tests that the quantized model can be exported and compiled. + """ + k, n = 64, 32 + activations = torch.randn(2, k, dtype=weight_dtype) + model = ( + nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=weight_dtype).eval() + ) + lut_block_shape = group_size_to_block_shapes( + lut_group_size=lut_group_size, + tensor_shape=(n, k), + ) + + # Configure the quantization using the new API + config = GroupwiseLutWeightConfig( + code_dtype=code_dtype, + weight_dtype=weight_dtype, + lut_block_shape=lut_block_shape, + use_qdq_reference=False, + ) + quantize_(model, config) + + with torch.no_grad(): + eager_results = model(activations) + + # Export and Compile + exported_model = torch.export.export(model, (activations,)) + compiled_model = torch.compile(model, fullgraph=True) + + with tempfile.TemporaryDirectory() as tmpdir, torch.no_grad(): + # Check exported model + exported_results = exported_model.module()(activations) + self.assertTrue( + torch.allclose(eager_results, exported_results, atol=1e-3, rtol=1e-3) + ) + + # Check compiled model + compiled_results = compiled_model(activations) + self.assertTrue( + torch.allclose(eager_results, compiled_results, atol=1e-3, rtol=1e-3) + ) + + # Check AOTI compiled model using the packaging API + package_path = f"{tmpdir}/model.pt2" + torch._inductor.aoti_compile_and_package( + exported_model, package_path=package_path + ) + aoti_model = torch._inductor.aoti_load_package(package_path) + aoti_results = aoti_model(activations) + self.assertTrue( + torch.allclose(eager_results, aoti_results, atol=1e-3, rtol=1e-3) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/prototype/test_int8_lut_tensor.py b/test/prototype/test_int8_lut_tensor.py new file mode 100644 index 0000000000..b5d1a6b0a1 --- /dev/null +++ b/test/prototype/test_int8_lut_tensor.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy + +import pytest +import torch + +from torchao.prototype.parq.quant import ( + StretchedIntxWeightConfig, + StretchedUnifTorchaoQuantizer, +) +from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import ( + _is_kernel_library_loaded, +) +from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64 +from torchao.quantization import quantize_ +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.utils import compute_error + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, d1=512, d2=256, d3=128, d4=8): + super().__init__() + self.linear1 = torch.nn.Linear(d1, d2, bias=False) + self.linear2 = torch.nn.Linear(d2, d3, bias=True) + self.linear3 = torch.nn.Linear(d3, d4, bias=False) + + def example_inputs( + self, + lead_dim=(1,), + dtype=torch.bfloat16, + ): + return torch.randn( + *lead_dim, self.linear1.in_features, dtype=dtype, device="cpu" + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + yield + torch._dynamo.reset() # reset cache between tests + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) +@pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) +@pytest.mark.parametrize("lead_dim", [(5,), (2, 3)]) +@pytest.mark.skipif( + not _is_kernel_library_loaded(), reason="Kernel library is not loaded" +) +def test_parq_conversion(dtype, granularity, bit_width, lead_dim): + torch.manual_seed(0) + quantizer = StretchedUnifTorchaoQuantizer(bit_width) + config = StretchedIntxWeightConfig( + b=bit_width, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=granularity, + activation_quantization="int8_asym_per_token", + ) + + parq_model = ToyLinearModel(128, 256, 128, 1).to(dtype) + activations = parq_model.example_inputs(lead_dim=lead_dim, dtype=dtype) + quantize_(parq_model, config) + + # Convert PARQ model to lowbit LUT model + lut_model = deepcopy(parq_model) + _convert_model_for_aarch64(lut_model, tensor_type="int8_lut_tensor") + + # Run both models and compare + parq_out = parq_model(activations) + lut_out = lut_model(activations) + + sqnr = compute_error(parq_out, lut_out).item() + if dtype == torch.float32: + assert sqnr > 40.0, f"sqnr {sqnr} is too low" + elif dtype == torch.bfloat16: + assert sqnr > 25.0, f"sqnr {sqnr} is too low" + else: + raise ValueError(f"Unsupported dtype {dtype}") + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) +@pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) +@pytest.mark.parametrize("lead_dim", [(5,), (2, 3)]) +@pytest.mark.skipif( + not _is_kernel_library_loaded(), reason="Kernel library is not loaded" +) +def test_export(dtype, granularity, bit_width, lead_dim): + quantizer = StretchedUnifTorchaoQuantizer(bit_width) + config = StretchedIntxWeightConfig( + b=bit_width, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=granularity, + activation_quantization="int8_asym_per_token", + ) + + parq_model = ToyLinearModel(128, 256, 128, 8).to(dtype) + activations = parq_model.example_inputs(lead_dim=lead_dim) + quantize_(parq_model, config) + + _convert_model_for_aarch64(parq_model) + + ep = torch.export.export(parq_model, (activations,)) + + assert ( + f"torch.ops.torchao._linear_8bit_act_{bit_width}bit_weight.default" + in ep.graph_module.code + ) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 68c25821ee..10004a03f9 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import copy +import tempfile import unittest from typing import Optional @@ -11,7 +12,6 @@ from torch import nn from torch.testing._internal import common_utils -from torchao.core.config import AOBaseConfig from torchao.dtypes import Int4CPULayout from torchao.prototype.parq.optim import ( ProxHardQuant, @@ -21,118 +21,291 @@ from torchao.prototype.parq.quant import ( Int4UnifTorchaoQuantizer, LSBQuantizer, + Quantizer, + StretchedIntxWeightConfig, + StretchedUnifTorchaoQuantizer, TernaryUnifQuantizer, UnifQuantizer, UnifTorchaoQuantizer, ) -from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE -from torchao.quantization.granularity import PerGroup -from torchao.quantization.qat import ( - FakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, +from torchao.prototype.parq.quant.config_torchao import ( + TRANSFORMERS_AVAIL, + _attach_hf_quantization_config, + _is_hf_model, ) +from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, - MappingType, _is_linear, - int4_weight_only, quantize_, ) +from torchao.quantization.quant_primitives import MappingType +from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_6, + _is_fbgemm_genai_gpu_available, check_cpu_version, + is_sm_at_least_90, + torch_version_at_least, ) _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -def split_param_groups(model): - params_quant, params_no_quant = [], [] +class M(nn.Module): + _tied_weights_keys: list[str] = [] + + def __init__( + self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False + ): + nn.Module.__init__(self) + self.embed_tokens = nn.Embedding(k, m) if embedding else nn.Identity() + self.linear1 = nn.Linear(m, n, bias=bias) + self.linear2 = nn.Linear(n, k, bias=bias) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + if embedding and tied_weights: + assert self.embed_tokens.weight.shape == self.linear2.weight.shape + self.tie_weights() + self._tied_weights_keys.append("linear2.weight") + + def tie_weights(self): + self.linear2.weight = self.embed_tokens.weight + + def example_inputs(self, device=None): + if isinstance(self.embed_tokens, nn.Identity): + inputs = torch.randn(1, self.linear1.in_features, device=device) + else: + k = self.embed_tokens.num_embeddings + inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device) + return inputs + + def forward(self, x): + x = self.embed_tokens(x) + x = self.relu(self.linear1(x)) + x = self.sigmoid(self.linear2(x)) + return x + + +if TRANSFORMERS_AVAIL: + from transformers import PretrainedConfig, PreTrainedModel, TorchAoConfig + + class MConfig(PretrainedConfig): + def __init__( + self, + m=256, + n=128, + k=16, + bias=False, + embedding=True, + tied_weights=False, + **kwargs, + ): + super().__init__(**kwargs) + self.m = m + self.n = n + self.k = k + self.bias = bias + self.embedding = embedding + self.tied_weights = tied_weights + + class PreTrainedM(M, PreTrainedModel): + base_model_prefix = "base" + config_class = MConfig + + def __init__(self, config: MConfig): + PreTrainedModel.__init__(self, config) + M.__init__( + self, + m=config.m, + n=config.n, + k=config.k, + bias=config.bias, + embedding=config.embedding, + tied_weights=config.tied_weights, + ) + + def get_input_embeddings(self) -> nn.Module: + return self.embed_tokens + + +def split_param_groups(model) -> tuple[list, list, list]: + params_quant, params_embed, params_no_quant = [], [], [] def get_param_groups(model): + seen_data_ptrs = set() # avoid duplicates in case of tied weights for module in model.children(): is_linear = _is_linear(module) for n, p in module.named_parameters(): + if n == "weight": + data_ptr = p.data_ptr() + if data_ptr in seen_data_ptrs: + continue + seen_data_ptrs.add(data_ptr) + if is_linear and n == "weight": params_quant.append(p) + elif isinstance(module, nn.Embedding) and n == "weight": + params_embed.append(p) else: params_no_quant.append(p) get_param_groups(model) - return params_quant, params_no_quant - - -def build_param_groups(model, b: int = 2, group_size: Optional[int] = None): - params_quant, params_no_quant = split_param_groups(model) - quant_kwargs = {"quant_block_size": group_size} if group_size else {} - return [ + return params_quant, params_embed, params_no_quant + + +def build_param_groups( + model, + b: int = 2, + group_size: Optional[int] = None, + quantizer: Optional[Quantizer] = None, +): + params_quant, params_embed, params_no_quant = split_param_groups(model) + quant_kwargs = {} + if group_size: + quant_kwargs["quant_block_size"] = group_size + if quantizer is not None: + quant_kwargs["quantizer"] = quantizer + param_groups = [ {"params": params_quant, "quant_bits": b, **quant_kwargs}, {"params": params_no_quant}, ] - - -class M(nn.Module): - def __init__(self, m=256, n=128, k=16, bias=False, embedding=True): - super().__init__() - self.embedding = nn.Embedding(10, m) if embedding else nn.Identity() - self.linear1 = nn.Linear(m, n, bias=bias) - self.linear2 = nn.Linear(n, k, bias=bias) - self.relu = nn.ReLU() - self.sigmoid = nn.Sigmoid() - - def reset_parameters(self): - for module in (self.linear1, self.linear2): - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - - def example_inputs(self, device=None): - return ( - torch.randint(1, 10, (1, self.linear1.in_features), device=device) - if isinstance(self.embedding, nn.Embedding) - else torch.randn(1, self.linear1.in_features, device=device) + if params_embed: + param_groups.append( + { + "params": params_embed, + "quant_bits": 4, + "quantizer": UnifTorchaoQuantizer(), + } ) - - def forward(self, x): - x = self.embedding(x) - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - x = self.sigmoid(x) - return x + return param_groups + + +def compare_quantized_models( + model: nn.Module, + m_ref: nn.Module, + quantizer: UnifTorchaoQuantizer, + b: int, + group_size: int, +): + for n, module in model.named_children(): + if not _is_linear(module): + continue + + # simulate grouping from QuantOptimizer.step + p = module.weight + original_shape = p.shape + p = p.view(-1, group_size) + + q, Q = quantizer.quantize(p, b=b, dim=-1) + + # compare to AffineQuantizedTensor instance + q = q.view(original_shape) + ref = getattr(m_ref, n).weight.dequantize() + torch.testing.assert_close(q, ref, atol=0, rtol=0) + + +def compare_parq_convert( + model: nn.Module, + m_ref: nn.Module, + optimizer: QuantOptimizer, + weight_only: bool = False, +): + # do not update model weights, just quantize + optimizer.zero_grad() + optimizer.step() + + orig_model = copy.deepcopy(model) # save copy of PARQ quantized model + + # equivalent to torchao's convert step + optimizer.torchao_convert(model, weight_only=weight_only) + + inputs = model.example_inputs(device=_DEVICE) + torch.testing.assert_close(model(inputs), orig_model(inputs)) + + for n, module in model.named_modules(): + if not _is_linear(module): + continue + + p_orig = getattr(orig_model, n).weight # PARQ weight + p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_ + torch.testing.assert_close(p_orig, p_ref, atol=0, rtol=0) + + p = module.weight.dequantize() # PARQ weight after quantize_ + torch.testing.assert_close(p, p_ref, atol=0, rtol=0) + + +def check_torchao_tensor_subclass( + test_case: common_utils.TestCase, model: nn.Module, weight_only: bool = False +): + for name, module in model.named_modules(): + if not hasattr(module, "weight") or f"{name}.weight" in getattr( + model, "_tied_weights_keys", [] + ): + continue + + if not weight_only and _is_linear(module): + test_case.assertTrue(isinstance(module.weight, IntxUnpackedToInt8Tensor)) + test_case.assertTrue( + module.weight.activation_quantization == "int8_asym_per_token" + ) + elif weight_only and _is_linear(module) or isinstance(module, nn.Embedding): + test_case.assertTrue(isinstance(module.weight, IntxUnpackedToInt8Tensor)) + test_case.assertTrue(module.weight.activation_quantization is None) + + +def apply_activation_quantization( + model: nn.Module, optimizer: torch.optim.Optimizer, model_dtype: torch.dtype +): + # apply torchao quantized activations on top + activation_config = IntxFakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False, scale_precision=model_dtype + ) + qat_config = QATConfig(activation_config=activation_config, step="prepare") + for filter_fn in optimizer.get_filter_fns(model): + try: + quantize_(model, qat_config, filter_fn=filter_fn) + except ValueError as e: + if str(e) == "Activation fake quantization is not supported for embedding": + pass class TestPARQuantization(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - self.model = M(bias=True).to(_DEVICE) @common_utils.parametrize("b", [0, 1, 2, 4]) @common_utils.parametrize("unif_quant", [True, False]) @common_utils.parametrize("hard_prox", [True, False]) - def test_parq_train_loop(self, b: int = 2, unif_quant=True, hard_prox=True): - self.model.reset_parameters() - param_groups = build_param_groups(self.model, b) - base_optimizer = torch.optim.AdamW(param_groups) - + @common_utils.parametrize("per_group_quantizer", [True, False]) + def test_parq_train_loop( + self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quantizer=False + ): + model = M(bias=True).to(_DEVICE) if unif_quant: quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer() else: quantizer = LSBQuantizer() + param_groups = build_param_groups( + model, b, quantizer=quantizer if per_group_quantizer else None + ) + base_optimizer = torch.optim.AdamW(param_groups) + prox_map = ( ProxHardQuant() if hard_prox else ProxPARQ(anneal_start=0, anneal_end=2) ) optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map) for _ in range(3): - x = self.model.example_inputs(device=_DEVICE) - out = self.model(x) + x = model.example_inputs(device=_DEVICE) + out = model(x) out.sum().backward() optimizer.step() - for child in self.model.children(): + for child in model.children(): if isinstance(child, nn.Linear): self.assertEqual( child.weight.unique().numel(), quantizer.get_quant_size(b) @@ -143,82 +316,31 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - def compare_quantized_models( - self, - model: nn.Module, - m_ref: nn.Module, - quantizer: UnifTorchaoQuantizer, - b: int, - group_size: int, - ): - for n, module in model.named_children(): - if not _is_linear(module): - continue - - # simulate grouping from QuantOptimizer.step - p = module.weight - original_shape = p.shape - p = p.view(-1, group_size) - - q, Q = quantizer.quantize(p, b=b, dim=-1) - - # compare to AffineQuantizedTensor instance - q = q.view(original_shape) - ref = getattr(m_ref, n).weight.dequantize() - torch.testing.assert_close(q, ref, atol=0, rtol=0) - - def compare_parq_convert( - self, - model: nn.Module, - m_ref: nn.Module, - optimizer: QuantOptimizer, - config: AOBaseConfig, - ): - # do not update model weights, just quantize - optimizer.zero_grad() - optimizer.step() - - orig_model = copy.deepcopy(model) # save copy of PARQ quantized model - - # equivalent to torchao's convert step - model.eval() - optimizer.restore_latent_params() - quantize_(model, config, filter_fn=optimizer.get_filter_fn(model)) - - for n, module in model.named_modules(): - if not _is_linear(module): - continue - - p_orig = getattr(orig_model, n).weight # PARQ weight - p = module.weight.dequantize() # PARQ weight after quantize_ - p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_ - - torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0) - torch.testing.assert_true(p, p_ref, atol=0, rtol=0) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch >= 2.8.0") + @unittest.skipIf(not is_sm_at_least_90(), "Need sm >= 90") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) @common_utils.parametrize("group_size", [32, 256]) def test_int4_weight_only(self, group_size: int = 32): model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16) - model.reset_parameters() m_ref = copy.deepcopy(model).eval().to(_DEVICE) - config = int4_weight_only(group_size=group_size) + config = Int4WeightOnlyConfig(group_size=group_size) if check_cpu_version(_DEVICE): config.layout = Int4CPULayout() + config.version = 1 quantize_(m_ref, config) b = 4 - self.compare_quantized_models( + compare_quantized_models( model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("b", [2, 3, 4, 8]) @common_utils.parametrize("group_size", [32, 512]) def test_intx_weight_only(self, b: int = 2, group_size: int = 32): model = M(m=512, n=512).to(_DEVICE) - model.reset_parameters() m_ref = copy.deepcopy(model).eval().to(_DEVICE) quantize_( @@ -229,18 +351,18 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32): ) quantizer = UnifTorchaoQuantizer() - self.compare_quantized_models(model, m_ref, quantizer, b, group_size) + compare_quantized_models(model, m_ref, quantizer, b, group_size) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") + @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch >= 2.8.0") + @unittest.skipIf(not is_sm_at_least_90(), "Need sm >= 90") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) def test_int4_weight_only_e2e(self, group_size: int = 32): - model = M(m=512, n=512).to(torch.bfloat16).to(_DEVICE) - model.reset_parameters() + model = M(m=512, n=512, embedding=False).to(torch.bfloat16).to(_DEVICE) m_ref = copy.deepcopy(model).eval().to(_DEVICE) - config = int4_weight_only(group_size=group_size) - if check_cpu_version(_DEVICE): - config.layout = Int4CPULayout() + config = Int4WeightOnlyConfig(group_size=group_size) quantize_(m_ref, config) b = 4 @@ -251,14 +373,12 @@ def test_int4_weight_only_e2e(self, group_size: int = 32): ProxHardQuant(), quant_per_channel=True, ) - self.compare_parq_convert(model, m_ref, optimizer, config) + compare_parq_convert(model, m_ref, optimizer, weight_only=True) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") @common_utils.parametrize("b", [2, 3, 4, 8]) def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): - model = M(m=512, n=512).to(_DEVICE) - model.reset_parameters() + model = M(m=512, n=512, embedding=False).to(_DEVICE) m_ref = copy.deepcopy(model).eval().to(_DEVICE) config = IntxWeightOnlyConfig( @@ -273,16 +393,119 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): ProxHardQuant(), quant_per_channel=True, ) - self.compare_parq_convert(model, m_ref, optimizer, config) + compare_parq_convert(model, m_ref, optimizer, weight_only=True) + check_torchao_tensor_subclass(self, model, weight_only=True) + + +class TestStretchedUnifTorchaoQuantizer(common_utils.TestCase): + def setUp(self): + torch.manual_seed(123) + + @common_utils.parametrize("b", [2, 3]) + @common_utils.parametrize("group_size", [32, 256]) + def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512).to(_DEVICE) + + quantizer_ref = UnifQuantizer() + quantizer = StretchedUnifTorchaoQuantizer(b) + + for module in model.children(): + if not _is_linear(module): + continue + + # simulate grouping from QuantOptimizer.step + p = module.weight + p = p.view(-1, group_size) + + q_ref, Q_ref = quantizer_ref.quantize(p, b=b, dim=-1) + q, Q = quantizer.quantize(p, b=b, dim=-1) + + torch.testing.assert_close(q, q_ref, atol=0, rtol=0) + torch.testing.assert_close(Q, Q_ref, atol=0, rtol=0) + + @common_utils.parametrize("b", [2, 3]) + @common_utils.parametrize("group_size", [32, 512]) + def test_intx_weight_only(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512).to(_DEVICE) + + quantizer = StretchedUnifTorchaoQuantizer(b) + + m_ref = copy.deepcopy(model).eval().to(_DEVICE) + quantize_( + m_ref, + StretchedIntxWeightConfig( + b=b, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=PerGroup(group_size), + activation_quantization=None, + ), + ) + + compare_quantized_models(model, m_ref, quantizer, b, group_size) + + @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") + @common_utils.parametrize("b", [2, 3]) + def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32): + model = M(m=512, n=512, embedding=False).to(_DEVICE) + + quantizer = StretchedUnifTorchaoQuantizer(b) + + m_ref = copy.deepcopy(model).eval().to(_DEVICE) + config = StretchedIntxWeightConfig( + b=b, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=PerGroup(group_size), + activation_quantization=None, + ) + quantize_(m_ref, config, filter_fn=_is_linear) + + base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size)) + optimizer = QuantOptimizer( + base_optimizer, + quantizer, + ProxHardQuant(), + quant_per_channel=True, + ) + compare_parq_convert(model, m_ref, optimizer, weight_only=True) + check_torchao_tensor_subclass(self, model, weight_only=True) + + @common_utils.parametrize("b", [2, 3]) + @common_utils.parametrize( + "model_dtype", [torch.float16, torch.float32, torch.bfloat16] + ) + def test_intx_weight_only_tied_embed_linear( + self, b: int = 2, model_dtype: torch.dtype = torch.float32 + ): + model = M(m=256, n=256, tied_weights=True).to(_DEVICE) + + quantizer = StretchedUnifTorchaoQuantizer(b) + base_optimizer = torch.optim.SGD(build_param_groups(model, b)) + optimizer = QuantOptimizer( + base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True + ) + optimizer.zero_grad() + optimizer.step() + + apply_activation_quantization(model, optimizer, model_dtype) + optimizer.torchao_convert(model) + check_torchao_tensor_subclass(self, model) + self.assertTrue( + torch.equal(model.embed_tokens.weight.qdata, model.linear2.weight.qdata) + ) class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase): def setUp(self): torch.manual_seed(123) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @unittest.skipIf(_DEVICE == "cpu", "Need GPU available") + @unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers") @common_utils.parametrize("b", [2, 3, 4, 8]) - @common_utils.parametrize("model_dtype", [torch.float16, torch.float32]) + @common_utils.parametrize( + "model_dtype", [torch.float16, torch.float32, torch.bfloat16] + ) @common_utils.parametrize("group_size", [32, 128]) def test_int8_dynamic_activation_intx_e2e( self, @@ -290,7 +513,8 @@ def test_int8_dynamic_activation_intx_e2e( model_dtype: torch.dtype = torch.float32, group_size: int = 32, ): - model = M(embedding=False).to(_DEVICE, dtype=model_dtype) + config = MConfig(embedding=False, bias=True) + model = PreTrainedM(config).to(_DEVICE, dtype=model_dtype) x = model.example_inputs(device=_DEVICE).to(model_dtype) # reference model using native quantization @@ -310,31 +534,79 @@ def test_int8_dynamic_activation_intx_e2e( optimizer = QuantOptimizer( base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True ) + optimizer.zero_grad() optimizer.step() - # apply torchao quantized activations on top - activation_config = FakeQuantizeConfig( - torch.int8, - granularity="per_token", - mapping_type=config.act_mapping_type, - ) - filter_fn = optimizer.get_filter_fn(model) - quantize_( - model, - IntXQuantizationAwareTrainingConfig(activation_config=activation_config), - filter_fn=filter_fn, - ) + apply_activation_quantization(model, optimizer, model_dtype) + out = model(x) torch.testing.assert_close(out, ref_out, atol=0, rtol=0) - # equivalent to torchao's convert step - model.eval() - optimizer.restore_latent_params() - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn) - quantize_(model, config, filter_fn=filter_fn) + attach_hf_config = False + if TRANSFORMERS_AVAIL: + attach_hf_config = _is_hf_model(model) + self.assertTrue(attach_hf_config) + + optimizer.torchao_convert(model) converted_out = model(x) - torch.testing.assert_close(converted_out, ref_out, atol=0, rtol=0) + torch.testing.assert_close(converted_out, ref_out) + check_torchao_tensor_subclass(self, model) + + if attach_hf_config: + reg_param_names = { + n for n, m in model.named_modules() if isinstance(m, nn.Embedding) + } + reg_param_names.add("_default") + module_fqn_to_config = ( + model.config.quantization_config.quant_type.module_fqn_to_config + ) + self.assertEqual(set(module_fqn_to_config.keys()), reg_param_names) + for torchao_config in module_fqn_to_config.values(): + self.assertTrue(isinstance(torchao_config, config.__class__)) + + +class TestTorchAoConfigIntegration(common_utils.TestCase): + @unittest.skipIf(torch.backends.mps.is_available(), "MPS not supported") + @unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers") + def test_tied_weights_quantization(self, b: int = 4): + config = MConfig(m=128, n=128, tied_weights=True) + model = PreTrainedM(config).to(_DEVICE) + + quantizer = StretchedUnifTorchaoQuantizer(b) + linear_config = StretchedIntxWeightConfig( + b=b, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=PerAxis(0), + ) + embed_config = IntxWeightOnlyConfig( + weight_dtype=_BIT_WIDTH_TO_DTYPE[b], granularity=PerGroup(32) + ) + module_to_config = {"_default": linear_config} + configs = [embed_config] + filter_fns = [lambda m: isinstance(m, nn.Embedding)] + _attach_hf_quantization_config(model, filter_fns, configs, module_to_config) + + quantization_config = getattr(model.config, "quantization_config", None) + self.assertTrue(isinstance(quantization_config, TorchAoConfig)) + self.assertTrue(quantization_config.modules_to_not_convert == ["linear2"]) + + # Let HF apply quantize_ given quantization_config + del model.config.quantization_config + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=False) + model = PreTrainedM.from_pretrained( + tmp_dir, quantization_config=quantization_config + ) + + check_torchao_tensor_subclass(self, model.linear1) + check_torchao_tensor_subclass(self, model.linear2, weight_only=True) + check_torchao_tensor_subclass(self, model.embed_tokens, weight_only=True) + + self.assertTrue( + model.linear2.weight.data_ptr() == model.embed_tokens.weight.data_ptr() + ) common_utils.instantiate_parametrized_tests(TestPARQuantization) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 264c70abb6..fa0edd694b 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -3,18 +3,13 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 - -if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("Requires torch>=2.4", allow_module_level=True) - import copy +import pytest import torch import torch.distributed as dist import torch.nn.functional as F +import torch.testing._internal.common_utils as common_utils from torch import nn from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -40,6 +35,9 @@ ) from torchao.quantization.quant_api import quantize_ +if common_utils.SEED is None: + common_utils.SEED = 1234 + _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -213,7 +211,9 @@ def test_int8_mixed_precision_training(self, compile, config, module_swap): def snr(ref, actual): error = actual - ref - return 20 * torch.log10(ref.norm() / error.norm()) + return 20 * torch.log10( + torch.linalg.vector_norm(ref) / torch.linalg.vector_norm(error) + ) assert snr(outputs_ref, outputs_int8mp) > 20 assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20 @@ -308,21 +308,19 @@ def test_fsdp2_correctness(self): (bitnet_training(), mp_policy, 1e-5), ] - # FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129 - if TORCH_VERSION_AT_LEAST_2_6: - # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. - # We would need to cast all params to BF16 in forward and backward pass, while keeping - # the params in FP32 for optim step. - # torch.autocast() will only do this for F.linear() layer (and its backward). - # To keep it simple, we just use a larger tolerance here. - bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) - - extra_args = [ - (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), - (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), - (bitnet_training(), bf16_mp_policy, 1e-2), - ] - test_args.extend(extra_args) + # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model. + # We would need to cast all params to BF16 in forward and backward pass, while keeping + # the params in FP32 for optim step. + # torch.autocast() will only do this for F.linear() layer (and its backward). + # To keep it simple, we just use a larger tolerance here. + bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) + + extra_args = [ + (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), + (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), + (bitnet_training(), bf16_mp_policy, 1e-2), + ] + test_args.extend(extra_args) self.run_subtests({"args": test_args}, self._run_subtest) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index a5265f7b1f..581f75b925 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -3,30 +3,21 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import tempfile +import unittest from copy import deepcopy -import pytest import torch +from torch.testing._internal import common_utils from torchao.prototype.smoothquant import ( SmoothQuantConfig, SmoothQuantObservedLinear, - insert_smooth_quant_observer_, - load_smooth_quant_recipe, - save_smooth_quant_recipe, ) +from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ -from torchao.quantization.utils import ( - dequantize_per_channel, - dynamically_quantize_per_channel, +from torchao.quantization.quant_api import ( + Int8DynamicActivationInt8WeightConfig, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, -) - -if torch.version.hip is not None: - pytest.skip("Skipping the test in ROCm", allow_module_level=True) class ToyLinearModel(torch.nn.Module): @@ -34,14 +25,22 @@ def __init__(self, m=512, n=256, k=128): super().__init__() self.linear1 = torch.nn.Linear(m, n, bias=False) self.linear2 = torch.nn.Linear(n, k, bias=False) - self.linear3 = torch.nn.Linear(k, 1, bias=False) + self.linear3 = torch.nn.Linear(k, 64, bias=False) def example_inputs( - self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda" + self, + batch_size, + sequence_length=10, + dtype=torch.bfloat16, + device="cuda", ): return [ torch.randn( - 1, sequence_length, self.linear1.in_features, dtype=dtype, device=device + 1, + sequence_length, + self.linear1.in_features, + dtype=dtype, + device=device, ) for j in range(batch_size) ] @@ -53,143 +52,163 @@ def forward(self, x): return x -bias_list = [True, False] -alpha_list = [None, 0.5, 0.75] -quant_mode_list = ["static", "dynamic"] -devices = ["cpu"] -if torch.cuda.is_available(): - devices.append("cuda") -idtypes = (torch.float, torch.bfloat16, torch.half) - -if TORCH_VERSION_AT_LEAST_2_5: - # This test case will trigger recompilation many times, so set a large cache_size_limit here - torch._dynamo.config.cache_size_limit = 128 - - -@pytest.mark.parametrize("bias", bias_list) -@pytest.mark.parametrize("alpha", alpha_list) -@pytest.mark.parametrize("quant_mode", quant_mode_list) -@pytest.mark.parametrize("device", devices) -@pytest.mark.parametrize("idtype", idtypes) -@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it") -def test_compute(bias, alpha, quant_mode, device, idtype): - class Linear(torch.nn.Module): - def __init__(self, bias: bool): - super().__init__() - self.fc = torch.nn.Linear(32, 32, bias) - self.fc.weight.data = torch.randn_like(self.fc.weight.data) - - def forward(self, x): - return self.fc(x) - - m = Linear(bias).eval().to(idtype).to(device) - m_ref = deepcopy(m) - data = torch.randn(2, 32, dtype=idtype, device=device) - - # calibrate - insert_smooth_quant_observer_(m, alpha, quant_mode) - m(data) - # quantize - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m, SmoothQuantConfig(), is_observed_linear) - with torch.inference_mode(): - if TORCH_VERSION_AT_LEAST_2_5: - m = torch.compile(m, fullgraph=True) - out = m(data) - - # reference - weight = m_ref.fc.weight.data.float() - b = m_ref.fc.bias if bias else None - x_abs_max_per_ic = torch.abs(data).max(dim=0).values - w_abs_max_per_ic = torch.abs(weight).max(dim=0).values - smoothing_factor = ( - 1 - if alpha is None - else ( - torch.pow(x_abs_max_per_ic, alpha) - / torch.pow(w_abs_max_per_ic, 1 - alpha) - ) +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm") +class TestSmoothQuant(unittest.TestCase): + """SmoothQuant tests using only supported quantization configs.""" + + @classmethod + def setUpClass(cls): + """Set up class-level configuration for tests.""" + # This test case will trigger recompilation many times, so set a large cache_size_limit here + torch._dynamo.config.cache_size_limit = 128 + + @common_utils.parametrize("alpha", [0.5, 0.75]) + @common_utils.parametrize( + "base_config", + [ + Int8DynamicActivationInt8WeightConfig(), + # Note: float8_static_activation_float8_weight is broken after recent PyTorch update. + # TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py + ], + ) + @common_utils.parametrize("device", ["cpu", "cuda"]) + @common_utils.parametrize("input_dtype", [torch.bfloat16]) + def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): + """Test if SmoothQuant achieves lower loss than basic quantization.""" + in_features = 64 + out_features = 128 + + # Note: This is sanity check. For real run, consider Transformer model to reproduce. + X = torch.randn(16, in_features, dtype=input_dtype, device=device) + W = torch.randn(out_features, in_features, dtype=input_dtype, device=device) + + # Create linear layer + linear = ( + torch.nn.Linear(in_features, out_features, bias=False) + .to(device) + .to(input_dtype) + ) + with torch.no_grad(): + linear.weight.copy_(W) + + # Reference output + out_ref = linear(X) + + # Step 1. Basic quantization + basic_model = deepcopy(linear) + quantize_(basic_model, base_config) + out_basic = basic_model(X) + loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item() + + # SmoothQuant quantization + model = deepcopy(linear) + config = SmoothQuantConfig( + base_config=base_config, + step=SmoothQuantStep.PREPARE, + alpha=alpha, + ) + quantize_(model, config) + + # Perform calibration with test data + model(X) + + # Step 2. SmoothQuant + config.step = SmoothQuantStep.CONVERT + quantize_(model, config) + + out_smoothquant = model(X) + loss_smoothquant = torch.nn.functional.mse_loss(out_smoothquant, out_ref).item() + + assert loss_smoothquant < loss_base, ( + f"SmoothQuant loss ({loss_smoothquant:.6f}) should not be higher than basic loss ({loss_base:.6f})" + ) + + @common_utils.parametrize( + "base_config", + [ + Int8DynamicActivationInt8WeightConfig(), + # TODO: Check more quantization APIs + ], + ) + def test_observer_insertion(self, base_config): + """Test that PREPARE step correctly inserts SmoothQuantObservedLinear.""" + + m = ToyLinearModel().eval() + + # Before quantization - should be regular Linear + self.assertIsInstance(m.linear1, torch.nn.Linear) + self.assertNotIsInstance(m.linear1, SmoothQuantObservedLinear) + + # PREPARE step - should insert observers + config = SmoothQuantConfig( + base_config=base_config, + step=SmoothQuantStep.PREPARE, ) - act = data / smoothing_factor - wei = weight * smoothing_factor - qw, w_scales, w_zps = dynamically_quantize_per_channel( - wei, -127, 127, torch.int8 + quantize_(m, config) + + # After PREPARE - should be SmoothQuantObservedLinear + self.assertIsInstance(m.linear1, SmoothQuantObservedLinear) + self.assertTrue(hasattr(m.linear1, "obs")) + + # Test calibration + test_data = torch.randn(2, 512) + m(test_data) + + # CONVERT step - should produce regular Linear with quantized weights + config.step = SmoothQuantStep.CONVERT + quantize_(m, config) + + # After CONVERT - should be regular Linear again (but quantized) + self.assertIsInstance(m.linear1, torch.nn.Linear) + self.assertNotIsInstance(m.linear1, SmoothQuantObservedLinear) + + @common_utils.parametrize( + "base_config", + [ + Int8DynamicActivationInt8WeightConfig(), + # TODO: Check more quantization APIs + ], + ) + def test_prepare_for_loading(self, base_config): + """Test PREPARE_FOR_LOADING step for loading pre-quantized checkpoints.""" + + m = ToyLinearModel().eval() + + # Before quantization - should be regular Linear + self.assertIsInstance(m.linear1, torch.nn.Linear) + self.assertNotIsInstance(m.linear1, SmoothQuantObservedLinear) + + # PREPARE_FOR_LOADING step - should create quantized model ready for loading + config = SmoothQuantConfig( + base_config=base_config, + step=SmoothQuantStep.PREPARE_FOR_LOADING, + alpha=0.5, ) - fq_wei = dequantize_per_channel(qw, w_scales, w_zps, idtype) - if quant_mode == "static": - # activation is quantized per-tensor - act_min, act_max = torch.aminmax(act.float()) - max_val_pos = torch.max(-act_min, act_max) - act_scale = max_val_pos / 127.0 - fq_act = ( - torch.quantize_per_tensor( - act.float(), scale=act_scale.item(), zero_point=0, dtype=torch.qint8 - ) - .dequantize() - .to(idtype) + quantize_(m, config) + + # After PREPARE_FOR_LOADING - should be regular Linear with quantized weights + self.assertIsInstance(m.linear1, torch.nn.Linear) + self.assertNotIsInstance(m.linear1, SmoothQuantObservedLinear) + + # Test that model can run inference + test_data = torch.randn(2, 512) + with torch.inference_mode(): + output = m(test_data) + + # Validate output + self.assertIsNotNone( + output, "PREPARE_FOR_LOADING model output should not be None" ) - out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) - else: - # activation is quantized per-row (batch * sequence_length) - qx, x_scales, x_zps = dynamically_quantize_per_channel( - act.float(), -127, 127, torch.int8 + self.assertFalse( + torch.isnan(output).any(), "Model should not produce NaN values" ) - fq_act = dequantize_per_channel(qx, x_scales, x_zps, idtype) - out_ref = torch.nn.functional.linear(fq_act, fq_wei, b) - - # BFloat16 and Float16 have larger errors - atol = 0.1 if idtype == torch.float else (0.2 if idtype == torch.half else 0.3) - assert torch.allclose(out, out_ref.to(idtype), atol=atol) - - -@pytest.mark.parametrize("alpha", alpha_list) -@pytest.mark.parametrize("quant_mode", quant_mode_list) -@pytest.mark.parametrize("device", devices) -@pytest.mark.parametrize("idtype", idtypes) -@pytest.mark.skip("this test is broken on recent PyTorch, TODO(#1639): fix it") -def test_save_load_recipe(alpha, quant_mode, device, idtype): - dataset_size = 20 - l1, l2, l3 = 512, 256, 128 - original_dtype = idtype - n_calib_examples = 10 - sequence_length = 5 - - m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device) - m_save_load = deepcopy(m) - - dataset = m.example_inputs( - dataset_size, - sequence_length=sequence_length, - dtype=original_dtype, - device=device, - ) - calibration_data = dataset[:n_calib_examples] - - # calibrate - insert_smooth_quant_observer_(m, alpha, quant_mode) - insert_smooth_quant_observer_(m_save_load, alpha, quant_mode) - - for example in calibration_data: - m(example.to(device)) - m_save_load(example.to(device)) - - with tempfile.NamedTemporaryFile() as fp: - save_path = fp.name - save_smooth_quant_recipe(m_save_load, save_path) - load_smooth_quant_recipe(m_save_load, save_path) - - # quantize - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m, SmoothQuantConfig(), is_observed_linear) - if TORCH_VERSION_AT_LEAST_2_5: - # earlier versions are not compatible - m = torch.compile(m, fullgraph=True) - m_save_load = torch.compile(m_save_load, fullgraph=True) - out_list = [m(data.squeeze(0)) for data in dataset] - out = torch.cat(out_list) - save_load_out_list = [m_save_load(data.squeeze(0)) for data in dataset] - save_load_out = torch.cat(save_load_out_list) - - assert out is not None - assert save_load_out is not None - assert torch.allclose(out, save_load_out) + self.assertEqual( + output.shape, (2, 64), "Output shape should match expected dimensions" + ) + + +common_utils.instantiate_parametrized_tests(TestSmoothQuant) + +if __name__ == "__main__": + unittest.main() diff --git a/test/prototype/test_tensor_conversion.py b/test/prototype/test_tensor_conversion.py new file mode 100644 index 0000000000..1647a13693 --- /dev/null +++ b/test/prototype/test_tensor_conversion.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +import torch + +from torchao.prototype.parq.quant import ( + StretchedIntxWeightConfig, + StretchedUnifTorchaoQuantizer, +) +from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor +from torchao.prototype.tensor_conversion.api import ( + _convert_model_for_aarch64, + convert_to_packed_tensor_based_on_current_hardware, +) +from torchao.quantization import ( + Int4PreshuffledTensor, + Int4Tensor, + MappingType, +) +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, +) +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + _is_kernel_library_loaded, +) +from torchao.quantization.utils import compute_error +from torchao.utils import _is_fbgemm_genai_gpu_available + + +class ToyLinearModelWithTiedEmbedding(torch.nn.Module): + def __init__(self, d0=512, d1=512, d2=256, d3=128, d4=32): + super().__init__() + self.embedding1 = torch.nn.Embedding(d0, d1) + self.embedding2 = torch.nn.Embedding(d0, d1) + self.embedding3 = torch.nn.Embedding(d0, d1) + + self.linear1 = torch.nn.Linear(d1, d2, bias=False) + self.linear2 = torch.nn.Linear(d2, d3, bias=True) + self.linear3 = torch.nn.Linear(d3, d4, bias=False) + self.linear4 = torch.nn.Linear(d4, d1, bias=False) + + self.lm_head1 = torch.nn.Linear(d1, d0, bias=False) + self.lm_head2 = torch.nn.Linear(d1, d0, bias=False) + self.lm_head3 = torch.nn.Linear(d1, d0, bias=False) + + # Tie weights + # lm_head1 / lm_head2 form one tied weight group + self.embedding2.weight = self.embedding1.weight + self.lm_head1.weight = self.embedding1.weight + self.lm_head2.weight = self.embedding1.weight + + # lm_head3 forms a separate tied weight group + self.lm_head3.weight = self.embedding3.weight + + def example_inputs( + self, + lead_dim=(1,), + dtype=torch.bfloat16, + ): + return ( + torch.randint( + 0, + self.embedding1.num_embeddings, + size=lead_dim, + dtype=torch.int64, + device="cpu", + ), + ) + + def forward(self, x): + x = self.embedding1(x) + self.embedding2(x) + self.embedding3(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.lm_head1(x) + self.lm_head2(x) + self.lm_head3(x) + return x + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + yield + torch._dynamo.reset() # reset cache between tests + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) +@pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) +@pytest.mark.parametrize( + "lead_dim", + [ + (1,), + (5,), + (7, 2), + ], +) +@pytest.mark.skipif( + not _is_kernel_library_loaded(), reason="Kernel library is not loaded" +) +def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim): + torch.manual_seed(0) + + model = ToyLinearModelWithTiedEmbedding() + model = model.to(dtype) + example_inputs = model.example_inputs(lead_dim, dtype) + + # Quantize linear 2 and 3 with PARQ + quantizer = StretchedUnifTorchaoQuantizer(bit_width) + config = StretchedIntxWeightConfig( + b=bit_width, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=granularity, + activation_quantization="int8_asym_per_token", + ) + quantize_(model, config, filter_fn=lambda m, fqn: fqn in ["linear2", "linear3"]) + + # Quantize linear 1 and 4 with int8 dynamic activation + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=granularity, + weight_mapping_type=MappingType.SYMMETRIC, + ) + quantize_( + model, + config, + filter_fn=lambda m, fqn: fqn + in ["linear1", "linear4", "lm_head1", "lm_head2", "lm_head3"], + ) + + # Quantize embedding 1, 2, and 3 with weight only + config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + mapping_type=MappingType.SYMMETRIC, + ) + quantize_( + model, + config, + filter_fn=lambda m, fqn: fqn in ["embedding1", "embedding2", "embedding3"], + ) + model_out = model(*example_inputs) + + # Convert to optimized model + _convert_model_for_aarch64(model) + + # Check expected tensor subclass + assert isinstance(model.linear2.weight, Int8LutTensor) + assert isinstance(model.linear3.weight, Int8LutTensor) + assert isinstance(model.linear1.weight, IntxOpaqueTensor) + assert isinstance(model.linear4.weight, IntxOpaqueTensor) + + # Assert tied params + tied_group1_id = id(model.embedding1.weight) + assert id(model.embedding2.weight) == tied_group1_id + assert id(model.lm_head1.weight) == tied_group1_id + assert id(model.lm_head2.weight) == tied_group1_id + + assert id(model.lm_head3.weight) == id(model.embedding3.weight) + assert id(model.lm_head3.weight) != tied_group1_id + + # Compare converted out with original out + converted_out = model(*example_inputs) + sqnr = compute_error(model_out, converted_out) + sqnr_threshold = 30 + assert sqnr > sqnr_threshold, f"sqnr: {sqnr}" + + # Check exported graph for correct ops + ep = torch.export.export(model, example_inputs) + expected_counts = { + "torch.ops.torchao._shared_embedding_": 3, + "torch.ops.torchao._linear_8bit_act_": 7, + "torch.ops.aten.linear.default": 0, + "torch.ops.aten.embedding.default": 0, + } + for line, cnt in expected_counts.items(): + assert ep.graph_module.code.count(line) == cnt, ( + f"expected {cnt} {line} in {ep.graph_module.code}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA") +@pytest.mark.skipif( + not _is_fbgemm_genai_gpu_available(), reason="Requires fbgemm-gpu-genai >= 1.2.0" +) +def test_int4_tensor_conversion(): + m = torch.nn.Sequential( + torch.nn.Linear(256, 512, dtype=torch.bfloat16, device="cuda") + ) + quantize_(m, Int4WeightOnlyConfig(group_size=128)) + weight = m[0].weight + assert isinstance(weight, Int4Tensor) + example_inputs = (torch.randn(32, 256, dtype=torch.bfloat16, device="cuda"),) + before_conversion = m(*example_inputs) + m[0].weight = torch.nn.Parameter( + convert_to_packed_tensor_based_on_current_hardware(weight), requires_grad=False + ) + after_conversion = m(*example_inputs) + assert isinstance(m[0].weight, Int4PreshuffledTensor) + assert torch.equal(before_conversion, after_conversion) diff --git a/test/quantization/pt2e/test_arm_inductor_quantizer.py b/test/quantization/pt2e/test_arm_inductor_quantizer.py index 750e88d451..f74b6620db 100644 --- a/test/quantization/pt2e/test_arm_inductor_quantizer.py +++ b/test/quantization/pt2e/test_arm_inductor_quantizer.py @@ -6,12 +6,22 @@ # Owner(s): ["oncall: quantization"] import copy +import functools import itertools +import platform import unittest from enum import Enum import torch import torch.nn as nn +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + skipIfNoInductorSupport, +) +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as armiq from torchao.quantization.pt2e import ObserverBase @@ -26,22 +36,7 @@ from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( QUANT_ANNOTATION_KEY, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - -import functools -import platform - -from torch.testing._internal.common_quantization import ( - NodeSpec as ns, -) -from torch.testing._internal.common_quantization import ( - QuantizationTestCase, - skipIfNoInductorSupport, -) -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo +from torchao.utils import torch_version_at_least def skipIfNoArm(fn): @@ -319,10 +314,7 @@ def _test_quantizer( # program capture m = copy.deepcopy(m_eager) - m = export_for_training( - m, - example_inputs, - ).module() + m = torch.export.export(m, example_inputs).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -356,7 +348,7 @@ def _test_quantizer( @skipIfNoInductorSupport -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EArmInductor(ArmInductorQuantTestCase): @skipIfNoArm def test_conv2d(self): @@ -580,7 +572,7 @@ def _test_linear_unary_helper( Test pattern of linear with unary post ops (e.g. relu) with ArmInductorQuantizer. """ use_bias_list = [True, False] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of export inplace_list = [False] if post_op_algo_list is None: post_op_algo_list = [None] @@ -720,7 +712,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op is supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of export inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = ArmInductorQuantizer().set_global( @@ -1082,7 +1074,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs).module() + m = torch.export.export(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index a1b43b4f3a..90050c4c9f 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -33,10 +33,7 @@ OP_TO_ANNOTATOR, QuantizationConfig, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import torch_version_at_least class TestHelperModules: @@ -100,7 +97,7 @@ def forward(self, x): @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestDuplicateDQPass(QuantizationTestCase): def _test_duplicate_dq( self, @@ -112,7 +109,7 @@ def _test_duplicate_dq( # program capture m = copy.deepcopy(m_eager) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index c9fa3960ee..eee33e3b13 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -20,7 +20,7 @@ get_symmetric_quantization_config, ) from torchao.testing.pt2e._xnnpack_quantizer_utils import OP_TO_ANNOTATOR -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least class TestHelperModules: @@ -64,7 +64,7 @@ def _tag_partitions( # TODO: rename to TestPortMetadataPass to align with the util name? @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestMetaDataPorting(QuantizationTestCase): def _test_quant_tag_preservation_through_decomp( self, model, example_inputs, from_node_to_tags @@ -107,7 +107,7 @@ def _test_metadata_porting( # program capture m = copy.deepcopy(m_eager) - m = torch.export.export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 80648f6c77..75e9688806 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -18,41 +18,45 @@ prepare_for_propagation_comparison, ) from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least -if TORCH_VERSION_AT_LEAST_2_8: - from torch.export import export_for_training +# Increase cache size limit to avoid FailOnRecompileLimitHit error when running multiple tests +# that use torch.export.export, which causes many dynamo recompilations +if torch_version_at_least("2.8.0"): + torch._dynamo.config.cache_size_limit = 128 @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8 and above, including nightly" + not torch_version_at_least("2.8.0"), + "Requires torch 2.8 and above, including nightly", ) @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase): - @unittest.skip( - "torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..." - ) def test_simple(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map = self._extract_from_node_source(m) - self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map)) + self.assertEqual( + len(set(from_node_source_map.values())), len(from_node_source_map) + ) @unittest.skip("debug flow not working on model with conditional control flow") def test_control_flow(self): m = TestHelperModules.ControlFlow() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map = self._extract_from_node_source(m) - self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map)) + self.assertEqual( + len(set(from_node_source_map.values())), len(from_node_source_map) + ) def test_copy_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() @@ -60,74 +64,68 @@ def test_copy_preserve_handle(self): ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map_ref = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map_ref = self._extract_from_node_source(m) ep_copy = copy.copy(ep) - debug_handle_map = self._extract_debug_handles(ep_copy.module()) + from_node_source_map = self._extract_from_node_source(ep_copy.module()) - self._assert_each_node_has_debug_handle(ep) - self.assertEqual(debug_handle_map, debug_handle_map_ref) + self._assert_each_node_has_from_node_source(ep) + self.assertEqual(from_node_source_map, from_node_source_map_ref) def test_deepcopy_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() ep = torch.export.export(m, example_inputs, strict=True) - debug_handle_map_ref = self._extract_debug_handles(ep.module()) + from_node_source_map_ref = self._extract_from_node_source(ep.module()) ep_copy = copy.deepcopy(ep) - debug_handle_map = self._extract_debug_handles(ep_copy.module()) + from_node_source_map = self._extract_from_node_source(ep_copy.module()) - self._assert_each_node_has_debug_handle(ep.module()) - self.assertEqual(debug_handle_map, debug_handle_map_ref) + self._assert_each_node_has_from_node_source(ep.module()) + self.assertEqual(from_node_source_map, from_node_source_map_ref) + self.assertEqual( + set(from_node_source_map.values()), set(from_node_source_map_ref.values()) + ) - @unittest.skip( - "torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..." - ) def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map_ref = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map_ref = self._extract_from_node_source(m) - ep_reexport = export_for_training(m, example_inputs, strict=True) + ep_reexport = torch.export.export(m, example_inputs, strict=True) m_reexport = ep_reexport.module() - self._assert_each_node_has_debug_handle(m_reexport) - debug_handle_map = self._extract_debug_handles(m_reexport) + self._assert_each_node_has_from_node_source(m_reexport) + from_node_source_map = self._extract_from_node_source(m_reexport) - self.assertEqual(debug_handle_map, debug_handle_map_ref) + self.assertEqual(from_node_source_map, from_node_source_map_ref) - @unittest.skip( - "torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..." - ) def test_run_decompositions_same_handle_id(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map_ref = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map_ref = self._extract_from_node_source(m) ep_copy = copy.copy(ep) ep_copy = ep_copy.run_decompositions() m_decomposed = ep_copy.module() - self._assert_each_node_has_debug_handle(m_decomposed) - debug_handle_map = self._extract_debug_handles(m_decomposed) + self._assert_each_node_has_from_node_source(m_decomposed) + from_node_source_map = self._extract_from_node_source(m_decomposed) # checking the map still has the same ids, the node may change self.assertEqual( - set(debug_handle_map.values()), set(debug_handle_map_ref.values()) + set(from_node_source_map.values()), set(from_node_source_map_ref.values()) ) - @unittest.skip( - "torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..." - ) def test_run_decompositions_map_handle_to_new_nodes(self): test_models = [ TestHelperModules.TwoLinearModule(), @@ -136,31 +134,32 @@ def test_run_decompositions_map_handle_to_new_nodes(self): for m in test_models: example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - pre_decomp_to_debug_handle_map_ref = ( - self._extract_debug_handles_with_prev_decomp_op(m) + self._assert_each_node_has_from_node_source(m) + pre_decomp_to_from_node_source_map_ref = ( + self._extract_from_node_source_with_prev_decomp_op(m) ) ep_copy = copy.copy(ep) ep_copy = ep_copy.run_decompositions() m_decomposed = ep_copy.module() - self._assert_each_node_has_debug_handle(m_decomposed) - pre_decomp_to_debug_handle_map = ( - self._extract_debug_handles_with_prev_decomp_op(m_decomposed) + self._assert_each_node_has_from_node_source(m_decomposed) + pre_decomp_to_from_node_source_map = ( + self._extract_from_node_source_with_prev_decomp_op(m_decomposed) ) - # checking the map still has the same ids, the node may change + # checking the map still has the same infos, the node may change self.assertEqual( - pre_decomp_to_debug_handle_map, pre_decomp_to_debug_handle_map_ref + pre_decomp_to_from_node_source_map, + pre_decomp_to_from_node_source_map_ref, ) def test_prepare_for_propagation_comparison(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() m_logger = prepare_for_propagation_comparison(m) ref = m(*example_inputs) @@ -176,20 +175,20 @@ def test_prepare_for_propagation_comparison(self): def test_added_node_gets_unique_id(self) -> None: m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = torch.export.export(m, example_inputs, strict=True) - ref_handles = self._extract_debug_handles(ep.module()) - ref_counter = Counter(ref_handles.values()) + ref_from_node_source = self._extract_from_node_source(ep.module()) + ref_counter = Counter(ref_from_node_source.values()) for k, v in ref_counter.items(): self.assertEqual( v, 1, - msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1", + msg=f"For from_node info {k}, there were {v} nodes with that info, but expected only 1", ) - # Now that we have unique ids, add a new node into the graph and re-generate - # to make sure that the new node gets a unique id. + # Now that we have unique infos, add a new node into the graph and re-generate + # to make sure that the new node gets a unique info. last_node = next(iter(reversed(ep.graph.nodes))) with ep.graph.inserting_before(last_node): arg = last_node.args[0] @@ -200,30 +199,39 @@ def test_added_node_gets_unique_id(self) -> None: arg.replace_all_uses_with(n, lambda x: x != n) ep.graph_module.recompile() - # Regenerate handles, make sure only the new relu node has a new id, and - # it doesn't clash with any of the existing ids. + # Regenerate from_node info, make sure only the new relu node has a new info, and + # it doesn't clash with any of the existing infos. m = ep.module() - self._assert_each_node_has_debug_handle(m) - handles_after_modification = self._extract_debug_handles(m) - handles_counter = Counter(handles_after_modification.values()) - for name, handle in ref_handles.items(): - self.assertIn(name, handles_after_modification) - # Check that handle was unchanged. - self.assertEqual(handles_after_modification[name], handle) + self._assert_each_node_has_from_node_source(m) + from_node_source_after_modification = self._extract_from_node_source(m) + from_node_source_counter = Counter(from_node_source_after_modification.values()) + for name, from_node_source in ref_from_node_source.items(): + self.assertIn(name, from_node_source_after_modification) + # Check that from_node info was unchanged. + self.assertEqual( + from_node_source_after_modification[name], from_node_source + ) # Check that total count was unchanged. - ref_count = ref_counter[handle] - after_count = handles_counter[handle] + ref_count = ref_counter[from_node_source] + after_count = from_node_source_counter[from_node_source] self.assertEqual( after_count, ref_count, - msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}", + msg=f"For from_node info {from_node_source}, there were {after_count} nodes with that info, but expected only {ref_count}", ) - # Check for relu specifically. Avoid hardcoding the handle id since it + # Check for relu specifically. Avoid hardcoding the from_node info since it # may change with future node ordering changes. - self.assertNotIn(handles_after_modification["relu_default"], ref_counter) - self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1) + self.assertNotIn( + from_node_source_after_modification["relu_default"], ref_counter + ) + self.assertEqual( + from_node_source_counter[ + from_node_source_after_modification["relu_default"] + ], + 1, + ) if __name__ == "__main__": diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 19f208a55c..482e97e3ce 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -57,6 +57,7 @@ from torchao.quantization.pt2e.quantizer.embedding_quantizer import ( # noqa: F811 EmbeddingQuantizer, ) +from torchao.testing.model_architectures import ConvWithSharedWeightInExportedModel from torchao.testing.pt2e._xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, @@ -66,15 +67,11 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - +from torchao.utils import torch_version_at_least DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) -if TORCH_VERSION_AT_LEAST_2_7: +if torch_version_at_least("2.7.0"): from torch.testing._internal.common_utils import ( TEST_HPU, ) @@ -83,7 +80,7 @@ @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2E(PT2EQuantizationTestCase): def test_simple_quantizer(self): # TODO: use OP_TO_ANNOTATOR @@ -154,6 +151,34 @@ def validate(self, model: torch.fx.GraphModule) -> None: node_list, ) + def test_chunked_bn_fusion(self): + batch_size = 1 + n_chunks = 3 + in_channels = 1 + out_channels = 32 + m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels) + m.bn.running_var = torch.nn.Parameter( + torch.rand(out_channels) * 1e-2, requires_grad=False + ) + + m.eval() + example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),) + ref_outputs = m(*example_inputs) + traced_model = torch.export.export(m, example_inputs, strict=True).module() + traced_outputs = traced_model(*example_inputs) + prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer()) + prepared_outputs = prepared_model(*example_inputs) + + if isinstance(ref_outputs, (tuple, list)): + for ref, prepared, traced in zip( + ref_outputs, prepared_outputs, traced_outputs + ): + torch.testing.assert_close(ref, traced) + torch.testing.assert_close(traced, prepared) + else: + torch.testing.assert_close(ref_outputs, traced_outputs) + torch.testing.assert_close(traced_outputs, prepared_outputs) + def test_wo_annotate_conv_output_quantizer(self): # TODO: use OP_TO_ANNOTATOR class BackendAQuantizer(Quantizer): @@ -793,7 +818,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) # program capture - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, BackendAQuantizer()) # make sure the two observers for input are shared conv_output_obs = [] @@ -853,7 +878,7 @@ def _test_transitive_sharing_with_cat_helper(self, quantizer): ) # program capture - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) # make sure the two input observers and output are shared @@ -1172,7 +1197,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: ) # program capture - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = BackendAQuantizer() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -1193,7 +1218,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: @parametrize("dtype", (torch.float32, torch.bfloat16)) @parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn)) def test_quantization_dtype(self, dtype, quant_dtype): - if TORCH_VERSION_AT_LEAST_2_7 and TEST_HPU: + if torch_version_at_least("2.7.0") and TEST_HPU: unittest.SkipTest("test doesn't currently work with HPU") class DtypeActQuantizer(Quantizer): @@ -1324,7 +1349,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: m = M().eval() example_inputs = torch.randn(1, 2, 3, 3) - m = export_for_training(m, (example_inputs,), strict=True).module() + m = torch.export.export(m, (example_inputs,), strict=True).module() with self.assertRaises(Exception): m = prepare_pt2e(m, BackendAQuantizer()) @@ -1332,7 +1357,7 @@ def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): # resetting dynamo cache torch._dynamo.reset() - m = export_for_training( + m = torch.export.export( m, example_inputs, ).module() @@ -1481,7 +1506,7 @@ def forward(self, x): quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() weight_meta = None for n in m.graph.nodes: if ( @@ -1569,7 +1594,7 @@ def forward(self, x): m = M().eval() quantizer = TestQuantizer() example_inputs = (torch.randn(1, 2, 3, 3),) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1620,7 +1645,7 @@ def forward(self, x, y, z): torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3), ) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1875,7 +1900,7 @@ def forward(self, x): example_inputs = (torch.randn(1),) m = M().train() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() if inplace: target = torch.ops.aten.dropout_.default else: @@ -1937,7 +1962,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() # Assert that batch norm op exists and is in train mode bn_node = self._get_node(m, bn_train_op) @@ -1968,7 +1993,7 @@ def test_disallow_eval_train(self): m.train() # After export: this is not OK - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -1990,7 +2015,7 @@ def test_disallow_eval_train(self): m.train() def test_allow_exported_model_train_eval(self): - if TORCH_VERSION_AT_LEAST_2_7 and TEST_HPU: + if torch_version_at_least("2.7.0") and TEST_HPU: unittest.SkipTest("test doesn't currently work with HPU") class M(torch.nn.Module): @@ -2011,7 +2036,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): targets = [n.target for n in m.graph.nodes] @@ -2077,7 +2102,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() torchao.quantization.pt2e.allow_exported_model_train_eval(m) # Mock m.recompile() to count how many times it's been called @@ -2109,7 +2134,7 @@ def _fake_recompile(): def test_model_is_exported(self): m = TestHelperModules.ConvWithBNRelu(relu=True) example_inputs = (torch.rand(3, 3, 5, 5),) - exported_gm = export_for_training(m, example_inputs, strict=True).module() + exported_gm = torch.export.export(m, example_inputs, strict=True).module() fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs) self.assertTrue( torchao.quantization.pt2e.export_utils.model_is_exported(exported_gm) @@ -2127,7 +2152,7 @@ def test_reentrant(self): quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=True, is_qat=True) ) - m.conv_bn_relu = export_for_training( + m.conv_bn_relu = torch.export.export( m.conv_bn_relu, example_inputs, strict=True ).module() m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) @@ -2137,7 +2162,7 @@ def test_reentrant(self): quantizer = XNNPACKQuantizer().set_module_type( torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) ) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) @@ -2300,7 +2325,7 @@ def test_speed(self): def dynamic_quantize_pt2e(model, example_inputs): torch._dynamo.reset() - model = export_for_training(model, example_inputs, strict=True).module() + model = torch.export.export(model, example_inputs, strict=True).module() # Per channel quantization for weight # Dynamic quantization for activation # Please read a detail: https://fburl.com/code/30zds51q @@ -2707,7 +2732,7 @@ def forward(self, x): example_inputs = (torch.randn(1, 3, 5, 5),) m = M() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(), ) @@ -2789,7 +2814,7 @@ def prepare_obs_or_fq_callback( edge_or_node_to_obs_or_fq[x] = new_observer example_inputs = (torch.rand(1, 32, 16, 16),) - gm = export_for_training(Model().eval(), example_inputs, strict=True).module() + gm = torch.export.export(Model().eval(), example_inputs, strict=True).module() gm = prepare_pt2e(gm, BackendAQuantizer()) gm = convert_pt2e(gm) for n in gm.graph.nodes: @@ -2816,7 +2841,7 @@ def check_nn_module(node): "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] ) - m.conv_bn_relu = export_for_training( + m.conv_bn_relu = torch.export.export( m.conv_bn_relu, example_inputs, strict=True ).module() for node in m.conv_bn_relu.graph.nodes: @@ -2901,7 +2926,7 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool: quantizer = TestQuantizer() example_inputs = (torch.randn(1, 2, 3, 3),) quantizer.set_example_inputs(example_inputs) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() # Check that the model has in-place ops self.assertTrue(has_inplace_ops(m)) m = prepare_pt2e(m, quantizer) @@ -2920,13 +2945,14 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool: @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase): def test_channel_group_quantization(self): + from torchao.quantization import PerGroup, PerToken from torchao.quantization.pt2e._affine_quantization import ( AffineQuantizedMinMaxObserver, ) - from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken + from torchao.quantization.pt2e.observer import MappingType class BackendAQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: @@ -3006,13 +3032,13 @@ def forward(self, x): def test_dynamic_affine_act_per_channel_weights(self): import operator + from torchao.quantization import PerToken from torchao.quantization.pt2e._affine_quantization import ( AffineQuantizedMovingAverageMinMaxObserver, ) from torchao.quantization.pt2e.observer import ( MappingType, PerChannelMinMaxObserver, - PerToken, ) class BackendAQuantizer(Quantizer): @@ -3097,12 +3123,14 @@ def forward(self, x): def test_dynamic_per_tok_act_per_group_weights(self): import operator + from torchao.quantization import PerGroup, PerToken + # TODO: merge into torchao observer from torchao.quantization.pt2e._affine_quantization import ( AffineQuantizedMinMaxObserver, AffineQuantizedPlaceholderObserver, ) - from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken + from torchao.quantization.pt2e.observer import MappingType class BackendAQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index d8a2c8df03..fb1b17ce9f 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -51,10 +51,7 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import torch_version_at_least class PT2EQATTestCase(QuantizationTestCase): @@ -151,7 +148,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper( is_per_channel=is_per_channel, is_qat=True ) ) - model_pt2e = export_for_training( + model_pt2e = torch.export.export( model_pt2e, example_inputs, strict=True ).module() model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer) @@ -250,7 +247,7 @@ def _verify_symmetric_xnnpack_qat_graph_helper( quantizer.set_global( get_symmetric_quantization_config(is_per_channel, is_qat=True) ) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -426,7 +423,7 @@ def _verify_symmetric_xnnpack_qat_graph_helper( ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): """ Base TestCase to be used for all conv-bn[-relu] fusion patterns. @@ -640,7 +637,7 @@ def forward(self, x): m = M(self.conv_class, self.bn_class, backbone) quantizer = XNNPACKQuantizer() quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) @@ -698,7 +695,7 @@ def get_source_fn(node: torch.fx.Node): def test_qat_conv_bn_bias_derived_qspec(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -745,7 +742,7 @@ def test_qat_conv_bn_bias_derived_qspec(self): def test_qat_per_channel_weight_custom_dtype(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = ConvBnInt32WeightQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -799,7 +796,7 @@ def test_qat_conv_transpose_bn_relu(self): def test_qat_conv_bn_per_channel_weight_bias(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True) m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -856,7 +853,7 @@ def test_fold_bn_erases_bn_node(self): it into conv in `convert_pt2e` even in train mode. """ m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) - m = export_for_training(m, self.example_inputs, strict=True).module() + m = torch.export.export(m, self.example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config(is_per_channel=False, is_qat=True), @@ -869,7 +866,7 @@ def test_fold_bn_erases_bn_node(self): @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): dim = 1 example_inputs = (torch.randn(1, 3, 5),) @@ -879,7 +876,7 @@ class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base): dim = 2 example_inputs = (torch.randn(1, 3, 5, 5),) @@ -1048,7 +1045,7 @@ def validate(self, model: torch.fx.GraphModule): @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EQATModels(PT2EQATTestCase): @skip_if_no_torchvision @skipIfNoQNNPACK @@ -1071,7 +1068,7 @@ def test_qat_mobilenet_v2(self): self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizeMixQATAndPTQ(QuantizationTestCase): class TwoLinear(torch.nn.Module): def __init__(self) -> None: @@ -1108,7 +1105,7 @@ def _prepare_qat_linears(self, model): in_channels = child.linear1.weight.size(1) example_input = (torch.rand((1, in_channels)),) - traced_child = export_for_training( + traced_child = torch.export.export( child, example_input, strict=True ).module() quantizer = XNNPACKQuantizer() @@ -1141,7 +1138,7 @@ def test_mixing_qat_ptq(self): self._convert_qat_linears(model) model(*example_inputs) - model_pt2e = export_for_training(model, example_inputs, strict=True).module() + model_pt2e = torch.export.export(model, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_module_type(torch.nn.Linear, None) diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index 2123995a4b..cd431c4ccb 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -27,14 +27,11 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import torch_version_at_least @skipIfNoQNNPACK -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestPT2ERepresentation(QuantizationTestCase): def _test_representation( self, @@ -48,7 +45,7 @@ def _test_representation( ) -> torch.nn.Module: # resetting dynamo cache torch._dynamo.reset() - model = export_for_training(model, example_inputs, strict=True).module() + model = torch.export.export(model, example_inputs, strict=True).module() model_copy = copy.deepcopy(model) model = prepare_pt2e(model, quantizer) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index fa981dc4d6..cfe8d790e9 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -16,7 +16,6 @@ from torch._inductor import config from torch._inductor.test_case import TestCase, run_tests from torch._inductor.utils import run_and_get_code -from torch.export import export_for_training from torch.testing._internal.common_quantization import ( skipIfNoDynamoSupport, skipIfNoONEDNN, @@ -26,6 +25,7 @@ IS_FBCODE, IS_LINUX, IS_X86, + TEST_ACL, instantiate_parametrized_tests, parametrize, ) @@ -45,15 +45,7 @@ X86InductorQuantizer, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_8, -) - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.testing._internal.common_utils import TEST_ACL -else: - TEST_ACL = False +from torchao.utils import torch_version_at_least # The dict value is match_nodes(computation_op+unary_op) unary_list = { @@ -98,6 +90,10 @@ lambda x, y: x.add_(y), ] +skipIfNoFloat8Support = unittest.skipIf( + not torch_version_at_least("2.9.0"), "Float8 requires torch 2.9+" +) + def get_default_quantizer(is_qat, is_dynamic): quantizer = X86InductorQuantizer() @@ -109,24 +105,127 @@ def get_default_quantizer(is_qat, is_dynamic): return quantizer +class FP8QDQLinear(torch.nn.Module): + def __init__(self, in_features, out_features, has_bias): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn((out_features, in_features)).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if has_bias: + self.bias = torch.randn((out_features,)) + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=self.weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + + q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( + tensor=input, + scale=torch.tensor([self.scale]), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=q_input, + scale=torch.tensor([self.scale]), + output_dtype=torch.float, + ) + + out = torch.nn.functional.linear(dq_input, weight, self.bias) + return out + + +def qdq(input, scale): + dtype = input.dtype + q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( + input, + torch.tensor([scale]), + torch.float8_e4m3fn, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + q_input, + torch.tensor([scale]), + dtype, + ) + return dq_input + + +def fp8_convert_(model): + def generate_model_info(model): + from collections import namedtuple + + mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"]) + parent_child_mod_dict = {} + + def create_mod_info_recursion(parent): + for name, mod in parent.named_children(): + parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent) + create_mod_info_recursion(mod) + + create_mod_info_recursion(model) + return parent_child_mod_dict + + parent_child_mod_dict = generate_model_info(model) + for name, mod in model.named_modules(): + mod_type_str = mod.__class__.__name__ + if mod_type_str not in [ + "Linear", + ]: + continue + param = mod.weight + xmax = torch.max(param) + weight_scale = xmax / torch.finfo(torch.float8_e4m3fn).max + mod.weight_scale = weight_scale + q_param = torch.clamp( + (param / weight_scale), + torch.finfo(torch.float8_e4m3fn).min, + torch.finfo(torch.float8_e4m3fn).max, + ).to(torch.float8_e4m3fn) + mod.weight.data = q_param + if mod_type_str in ["Linear"]: + patched_mod = FP8QDQLinear(mod.in_features, mod.out_features, False) + patched_mod.bias = mod.bias + patched_mod.weight_scale = weight_scale.item() + patched_mod.weight.data = q_param + + parent = parent_child_mod_dict[mod].parent + name = parent_child_mod_dict[mod].name + setattr(parent, name, patched_mod) + + def _generate_qdq_quantized_model( - mod, inputs, is_qat=False, is_dynamic=False, quantizer=None + mod, + inputs, + is_qat=False, + is_dynamic=False, + quantizer=None, + is_fp8=False, ): maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: - export_model = export_for_training(mod, inputs, strict=True).module() - quantizer = ( - quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic) - ) - prepare_model = ( - prepare_qat_pt2e(export_model, quantizer) - if is_qat - else prepare_pt2e(export_model, quantizer) - ) - prepare_model(*inputs) - torchao.quantization.pt2e.move_exported_model_to_eval(prepare_model) - convert_model = convert_pt2e(prepare_model) - return convert_model + if is_fp8: + # fp8_convert_ not support dynamic and qat yet + assert not is_dynamic + assert not is_qat + fp8_convert_(mod) + return mod + else: + export_model = torch.export.export(mod, inputs, strict=True).module() + quantizer = ( + quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic) + ) + prepare_model = ( + prepare_qat_pt2e(export_model, quantizer) + if is_qat + else prepare_pt2e(export_model, quantizer) + ) + prepare_model(*inputs) + torchao.quantization.pt2e.move_exported_model_to_eval(prepare_model) + convert_model = convert_pt2e(prepare_model) + return convert_model def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"): @@ -195,6 +294,7 @@ def _test_common( is_dynamic=False, quantizer=None, compile_options={}, # noqa: B006 + is_fp8=False, ): if not hasattr(self, "device"): has_xpu = any( @@ -225,7 +325,7 @@ def _test_common( maybe_autocast = contextlib.nullcontext() if check_quantization: convert_model = _generate_qdq_quantized_model( - mod, inputs, is_qat, is_dynamic, quantizer + mod, inputs, is_qat, is_dynamic, quantizer, is_fp8 ) with torch.no_grad(), maybe_autocast: _ = torch.compile(convert_model)(*inputs) @@ -250,11 +350,14 @@ def _test_code_common( check_dynamic=None, num_include_ops=None, quantizer=None, + is_fp8=False, ): with torch.no_grad(): clone_inputs = self._clone_inputs(inputs) if check_quantization: - mod = _generate_qdq_quantized_model(mod, inputs, quantizer=quantizer) + mod = _generate_qdq_quantized_model( + mod, inputs, quantizer=quantizer, is_fp8=is_fp8 + ) expected = mod(*inputs) actual, (source_code,) = run_and_get_code( torch.compile(mod, fullgraph=True, dynamic=check_dynamic), @@ -277,7 +380,7 @@ def _test_code_common( torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") class TestPatternMatcher(TestPatternMatcherBase): def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): class M(torch.nn.Module): @@ -1342,12 +1445,13 @@ def _qlinear_test_helper( self, inputs, device="cpu", - int8_mixed_bf16=False, + mixed_bf16=False, do_permute=False, matcher_check_fn=None, bias=True, is_dynamic=False, is_qat=False, + is_fp8=False, ): class M(torch.nn.Module): def __init__(self, use_bias, do_permute=False): @@ -1382,10 +1486,11 @@ def _default_matcher_check_fn(): if matcher_check_fn is not None else _default_matcher_check_fn ), - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, check_quantization=True, is_qat=is_qat, is_dynamic=is_dynamic, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -1397,6 +1502,16 @@ def test_qlinear_cpu(self): for bias in [True, False]: self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_cpu(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias, is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_dynamic_qlinear_cpu(self): @@ -1433,13 +1548,26 @@ def test_dynamic_qlinear_input_dim_exceeds_2(self): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_int8_mixed_bf16(self): + def test_qlinear_mixed_bf16(self): + r""" + This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 4)),), mixed_bf16=True, bias=bias + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_mixed_bf16(self): r""" - This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. """ for bias in [True, False]: self._qlinear_test_helper( - (torch.randn((2, 4)),), int8_mixed_bf16=True, bias=bias + (torch.randn((2, 4)),), mixed_bf16=True, bias=bias, is_fp8=True ) @skipIfNoDynamoSupport @@ -1451,16 +1579,39 @@ def test_qlinear_input_dim_exceeds_2(self): for bias in [True, False]: self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_input_dim_exceeds_2(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias, is_fp8=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_mixed_bf16_input_dim_exceeds_2(self): + r""" + This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. + """ + for bias in [True, False]: + self._qlinear_test_helper( + (torch.randn((2, 3, 4)),), mixed_bf16=True, bias=bias + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self): + @skipIfNoFloat8Support + def test_fp8_qlinear_mixed_bf16_input_dim_exceeds_2(self): r""" - This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. + This testcase will quantize a single Linear Moduel with mixed_bf16 quantization. """ for bias in [True, False]: self._qlinear_test_helper( - (torch.randn((2, 3, 4)),), int8_mixed_bf16=True, bias=bias + (torch.randn((2, 3, 4)),), mixed_bf16=True, bias=bias, is_fp8=True ) @skipIfNoDynamoSupport @@ -1489,10 +1640,38 @@ def matcher_check_fn(): bias=bias, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_input_dim_exceeds_2_and_not_contiguous(self): + r""" + This testcase will quantize a single Linear Module. + * Input dim exceeds 2 + * Input not contiguous + """ + for bias in [True, False]: + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 13 if bias else 12, + ) + + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + is_fp8=True, + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): + def test_qlinear_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): r""" This testcase will quantize a single Linear Module for int8_bf16. * Input dim exceeds 2 @@ -1511,14 +1690,49 @@ def matcher_check_fn(): self._qlinear_test_helper( (torch.randn((2, 4, 3, 4)),), - int8_mixed_bf16=True, + mixed_bf16=True, do_permute=True, matcher_check_fn=matcher_check_fn, bias=bias, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self): + r""" + This testcase will quantize a single Linear Module for int8_bf16. + * Input dim exceeds 2 + * Input not contiguous + """ + for bias in [True, False]: + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2 + ) + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], + 17 if bias else 16, + ) + + self._qlinear_test_helper( + (torch.randn((2, 4, 3, 4)),), + mixed_bf16=True, + do_permute=True, + matcher_check_fn=matcher_check_fn, + bias=bias, + is_fp8=True, + ) + def _qlinear_unary_test_helper( - self, inputs, unary_op=torch.nn.ReLU(), device="cpu", int8_mixed_bf16=False + self, + inputs, + unary_op=torch.nn.ReLU(), + device="cpu", + mixed_bf16=False, + is_fp8=False, ): class M(torch.nn.Module): def __init__(self, use_bias): @@ -1555,8 +1769,9 @@ def matcher_check_fn(): mod, inputs, matcher_check_fn, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, check_quantization=True, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -1567,14 +1782,35 @@ def test_qlinear_relu_cpu(self): """ self._qlinear_unary_test_helper((torch.randn((2, 4)),)) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_relu_cpu(self): + r""" + This testcase will quantize a Linear->ReLU pattern. + """ + self._qlinear_unary_test_helper((torch.randn((2, 4)),), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_relu_int8_mixed_bf16(self): + def test_qlinear_relu_mixed_bf16(self): r""" - This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization. + This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. """ - self._qlinear_unary_test_helper((torch.randn((2, 4)),), int8_mixed_bf16=True) + self._qlinear_unary_test_helper((torch.randn((2, 4)),), mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_relu_mixed_bf16(self): + r""" + This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. + """ + self._qlinear_unary_test_helper( + (torch.randn((2, 4)),), mixed_bf16=True, is_fp8=True + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1584,14 +1820,35 @@ def test_qlinear_relu_input_dim_exceeds_2(self): """ self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),)) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_relu_input_dim_exceeds_2(self): + r""" + This testcase will quantize a Linear->ReLU pattern. + """ + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), is_fp8=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_relu_mixed_bf16_input_dim_exceeds_2(self): + r""" + This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. + """ + self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), mixed_bf16=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_relu_int8_mixed_bf16_input_dim_exceeds_2(self): + @skipIfNoFloat8Support + def test_fp8_qlinear_relu_mixed_bf16_input_dim_exceeds_2(self): r""" - This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization. + This testcase will quantize a Linear->ReLU pattern with mixed_bf16 quantization. """ - self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), int8_mixed_bf16=True) + self._qlinear_unary_test_helper( + (torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=True + ) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1602,25 +1859,49 @@ def test_qlinear_gelu_cpu(self): for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_gelu_cpu(self): + r""" + This testcase will quantize a Linear->GELU pattern. + """ + for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: + self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu, is_fp8=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_gelu_mixed_bf16(self): + r""" + This testcase will quantize a Linear->GELU pattern with mixed_bf16 quantization. + """ + for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: + self._qlinear_unary_test_helper( + (torch.randn((2, 4)),), gelu, mixed_bf16=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_gelu_int8_mixed_bf16(self): + @skipIfNoFloat8Support + def test_fp8_qlinear_gelu_mixed_bf16(self): r""" - This testcase will quantize a Linear->GELU pattern with int8_mixed_bf16 quantization. + This testcase will quantize a Linear->GELU pattern with mixed_bf16 quantization. """ for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]: self._qlinear_unary_test_helper( - (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True + (torch.randn((2, 4)),), gelu, mixed_bf16=True, is_fp8=True ) def _qlinear_add_test_helper( self, device="cpu", use_relu=False, - int8_mixed_bf16=False, + mixed_bf16=False, is_qat=True, is_dynamic=True, + is_fp8=False, ): r""" This testcase will quantize two consecutive Linear->Add(->relu) patterns as: @@ -1688,13 +1969,18 @@ def forward(self, x): res = self.relu2(res) return res + if is_fp8: + # fp8_convert_ not support dynamic and qat yet + assert not is_dynamic + assert not is_qat + add_fn_list = [ lambda x, y: x + y, lambda x, y: y + x, lambda x, y: x.add_(y), lambda x, y: y.add_(x), ] - fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False] + fake_quant_x2_list = [False, True] if mixed_bf16 and not is_fp8 else [False] shape_list = [(4, 4), (4, 4, 4)] cases = itertools.product(add_fn_list, fake_quant_x2_list, shape_list) for add_fn, fq_x2, shape in cases: @@ -1709,7 +1995,7 @@ def matcher_check_fn(): counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 ) # pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm] - nodes_per_match = 6 if int8_mixed_bf16 else 4 + nodes_per_match = 6 if mixed_bf16 else 4 if len(shape) == 3: # pattern = [dequant_per_tensor, (convert_dtype), (view), \ # dequant_per_channel, (convert_dtype), (view), permute, addmm] @@ -1745,9 +2031,10 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, is_qat=is_qat, is_dynamic=is_dynamic, + is_fp8=is_fp8, ) if TEST_ACL: @@ -1765,6 +2052,7 @@ def matcher_check_fn(): [], check_quantization=True, num_include_ops=[2, 2], + is_fp8=is_fp8, ) else: # For python wrapper @@ -1778,6 +2066,7 @@ def matcher_check_fn(): [], check_quantization=True, num_include_ops=[2, 2], + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -1798,19 +2087,34 @@ def test_qlinear_add_cpu(self, use_relu, is_qat, is_dynamic): @parametrize("is_dynamic", [True, False]) def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic): self._qlinear_add_test_helper( - int8_mixed_bf16=True, + mixed_bf16=True, use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + @parametrize("use_relu", [True, False]) + @parametrize("mixed_bf16", [True, False]) + def test_fp8_qlinear_add_cpu(self, use_relu, mixed_bf16): + self._qlinear_add_test_helper( + use_relu=use_relu, + mixed_bf16=mixed_bf16, + is_qat=False, + is_dynamic=False, + is_fp8=True, + ) + def _qlinear_dequant_promotion_test_helper( self, inputs, device="cpu", - int8_mixed_bf16=False, + mixed_bf16=False, is_dynamic=False, matcher_check_fn=None, + is_fp8=False, ): class M(torch.nn.Module): def __init__( @@ -1850,9 +2154,10 @@ def default_matcher_check_fn(): if matcher_check_fn is not None else default_matcher_check_fn ), - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float, check_quantization=True, is_dynamic=is_dynamic, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -1872,12 +2177,52 @@ def test_qlinear_dequant_promotion_cpu(self): """ self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),)) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_dequant_promotion_cpu(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),), is_fp8=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + def test_qlinear_dequant_promotion_mixed_bf16(self): + r""" + Test with mixed_bf16 quantization. + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 4)),), mixed_bf16=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_dequant_promotion_int8_mixed_bf16(self): + @skipIfNoFloat8Support + def test_fp8_qlinear_dequant_promotion_mixed_bf16(self): r""" - Test with int8_mixed_bf16 quantization. + Test with mixed_bf16 quantization. This testcase test if dequant node before linear is promoted correctly: X | @@ -1890,7 +2235,7 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16(self): Y """ self._qlinear_dequant_promotion_test_helper( - (torch.randn((2, 4)),), int8_mixed_bf16=True + (torch.randn((2, 4)),), mixed_bf16=True, is_fp8=True ) @skipIfNoDynamoSupport @@ -1910,12 +2255,32 @@ def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self): """ self._qlinear_dequant_promotion_test_helper((torch.randn((2, 3, 4)),)) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 3, 4)),), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self): + def test_qlinear_dequant_promotion_mixed_bf16_input_dim_exceeds_2(self): r""" - Test with int8_mixed_bf16 quantization. + Test with mixed_bf16 quantization. This testcase test if dequant node before linear is promoted correctly: X | @@ -1928,7 +2293,29 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self): Y """ self._qlinear_dequant_promotion_test_helper( - (torch.randn((2, 3, 4)),), int8_mixed_bf16=True + (torch.randn((2, 3, 4)),), mixed_bf16=True + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_dequant_promotion_mixed_bf16_input_dim_exceeds_2(self): + r""" + Test with mixed_bf16 quantization. + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + self._qlinear_dequant_promotion_test_helper( + (torch.randn((2, 3, 4)),), mixed_bf16=True, is_fp8=True ) @skipIfNoDynamoSupport @@ -1994,6 +2381,41 @@ def matcher_check_fn(): check_quantization=True, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_qlinear_mul_cpu(self): + r""" + This testcase will quantize a Linear->Mul pattern. + """ + + class M(torch.nn.Module): + def __init__(self, use_bias): + super().__init__() + self.linear = torch.nn.Linear(4, 5, use_bias) + + def forward(self, x1, x2): + return torch.mul(self.linear(x1), x2) + + bias_list = [True, False] + for bias in bias_list: + mod = M(bias).eval() + x1 = torch.randn((2, 4)) + x2 = torch.randn((2, 5)) + + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 + ) + + self._test_common( + mod, + (x1, x2), + matcher_check_fn, + check_quantization=True, + is_fp8=True, + ) + @skipIfNoDynamoSupport def test_qmaxpool2d(self): r""" @@ -2344,7 +2766,7 @@ def test_da8w8_sym_act_sym_wgt_with_int_mm( self, has_bias, dtype, dynamic, reshape_a, M, inplace_add, expand_a_scale ): r""" - This testcase check if we can match the int8_dynamic_activation_int8_weight int8 linear pattern from torchao, + This testcase check if we can match the Int8DynamicActivationInt8WeightConfig int8 linear pattern from torchao, when activation is symmetrically quantized dynamically & weights are symmetrically quantized (statically) The pattern is: (no bias) _int_mm -> convert_element_type -> ([expand_a] -> mul) -> mul @@ -2426,78 +2848,6 @@ def matcher_check_fn(): if test_for_pointwise_binary: self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) - @skipIfNoONEDNN - @parametrize("has_bias", [True, False]) - @parametrize("dtype", [torch.float32, torch.bfloat16]) - @parametrize("input_dim_exceeds_two", [True, False]) - @parametrize("check_reuse_input", [True, False]) - def test_scaled_mm(self, has_bias, dtype, input_dim_exceeds_two, check_reuse_input): - class FP8QDQLinear(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.qtype = torch.float8_e4m3fn - self.weight = torch.randn((out_features, in_features)).to(self.qtype) - self.weight_scale = 2.0 - self.scale = 2.0 - self.bias = None - if has_bias: - self.bias = torch.randn((out_features,)).to(dtype) - - def forward(self, input): - weight = torch.ops.torchao.dequantize_affine_float8( - tensor=self.weight.data, - scale=torch.tensor(self.weight_scale), - output_dtype=torch.float, - ) - if dtype != torch.float: - weight = weight.to(dtype) - - q_input = torch.ops.torchao.quantize_affine_float8( - tensor=input, - scale=torch.tensor(self.scale), - float8_dtype=self.qtype, - ) - dq_input = torch.ops.torchao.dequantize_affine_float8( - tensor=q_input, - scale=torch.tensor(self.scale), - output_dtype=torch.float, - ) - if dtype != torch.float: - dq_input = dq_input.to(dtype) - - out = torch.nn.functional.linear(dq_input, weight, self.bias) - return out - - class Mod(torch.nn.Module): - def __init__(self, in_features, out_features, check_reuse_input): - super().__init__() - self.l0 = FP8QDQLinear(in_features, out_features) - self.check_reuse_input = check_reuse_input - if self.check_reuse_input: - self.l1 = FP8QDQLinear(in_features, out_features) - - def forward(self, x): - y = self.l0(x) - if self.check_reuse_input: - z = self.l1(x) - y += z - return y - - M1, M2, N, K = 2, 3, 13, 16 - M = M1 * M2 - mod = Mod(N, K, check_reuse_input) - if input_dim_exceeds_two: - v = torch.randn(M1, M2, N) - else: - v = torch.randn(M, N) - v = v.to(dtype) - - def matcher_check_fn(): - counter = 2 if check_reuse_input else 1 - self.assertEqual(counters["inductor"]["scaled_mm_matcher_count"], counter) - - self._test_common(mod, (v,), matcher_check_fn) - @dynamo_config.patch( { @@ -2506,7 +2856,7 @@ def matcher_check_fn(): "specialize_float": True, } ) -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8+") +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") class TestDynamicPatternMatcher(TestPatternMatcherBase): def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): r""" @@ -2589,16 +2939,14 @@ def matcher_check_fn(): is_qat=True, ) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - def test_q_attention_block(self): + def _test_q_attention_block_helper(self, annotate_matmul, is_fp8=False): class SelfAttnLikeModule(torch.nn.Module): def __init__( self, input_dim, - transpose_for_score=False, - num_attention_heads=None, - attention_head_size=None, + num_attention_heads, + attention_head_size, + annotate_matmul=False, ) -> None: super().__init__() self.input_dim = input_dim @@ -2606,12 +2954,16 @@ def __init__( self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False) self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False) self.softmax = torch.nn.Softmax(dim=-1) - self.transpose_for_score = transpose_for_score - if self.transpose_for_score: - assert num_attention_heads is not None - assert attention_head_size is not None - self.num_attention_heads = num_attention_heads - self.attention_head_size = attention_head_size + self.annotate_matmul = annotate_matmul + if self.annotate_matmul: + self.q_out_scale = 0.5 + self.k_out_scale = 0.6 + self.v_out_scale = 0.7 + self.attn_weights_scale = 0.8 + self.num_attention_heads = num_attention_heads + self.attention_head_size = attention_head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + ( @@ -2625,46 +2977,71 @@ def forward(self, x): q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) - if self.transpose_for_score: - q = self.transpose_for_scores(q) - k = self.transpose_for_scores(k) - v = self.transpose_for_scores(v) - scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) + q = self.transpose_for_scores(q) + k = self.transpose_for_scores(k) + v = self.transpose_for_scores(v) + k = k.transpose(-1, -2) + if self.annotate_matmul: + q = qdq(q, self.q_out_scale) + k = qdq(k, self.k_out_scale) + scores = torch.matmul(q, k) / (self.input_dim**0.5) attention = self.softmax(scores) + if self.annotate_matmul: + attention = qdq(attention, self.attn_weights_scale) + v = qdq(v, self.v_out_scale) weighted = torch.matmul(attention, v) - return weighted + weighted = weighted.permute(0, 2, 1, 3).contiguous() + weighted = weighted.reshape( + weighted.size()[:-2] + (self.all_head_size,) + ) + return self.dense(weighted) - for annotate_matmul in [False, True]: - mod = SelfAttnLikeModule( - input_dim=64 * 16, - transpose_for_score=True, - num_attention_heads=16, - attention_head_size=64, - ).eval() - v = torch.randn(2, 384, 1024) + mod = SelfAttnLikeModule( + input_dim=64 * 16, + num_attention_heads=16, + attention_head_size=64, + annotate_matmul=annotate_matmul and is_fp8, + ).eval() + v = torch.randn(2, 384, 1024) - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 - ) - self.assertEqual( - counters["inductor"]["qlinear_unary_matcher_count"], - 3 if annotate_matmul and not TEST_ACL else 0, - ) + def matcher_check_fn(): + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4 + ) + self.assertEqual( + counters["inductor"]["qlinear_unary_matcher_count"], + 3 if annotate_matmul and not TEST_ACL else 0, + ) - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) - if annotate_matmul: - quantizer.set_function_type_qconfig( - torch.matmul, quantizer.get_global_quantization_config() - ) + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + if annotate_matmul: + quantizer.set_function_type_qconfig( + torch.matmul, quantizer.get_global_quantization_config() + ) - self._test_common( - mod, - (v,), - matcher_check_fn, - check_quantization=True, - quantizer=quantizer, + self._test_common( + mod, + (v,), + matcher_check_fn, + check_quantization=True, + quantizer=quantizer, + is_fp8=is_fp8, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_q_attention_block(self): + for annotate_matmul in [True, False]: + self._test_q_attention_block_helper(annotate_matmul=annotate_matmul) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_fp8_q_attention_block(self): + for annotate_matmul in [True, False]: + self._test_q_attention_block_helper( + annotate_matmul=annotate_matmul, is_fp8=True ) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 4476b18697..0d46771a68 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -35,10 +35,7 @@ QUANT_ANNOTATION_KEY, X86InductorQuantizer, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training +from torchao.utils import torch_version_at_least class NodePosType(Enum): @@ -678,7 +675,7 @@ def _test_quantizer( # program capture m = copy.deepcopy(m_eager) - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -706,7 +703,7 @@ def _test_quantizer( @skipIfNoInductorSupport -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): @skipIfNoX86 def test_conv2d(self): @@ -1432,7 +1429,7 @@ def _test_linear_unary_helper( Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. """ use_bias_list = [True, False] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of torch.export.export inplace_list = [False] if post_op_algo_list is None: post_op_algo_list = [None] @@ -1572,7 +1569,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op is supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of torch.export.export inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( @@ -1676,7 +1673,7 @@ def test_linear_binary2(self): Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 """ example_inputs = (torch.randn(2, 16),) - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of torch.export.export inplace_add_list = [False] is_qat_list = [False, True] is_dynamic_list = [False, True] @@ -1745,9 +1742,9 @@ def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op and relu as unary post op are supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of export_for_training + # TODO test for inplace add after refactoring of torch.export.export inplace_add_list = [False] - # TODO test for inplace relu after refactoring of export_for_training + # TODO test for inplace relu after refactoring of torch.export.export inplace_relu_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( @@ -2355,7 +2352,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py new file mode 100644 index 0000000000..9a638b8f8f --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest +from contextlib import nullcontext +from typing import Tuple + +import torch +from torch._inductor.utils import run_and_get_code +from torch.testing import FileCheck +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + run_tests, +) + +from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + PerRow, + PerTensor, + quantize_, +) +from torchao.quantization.quantize_.common import KernelPreference +from torchao.quantization.utils import compute_error +from torchao.testing.utils import TorchAOIntegrationTestCase +from torchao.utils import ( + _is_fbgemm_genai_gpu_available, + is_sm_at_least_89, + is_sm_at_least_90, + torch_version_at_least, +) + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 128 + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) + self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") +class TestFloat8Tensor(TorchAOIntegrationTestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("mode", ["dynamic", "weight-only"]) + @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "kernel_preference", + [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], + ) + # Inputs are (M,..), K, N + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ], + ) + def test_fp8_linear_variants( + self, + dtype: torch.dtype, + mode: str, + compile: bool, + granularity, + kernel_preference: KernelPreference, + sizes: Tuple, + ): + if ( + isinstance(granularity, PerTensor) + and kernel_preference == KernelPreference.FBGEMM + ): + return unittest.skip( + "per tensor with fbgemm kernel preferece does not work yet" + ) + + error_message = None + if isinstance(granularity, PerRow): + if mode == "dynamic" and dtype != torch.bfloat16: + error_message = "PerRow quantization only works for bfloat16 precision" + + if mode == "weight-only" and kernel_preference != KernelPreference.AUTO: + return unittest.skip( + "weight only quant only uses AUTO kernel preference right now" + ) + + if kernel_preference == KernelPreference.FBGEMM and ( + (not _is_fbgemm_genai_gpu_available()) or (not is_sm_at_least_90()) + ): + return unittest.skip( + "Requires fbgemm_gpu_genai to run fbgemm kernel preference test" + ) + + error_context = ( + self.assertRaisesRegex(AssertionError, error_message) + if error_message + else nullcontext() + ) + + with error_context: + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + # Create a linear layer with bfloat16 dtype + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + + quantized_model = copy.deepcopy(model) + + if mode == "dynamic": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, + kernel_preference=kernel_preference, + ) + else: + assert mode == "weight-only", f"Unsupported mode: {mode}" + config = Float8WeightOnlyConfig() + + quantize_(quantized_model, config) + + if compile: + quantized_model = torch.compile(quantized_model, fullgraph=True) + + output_original = model(input_tensor) + output_quantized = quantized_model(input_tensor) + + error = compute_error(output_original, output_quantized) + assert compute_error(output_original, output_quantized) > 20, ( + f"Quantization error is too high got a SQNR of {error}" + ) + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @unittest.skipIf( + not is_sm_at_least_90(), + "Failing in SM89 right now: " + "AssertionError: tensor(False, device='cuda:0') is not true : sqnr: -2.90625, will fix a bit later", + ) + def test_slice(self, granularity): + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + dtype = torch.bfloat16 + device = "cuda" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, config) + weight1 = dummy.weight.clone().narrow(0, 0, 64) + weight2 = dummy.weight.clone().narrow(1, 0, 128) + self.assertEqual( + weight1.qdata, + dummy.weight.qdata.narrow(0, 0, 64), + ) + self.assertEqual( + weight2.qdata, + dummy.weight.qdata.narrow(1, 0, 128), + ) + if isinstance(granularity, PerRow): + self.assertEqual( + weight1.scale, + dummy.weight.scale.narrow(0, 0, 64), + ) + self.assertEqual( + weight2.scale, + dummy.weight.scale, + ) + else: + self.assertEqual( + weight1.scale, + dummy.weight.scale, + ) + self.assertEqual( + weight2.scale, + dummy.weight.scale, + ) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False) + res = dummy(input) + sqnr = compute_error(res, res_ref) + self.assertTrue(sqnr > 25, f"sqnr: {sqnr}") + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False) + res = dummy(input) + sqnr = compute_error(res, res_ref) + self.assertTrue(sqnr > 15, f"sqnr: {sqnr}") + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + # Inputs are (M,..), K, N + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ], + ) + def test_kernel_preference_numerical_equivalence(self, granularity, sizes): + """Test different kernel preferences have the same numerics for float8 dynamic activation + and float8 weight config + """ + M, N, K = sizes + dtype = torch.bfloat16 + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + # Create a linear layer with bfloat16 dtype + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + + # reference kernel preference and results + # we are using KerenelPreference.TORCH as the reference + kp_ref = KernelPreference.TORCH + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, kernel_preference=kp_ref + ) + quantized_model = copy.deepcopy(model) + quantize_(quantized_model, config) + res_ref = quantized_model(input_tensor) + + other_kernel_preferences = [ + KernelPreference.AUTO, + ] + if ( + _is_fbgemm_genai_gpu_available() + and is_sm_at_least_90() + and not isinstance(granularity, PerTensor) + ): + other_kernel_preferences.append(KernelPreference.FBGEMM) + + quantized_outputs = {} + for kp in other_kernel_preferences: + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, kernel_preference=kp + ) + quantized_model = copy.deepcopy(model) + quantize_(quantized_model, config) + quantized_outputs[kp] = quantized_model(input_tensor) + + from torchao.quantization.utils import compute_error + + # comparing numerics between different kernel preferences, using TORCH as the standard + kp_and_res = list(quantized_outputs.items()) + for i in range(len(kp_and_res)): + kp, res = kp_and_res[i] + self.assertTrue( + compute_error(res, res_ref) > 28, + f"mismatch between {kp=} and {kp_ref}, {sizes=}, {granularity=}", + ) + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_slice_preserves_aliasing(self, granularity): + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_slice_and_copy_similar_to_vllm(self, granularity): + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + self._test_slice_and_copy_similar_to_vllm(config) + + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def test_bmm(self): + # only support per row quantization + config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(10, 32, 128, dtype=dtype, device=device) + weight = torch.randn(10, 128, 256, dtype=dtype, device=device) + m = M(weight).eval() + original = m(input) + # we need to transpose the weight first for bmm + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) + quantize_(m, config, filter_fn=lambda x, fqn: True) + quantized = m(input) + self.assertTrue(compute_error(original, quantized) > 20) + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ((2, 32, 128), 64, 256), + ], + ) + def test_to_device(self, granularity, sizes): + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + M, N, K = sizes + dtype = torch.bfloat16 + for device in self.GPU_DEVICES: + input_tensor = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device) + linear(input_tensor) + + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device=device) + linear(input_tensor) + + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device) + linear(input_tensor) + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ((2, 32, 128), 64, 256), + ], + ) + def test_cat(self, granularity, sizes): + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + dtype = torch.bfloat16 + device = "cuda" + M, N, K = sizes + linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device) + linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device) + input_cat1 = torch.randn(*M, K, dtype=dtype, device=device) + + cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + dummy_linear1 = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) + + dummy_linear1.weight = torch.nn.Parameter(cat_weight1) + quantize_(dummy_linear1, config) + + quantize_(linear1, config) + quantize_(linear2, config) + + cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + self.assertTrue(cat_qweight1.shape, (2 * N, K)) + self.assertEqual( + dummy_linear1.weight.qdata, + cat_qweight1.qdata, + ) + self.assertEqual( + dummy_linear1.weight.scale, + cat_qweight1.scale, + ) + + # making sure cat_qweight1 can be used for inference + dummy_linear1.weight = torch.nn.Parameter(cat_qweight1, requires_grad=False) + dummy_linear1(input_cat1) + + # align the scale before concatenation + linear2.weight.scale = linear1.weight.scale + cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1) + self.assertTrue(cat_qweight2.shape, (N, 2 * K)) + ref_data = torch.cat( + [ + linear1.weight.qdata, + linear2.weight.qdata, + ], + dim=1, + ) + ref_scale = linear1.weight.scale + self.assertEqual(cat_qweight2.qdata, ref_data) + self.assertEqual(cat_qweight2.scale, ref_scale) + + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def test_moe_weight_reshape_ops(self): + # only per row quantization is supported for bmm + granularity = PerRow() + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + self._test_moe_weight_reshape_ops(config) + + # TODO: we have some other tests living in https://github.com/pytorch/ao/blob/4ecc89edd7b5cfc12e6f80854c85d04c472a0eb0/test/dtypes/test_affine_quantized_float.py#L743 + # that should be moved here after v1 config is deprecated: + # https://github.com/pytorch/ao/issues/2649 + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def test_expected_gpu_kernel_fbgemm(self): + """Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels + and the bias add happens in the gemm kernel for per row quantization + """ + torch.compiler.reset() + + M, K, N = 128, 256, 512 + m = torch.nn.Sequential( + torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) + ) + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + kernel_preference=KernelPreference.FBGEMM, + ) + quantize_(m, config) + m = torch.compile(m) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + out, code = run_and_get_code(m, x) + + # 1. check at least one occurrence of the quantize op and rowwise gemm op + # 2. check that there are no additional kernels like `triton_poi_fused_add_0` + # are run, since the bias add should happen in the `f8f8bf16_rowwise.default` + # op instead of separately + FileCheck().check_count( + "torch.ops.triton.quantize_fp8_row.default(", 1 + ).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default(", 1).check_not( + ".run(" + ).run(code[0]) + + +common_utils.instantiate_parametrized_tests(TestFloat8Tensor) + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py new file mode 100644 index 0000000000..56994b2639 --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.testing.utils import skip_if_rocm +from torchao.utils import torch_version_at_least + +BF16_ACT_CONFIG = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="marlin_sparse", +) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +class TestInt4MarlinSparseTensor(TestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @skip_if_rocm("ROCm enablement in progress") + @parametrize("config", [BF16_ACT_CONFIG]) + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + def test_linear(self, config, sizes): + dtype = torch.float16 + device = "cuda" + + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + + apply_fake_sparsity(linear) + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @skip_if_rocm("ROCm enablement in progress") + @unittest.skip("Fix later") + @parametrize("config", [BF16_ACT_CONFIG]) + def test_to_device(self, config): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + @skip_if_rocm("ROCm enablement in progress") + @parametrize("config", [BF16_ACT_CONFIG]) + def test_module_path(self, config): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear.cuda(), config) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestInt4MarlinSparseTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py new file mode 100644 index 0000000000..456f834389 --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.quantize_.common import SupportsActivationPreScaling +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + + +def get_config(group_size, use_hqq): + return Int4WeightOnlyConfig( + group_size=group_size, + int4_packing_format="opaque", + int4_choose_qparams_algorithm="hqq" if use_hqq else "tinygemm", + ) + + +@unittest.skipIf(not torch_version_at_least("2.6.0"), "Need pytorch 2.6+") +class TestInt4OpaqueTensor(TestCase): + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + @parametrize("group_size", [32, 64, 128]) + @parametrize("use_hqq", [True, False]) + def test_linear(self, sizes, dtype, group_size, use_hqq): + device = "cpu" + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, get_config(group_size, use_hqq)) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + @parametrize("use_hqq", [True, False]) + def test_module_path(self, dtype, use_hqq): + linear = torch.nn.Linear(128, 256, dtype=dtype) + quantize_(linear, get_config(group_size=128, use_hqq=use_hqq)) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + @parametrize("use_hqq", [True, False]) + def test_activation_prescaling(self, use_hqq): + dtype = torch.bfloat16 + input = torch.randn(1, 128, dtype=dtype) + linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype) + original_output = linear(input) + quantize_(linear, get_config(group_size=128, use_hqq=use_hqq)) + qw = linear.weight + assert isinstance(qw, SupportsActivationPreScaling), ( + "Expected int4 tensor supports activation prescaling" + ) + assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" + _ACT_PRE_SCALE = 2 + manual_scaled_quantized = linear(input * _ACT_PRE_SCALE) + qw.act_pre_scale = _ACT_PRE_SCALE + auto_scaled_quantized = linear(input) + + # Making sure activation pre scaling is successfully applied to the activation. + self.assertEqual(manual_scaled_quantized, auto_scaled_quantized) + + # If pre-scaling is auto-applied, the quantization error should be low, + # i.e., compute_error (SQNR) is high + self.assertTrue( + compute_error(original_output * _ACT_PRE_SCALE, auto_scaled_quantized) > 20 + ) + + +instantiate_parametrized_tests(TestInt4OpaqueTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py new file mode 100644 index 0000000000..becb44a5e0 --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.quantize_.common import SupportsActivationPreScaling +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + + +def get_config(group_size): + return Int4WeightOnlyConfig( + group_size=group_size, + int4_packing_format="plain_int32", + ) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.xpu.is_available(), "XPU not available") +class Int4PlainInt32Tensor(TestCase): + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 12), + ], + ) + @parametrize("dtype", [torch.bfloat16, torch.half]) + @parametrize("group_size", [32, 64, 128]) + def test_linear(self, sizes, dtype, group_size): + device = "xpu" + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, get_config(group_size)) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @parametrize("dtype", [torch.bfloat16, torch.half]) + def test_module_path(self, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu") + quantize_(linear, get_config(group_size=128)) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + def test_activation_prescaling(self): + dtype = torch.bfloat16 + device = "xpu" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, get_config(128)) + qw = linear.weight + assert isinstance(qw, SupportsActivationPreScaling), ( + "Expected int4 tensor supports activation prescaling" + ) + assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" + _ACT_PRE_SCALE = 2 + qw.act_pre_scale = _ACT_PRE_SCALE + quantized = linear(input) + + # making sure activation pre scaling is successfully applied to the activation + self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) + + +instantiate_parametrized_tests(Int4PlainInt32Tensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py new file mode 100644 index 0000000000..3c919740ae --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import ( + Float8DynamicActivationInt4WeightConfig, + Int4PreshuffledTensor, + Int4WeightOnlyConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + _is_fbgemm_genai_gpu_available, + is_sm_at_least_90, + torch_version_at_least, +) + +BF16_ACT_CONFIG = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="preshuffled", +) + +# only 128 group_size is supported +FP8_ACT_CONFIG = Float8DynamicActivationInt4WeightConfig( + int4_packing_format="preshuffled", +) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") +@unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" +) +class TestInt4PreshuffledTensor(TestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_linear(self, config): + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449` + # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG]) + @parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG]) + def test_bmm(self, bmm_config): + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(10, 32, 128, dtype=dtype, device=device) + weight = torch.randn(10, 128, 256, dtype=dtype, device=device) + m = M(weight).eval() + original = m(input) + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) + quantize_(m, bmm_config, filter_fn=lambda x, fqn: True) + quantized = m(input) + self.assertTrue(compute_error(original, quantized) > 18) + + def test_from_int4_tensor(self): + """Test that constructing Int4PreshuffledTensor from Int4Tensor + is the same as quantizing the original weight to Int4PreshuffledTensor + """ + int4_config = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="plain", + ) + int4_preshuffled_config = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="preshuffled", + ) + linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear2 = copy.deepcopy(linear1) + + quantize_(linear1, int4_config) + quantize_(linear2, int4_preshuffled_config) + + # now convert the linear1.weight to Int4PreshuffledTensor + w1_preshuffled = Int4PreshuffledTensor.from_int4_tensor(linear1.weight) + linear1.weight = torch.nn.Parameter(w1_preshuffled, requires_grad=False) + + example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),) + + output1 = linear1(*example_inputs) + output2 = linear2(*example_inputs) + self.assertEqual(output1, output2) + + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_to_device(self, config): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + linear.to(device) + + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_module_path(self, config): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, config) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestInt4PreshuffledTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py new file mode 100644 index 0000000000..f438d9c3db --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import Int4WeightOnlyConfig, quantize_ +from torchao.quantization.quantize_.common import SupportsActivationPreScaling +from torchao.quantization.utils import compute_error +from torchao.testing.utils import TorchAOIntegrationTestCase +from torchao.utils import ( + _is_fbgemm_genai_gpu_available, + is_sm_at_least_90, + torch_version_at_least, +) + + +@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") +@unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" +) +class TestInt4Tensor(TorchAOIntegrationTestCase): + def setUp(self): + self.config = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="plain", + ) + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + def test_linear(self): + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, self.config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + def test_slice(self): + dtype = torch.bfloat16 + device = "cuda" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, self.config) + weight1 = dummy.weight.narrow(0, 0, 64) + weight2 = dummy.weight.narrow(1, 0, 128) + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64)) + self.assertEqual(weight1.zero_point, dummy.weight.zero_point.narrow(1, 0, 64)) + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 64)) + self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1)) + self.assertEqual(weight2.zero_point, dummy.weight.zero_point.narrow(0, 0, 1)) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 20 + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 15 + + def test_slice_preserves_aliasing(self): + config = self.config + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr() + + def test_slice_and_copy_similar_to_vllm(self): + self._test_slice_and_copy_similar_to_vllm(self.config) + + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def test_bmm(self): + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(10, 32, 128, dtype=dtype, device=device) + weight = torch.randn(10, 128, 256, dtype=dtype, device=device) + m = M(weight).eval() + original = m(input) + # we need to transpose the weight first for bmm + m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) + quantize_(m, self.config, filter_fn=lambda x, fqn: True) + quantized = m(input) + self.assertTrue(compute_error(original, quantized) > 18) + + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ((2, 32, 128), 64, 256), + ], + ) + def test_to_device(self, sizes): + config = self.config + M, N, K = sizes + dtype = torch.bfloat16 + for device in self.GPU_DEVICES: + input_tensor = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device) + linear(input_tensor) + + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device=device) + linear(input_tensor) + + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) + linear.to(device) + linear(input_tensor) + + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ((2, 32, 128), 64, 256), + ], + ) + def test_cat(self, sizes): + config = self.config + dtype = torch.bfloat16 + device = "cuda" + M, N, K = sizes + linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device) + linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device) + input_cat1 = torch.randn(*M, K, dtype=dtype, device=device) + + cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + dummy_linear1 = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) + + dummy_linear1.weight = torch.nn.Parameter(cat_weight1) + quantize_(dummy_linear1, config) + + quantize_(linear1, config) + quantize_(linear2, config) + + cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + self.assertTrue(cat_qweight1.shape, (2 * N, K)) + self.assertEqual( + dummy_linear1.weight.qdata, + cat_qweight1.qdata, + ) + self.assertEqual( + dummy_linear1.weight.scale, + cat_qweight1.scale, + ) + self.assertEqual( + dummy_linear1.weight.zero_point, + cat_qweight1.zero_point, + ) + + # making sure cat_qweight1 can be used for inference + dummy_linear1.weight = torch.nn.Parameter(cat_qweight1, requires_grad=False) + dummy_linear1(input_cat1) + + # align the scale and zero_point before concatenation + linear2.weight.scale = linear1.weight.scale + linear2.weight.zero_point = linear1.weight.zero_point + cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1) + self.assertTrue(cat_qweight2.shape, (N, 2 * K)) + ref_data = torch.cat( + [ + linear1.weight.qdata, + linear2.weight.qdata, + ], + dim=1, + ) + ref_scale = linear1.weight.scale + ref_zero_point = linear1.weight.zero_point + self.assertEqual(cat_qweight2.qdata, ref_data) + self.assertEqual(cat_qweight2.scale, ref_scale) + self.assertEqual(cat_qweight2.zero_point, ref_zero_point) + + def test_moe_weight_reshape_ops(self): + self._test_moe_weight_reshape_ops(self.config) + + def test_activation_prescaling(self): + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, self.config) + qw = linear.weight + assert isinstance(qw, SupportsActivationPreScaling), ( + "Expected int4 tensor supports activation prescaling" + ) + assert qw.act_pre_scale is None, "Default `act_pre_scale` is None" + _ACT_PRE_SCALE = 2 + qw.act_pre_scale = _ACT_PRE_SCALE + quantized = linear(input) + + # making sure activation pre scaling is successfully applied to the activation + self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) + + +instantiate_parametrized_tests(TestInt4Tensor) + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py new file mode 100644 index 0000000000..9fe9fddfb8 --- /dev/null +++ b/test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.quantization import Int4WeightOnlyConfig, quantize_ +from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import ( + Int4TilePackedTo4dTensor, +) +from torchao.quantization.utils import compute_error +from torchao.testing.utils import TorchAOIntegrationTestCase +from torchao.utils import is_sm_at_least_90 + +INT4_CONFIG = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="tile_packed_to_4d", +) + +INT4_HQQ_CONFIG = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", +) + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Need sm90+") +class TestInt4TilePackedTo4dTensor(TorchAOIntegrationTestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 128), + ], + ) + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_linear(self, sizes, config): + dtype = torch.bfloat16 + device = "cuda" + + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + + original = linear(input) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + compiled_linear = torch.compile(linear) + quantized_and_compiled = compiled_linear(input) + self.assertTrue(compute_error(original, quantized_and_compiled) > 20) + + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_module_path(self, config): + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear.cuda(), config) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_slice(self, config): + """Note: we use multiples of 1024 for both in_features and out_features + so that padding does not affect the weight after slicing + """ + dtype = torch.bfloat16 + device = "cuda" + + # Create a 2048x2048 linear layer for testing + dummy = torch.nn.Linear(2048, 2048, bias=False, dtype=dtype, device=device) + + # Create reference sliced linear layers + dummy1 = torch.nn.Linear(2048, 1024, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 1024), requires_grad=False + ) + dummy2 = torch.nn.Linear(1024, 2048, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 1024), requires_grad=False + ) + + # Quantize the main linear layer + quantize_(dummy, config) + + # Shape analysis for TilePackedTo4d format: + # Original weight shape: (2048, 2048) -> no padding needed (already multiple of 1024) + # n = 2048, k = 2048, inner_k_tiles = 8, group_size = 128 + # + # qdata shape: [n/8, k/(inner_k_tiles*16), 32, inner_k_tiles/2] + # = [2048/8, 2048/(8*16), 32, 8/2] + # = [256, 16, 32, 4] + # + # scale_and_zero shape: [in_features/group_size, out_features, 2] (packed format) + # = [2048/128, 2048, 2] = [16, 2048, 2] + + # Test slicing along output dimension (dim=0: 2048 -> 1024) + weight1 = dummy.weight.narrow(0, 0, 1024) + + # qdata slicing: narrow from [256, 16, 32, 4] to [128, 16, 32, 4] + # Calculation: 1024 out_features / 2048 total * 256 qdata_dim0 = 128 + expected_qdata_slice_0 = dummy.weight.qdata.narrow(0, 0, 128) + self.assertEqual(weight1.qdata, expected_qdata_slice_0) + + # scale_and_zero slicing: narrow from [16, 2048, 2] to [16, 1024, 2] + # slicing 0th dim of qdata means we have to slice 1th dim of scale_and_zero + expected_scale_zero_slice_0 = dummy.weight.scale_and_zero.narrow(1, 0, 1024) + self.assertEqual(weight1.scale_and_zero, expected_scale_zero_slice_0) + + # Test slicing along input dimension (dim=1: 2048 -> 1024) + weight2 = dummy.weight.narrow(1, 0, 1024) + + # qdata slicing: narrow from [256, 16, 32, 4] to [256, 8, 32, 4] + # k = 2048 + # Calculation: 1024 in_features (1/2 of in_features) corresponds to 1/2 of qdata dimension 1 + # which is k / (inner_k_tiles * 16) / 2 = 2048 / (8 * 16) / 2 = 8 + expected_qdata_slice_1 = dummy.weight.qdata.narrow(1, 0, 8) + self.assertEqual(weight2.qdata, expected_qdata_slice_1) + + # scale_and_zero slicing: narrow from [16, 2048, 2] to [8, 2048, 2] + expected_scale_zero_slice_1 = dummy.weight.scale_and_zero.narrow(0, 0, 8) + self.assertEqual(weight2.scale_and_zero, expected_scale_zero_slice_1) + + # Verify that sliced weights produce similar results to reference implementations + input1 = torch.randn(2, 2048, dtype=dtype, device=device) + res_ref1 = dummy1(input1) + + # Create a new linear layer with the sliced weight + test_linear1 = torch.nn.Linear( + 2048, 1024, bias=False, dtype=dtype, device=device + ) + test_linear1.weight = torch.nn.Parameter( + weight1.contiguous(), requires_grad=False + ) + res1 = test_linear1(input1) + self.assertGreater(compute_error(res_ref1, res1), 14) + + input2 = torch.randn(2, 1024, dtype=dtype, device=device) + res_ref2 = dummy2(input2) + + # Create a new linear layer with the sliced weight + test_linear2 = torch.nn.Linear( + 1024, 2048, bias=False, dtype=dtype, device=device + ) + test_linear2.weight = torch.nn.Parameter( + weight2.contiguous(), requires_grad=False + ) + res2 = test_linear2(input2) + self.assertGreater(compute_error(res_ref2, res2), 14) + + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_slice_preserves_aliasing(self, config): + l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert ( + param.data.scale_and_zero.data_ptr() == param_data.scale_and_zero.data_ptr() + ) + + def test_cant_initialize_in_cpu(self): + config = INT4_CONFIG + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + # make sure there is no cpu implementation of the packing op currently + with self.assertRaisesRegex( + NotImplementedError, + "Could not run 'aten::_convert_weight_to_int4pack' with arguments from the 'CPU' backend. ", + ): + quantize_(linear, config) + + def test_to_device(self): + # test calling to on the tensor that's already on the same device works + config = INT4_CONFIG + + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) + quantize_(linear, config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) + quantize_(linear, config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) + quantize_(linear, config) + linear.to(device) + + @parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG]) + def test_slice_and_copy_similar_to_vllm(self, config): + self._test_slice_and_copy_similar_to_vllm(config) + + @parametrize("device", ["cuda"]) + @parametrize("dtype", [torch.bfloat16]) + def test_mm_int4wo(self, device, dtype): + weight = torch.randn(512, 1024).to(device).to(dtype) + weight = weight.t() + + l = torch.nn.Linear(512, 1024).to(device).to(dtype) + l.weight = torch.nn.Parameter(weight) + quantize_(l, INT4_CONFIG) + # weight shape: 1024 x 512 + weight = l.weight + + input = torch.randn(1, 512, device=device, dtype=dtype) + # make sure it runs + torch.nn.functional.linear(input, weight) + + @parametrize("group_size", [32, 64, 128]) + def test_different_group_sizes(self, group_size): + """Test with different group sizes""" + dtype = torch.bfloat16 + device = "cuda" + hp_tensor = torch.randn(256, 512, dtype=dtype, device=device) + block_size = (1, group_size) + + tensor = Int4TilePackedTo4dTensor.from_hp(hp_tensor, block_size) + + self.assertEqual(tensor.shape, hp_tensor.shape) + self.assertEqual(tensor.block_size, block_size) + + def test_error_conditions(self): + """Test various error conditions""" + dtype = torch.bfloat16 + device = "cuda" + hp_tensor = torch.randn(128, 256, dtype=dtype, device=device) + + # Test invalid block_size length + with self.assertRaises(AssertionError): + Int4TilePackedTo4dTensor.from_hp( + hp_tensor, (64,) + ) # block_size length mismatch + + # Test non-groupwise quantization + with self.assertRaises(AssertionError): + Int4TilePackedTo4dTensor.from_hp( + hp_tensor, (2, 64) + ) # first element should be 1 + + +instantiate_parametrized_tests(TestInt4TilePackedTo4dTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py new file mode 100644 index 0000000000..93458aaead --- /dev/null +++ b/test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import tempfile +import unittest + +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + MappingType, + quantize_, +) +from torchao.quantization.quantize_.workflows import IntxPackingFormat +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + _is_kernel_library_loaded, +) +from torchao.quantization.utils import compute_error + + +def _get_accuracy_test_cases(): + MODEL_DTYPES = [ + torch.float32, + torch.bfloat16, + ] + + PACKING_FORMATS = [ + IntxPackingFormat.UNPACKED_TO_INT8, + IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI, + IntxPackingFormat.OPAQUE_TORCHAO_AUTO, + IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI, + IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT, + ] + + WEIGHT_DTYPES = [ + torch.int1, + torch.int2, + torch.int3, + torch.int4, + torch.int5, + torch.int6, + torch.int7, + torch.int8, + ] + + MAPPING_TYPES = [ + MappingType.SYMMETRIC, + MappingType.ASYMMETRIC, + MappingType.SYMMETRIC_NO_CLIPPING_ERR, + ] + + GRANULARITIES = [PerGroup(128), PerAxis(0)] + + def _is_valid_test_combination( + model_dtype, + packing_format, + weight_dtype, + weight_mapping_type, + weight_granularity, + ): + # ATEN restrictions + if packing_format == IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI: + if weight_dtype != torch.int4: + return False + if weight_mapping_type == MappingType.ASYMMETRIC: + return False + if model_dtype != torch.float32: + return False + + # TORCHAO_KLEIDIAI restrictions + if packing_format == IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI: + if weight_dtype != torch.int4: + return False + if weight_mapping_type == MappingType.ASYMMETRIC: + return False + + # SYMMETRIC_NO_CLIPPING_ERR does not work well with int1 + if ( + weight_dtype == torch.int1 + and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR + ): + return False + + return True + + test_cases = [ + param( + model_dtype=mdt, + packing_format=pf, + weight_dtype=dt, + weight_mapping_type=mt, + weight_granularity=gr, + ) + for mdt in MODEL_DTYPES + for pf in PACKING_FORMATS + for dt in WEIGHT_DTYPES + for mt in MAPPING_TYPES + for gr in GRANULARITIES + if _is_valid_test_combination(dt, pf, dt, mt, gr) + ] + + return test_cases + + +@unittest.skipIf(not _is_kernel_library_loaded(), "Kernel library not loaded") +class TestIntxOpaqueTensor(TestCase): + @parameterized.expand( + _get_accuracy_test_cases(), + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_accuracy( + self, + model_dtype, + packing_format, + weight_dtype, + weight_mapping_type, + weight_granularity, + ): + """ + Checks the accuracy of packed layouts + """ + m = 3 + n = 1071 + k = 2048 + activations = torch.randn(m, k).to(model_dtype) + model = torch.nn.Sequential( + *[torch.nn.Linear(k, k, bias=False), torch.nn.Linear(k, n, bias=True)] + ).to(model_dtype) + + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=weight_granularity, + weight_mapping_type=weight_mapping_type, + intx_packing_format=packing_format, + version=2, + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=weight_granularity, + weight_mapping_type=weight_mapping_type, + intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8, + version=2, + ), + ) + + with torch.no_grad(): + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + + sqnr = compute_error(result, expected_result) + self.assertTrue(sqnr > 30, f"Got SQNR of {sqnr}") + + def test_export_compile_aoti( + self, + ): + m = 3 + k0 = 512 + k1 = 256 + k2 = 128 + k3 = 1024 + weight_dtype = torch.int4 + weight_granularity = PerAxis(0) + weight_mapping_type = MappingType.ASYMMETRIC + + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, k2, bias=True), + torch.nn.Linear(k2, k3, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(2, 1, m, k0, dtype=torch.float32) + dynamic_shapes = { + "input": { + 0: torch.export.Dim.AUTO, + 1: torch.export.Dim.STATIC, + 2: torch.export.Dim.AUTO, + 3: torch.export.Dim.STATIC, + } + } + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=weight_granularity, + weight_mapping_type=weight_mapping_type, + intx_packing_format=IntxPackingFormat.OPAQUE_TORCHAO_AUTO, + version=2, + ), + ) + eager_results = model(activations) + + # Export + exported = torch.export.export( + model, (activations,), strict=True, dynamic_shapes=dynamic_shapes + ) + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) + + # Compile + compiled = torch.compile(model) + with torch.no_grad(): + compiled_results = compiled(activations) + self.assertTrue(torch.allclose(eager_results, compiled_results)) + + # AOTI + with tempfile.TemporaryDirectory() as tmpdirname: + package_path = f"{tmpdirname}/model.pt2" + torch._inductor.aoti_compile_and_package( + exported, package_path=package_path + ) + fn = torch._inductor.aoti_load_package(package_path) + aoti_results = fn(activations) + self.assertTrue(torch.allclose(eager_results, aoti_results)) + + @parameterized.expand( + [ + param(packing_format=pf) + for pf in [ + IntxPackingFormat.OPAQUE_TORCHAO_AUTO, + IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI, + ] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_serialization(self, packing_format): + layers = [ + torch.nn.Linear(512, 256), + ] + model = torch.nn.Sequential(*layers) + model2 = torch.nn.Sequential(*layers) + activations = torch.randn(1, 512, dtype=torch.float32) + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(64), + intx_packing_format=packing_format, + version=2, + ), + ) + expected = model(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(model.state_dict(), f"{tmpdirname}/model.pt") + state_dict = torch.load( + f"{tmpdirname}/model.pt", map_location="cpu", weights_only=True + ) + + # Load deserialized weights into model2 and check result + model2.load_state_dict(state_dict, assign=True) + actual = model2(activations) + self.assertTrue(torch.allclose(expected, actual)) + + def test_moe_quant_intx(self): + from torchao.prototype.moe_quant.quantizable_moe_modules import ( + MOEFeedForwardAOQuantizable, + ) + from torchao.prototype.moe_quant.utils import ( + FakeExtraDimTensor, + MoEQuantConfig, + UseFakeExtraDimTensor, + cond_ffn_filter, + ) + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + quantize_, + ) + from torchao.quantization.utils import compute_error + + with torch.device("cpu"): + model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to( + torch.float32 + ) + x = torch.randn(8, 512, dtype=torch.float32) + + out = model(x).clone() + + base_config = Int8DynamicActivationIntxWeightConfig( + intx_packing_format=IntxPackingFormat.OPAQUE_TORCHAO_AUTO, + version=2, + ) + moe_config = MoEQuantConfig( + base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + ) + + quantize_(model, moe_config, cond_ffn_filter) + + out_q = model(x).clone() + assert isinstance(model.experts.w1, FakeExtraDimTensor) + + mod_c = torch.compile(model, mode="reduce-overhead") + + mod_c(x) + mod_c(x) + + out_qc = mod_c(x).clone() + + self.assertTrue(compute_error(out_q, out) > 30) + self.assertTrue(compute_error(out_qc, out) > 30) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py new file mode 100644 index 0000000000..9284c1890e --- /dev/null +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -0,0 +1,448 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import tempfile +import unittest + +import torch +from parameterized import param, parameterized +from torch.testing import FileCheck +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.dtypes import QDQLayout +from torchao.quantization import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + MappingType, + quantize_, +) +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig +from torchao.quantization.quantize_.workflows import IntxPackingFormat +from torchao.quantization.utils import compute_error +from torchao.utils import torch_version_at_least, unwrap_tensor_subclass + + +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Need pytorch 2.7+") +class TestIntxUnpackedToInt8Tensor(TestCase): + def setUp(self): + self.config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + version=2, + ) + + def test_embedding(self): + dtype = torch.bfloat16 + device = "cpu" + input = torch.randint(low=0, high=128, size=(10,), device=device) + embedding = torch.nn.Embedding(128, 256, dtype=dtype, device=device) + original = embedding(input) + quantize_(embedding, self.config) + quantized = embedding(input) + error = compute_error(original, quantized) + self.assertTrue(error > 20) + + def test_linear(self): + dtype = torch.bfloat16 + device = "cpu" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, self.config) + quantized = linear(input) + error = compute_error(original, quantized) + self.assertTrue(error > 20) + + def test_slice(self): + dtype = torch.bfloat16 + device = "cpu" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, self.config) + weight1 = dummy.weight.narrow(0, 0, 64) + weight2 = dummy.weight.narrow(1, 0, 128) + + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) + + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128)) + self.assertEqual(weight2.scale, dummy.weight.scale.narrow(1, 0, 4)) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 20 + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 15 + + def test_slice_and_copy_(self): + device = "cpu" + l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16) + quantize_(l, self.config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr() + + # dummy_l has random input (shouldn't be 0) + dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16) + quantize_(dummy_l, self.config) + quantized = dummy_l.weight + quantized = quantized.narrow(0, 0, 512) + + param_data.copy_(quantized) + + # making sure param.data is updated + assert param.data.qdata[0][0] == quantized.qdata[0][0] + + def test_to_dtype(self): + activations_bf16 = torch.randn(1, 128, dtype=torch.bfloat16) + activations_fp32 = torch.randn(1, 128, dtype=torch.float32) + activations_fp16 = torch.randn(1, 128, dtype=torch.float16) + + linear = torch.nn.Linear(128, 256) + quantize_(linear, self.config) + + linear.to(dtype=torch.float16) + linear(activations_fp16) + + linear.to(dtype=torch.float32) + linear(activations_fp32) + + linear.to(dtype=torch.bfloat16) + linear(activations_bf16) + + def test_export_intx_weight_only_config(self): + linear = torch.nn.Linear(128, 256) + quantize_(linear, self.config) + ep = torch.export.export(linear, (torch.randn(1, 128),)) + assert "torch.ops.torchao.dequantize_affine.default" in ep.graph_module.code + + def test_export_int8_dyn_act_intx_weight_config(self): + layers = [ + torch.nn.Linear(512, 256, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(1, 512, dtype=torch.float32) + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerAxis(0), + weight_mapping_type=MappingType.SYMMETRIC, + intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8, + version=2, + ), + ) + eager_results = model(activations) + + exported = torch.export.export(model, (activations,)) + + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) + + expected_counts = { + "torch.ops.torchao.choose_qparams_affine.default": 1, + "torch.ops.torchao.quantize_affine.default": 1, + "torch.ops.torchao.dequantize_affine.default": 2, + "torch.ops.aten.linear.default": 1, + "torch.ops.aten.reshape.default": 0, + } + for line, count in expected_counts.items(): + FileCheck().check_count(line, count, exactly=True).run( + exported.graph_module.code + ) + + def test_export_int8_dyn_act_intx_weight_config_with_unwrap(self): + layers = [ + torch.nn.Linear(512, 256, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(1, 512, dtype=torch.float32) + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(64), + weight_mapping_type=MappingType.SYMMETRIC, + intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8, + version=2, + ), + ) + eager_results = model(activations) + + unwrap_tensor_subclass(model) + + exported = torch.export.export(model, (activations,)) + + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) + + expected_counts = { + "torch.ops.torchao.choose_qparams_affine.default": 1, + "torch.ops.torchao.quantize_affine.default": 1, + "torch.ops.torchao.dequantize_affine.default": 2, + "torch.ops.aten.linear.default": 1, + "torch.ops.aten.reshape.default": 0, + } + for line, count in expected_counts.items(): + FileCheck().check_count(line, count, exactly=True).run( + exported.graph_module.code + ) + + def test_serialization_int8_dyn_act_intx_weight_config(self): + layers = [ + torch.nn.Linear(512, 256), + ] + model = torch.nn.Sequential(*layers) + model2 = torch.nn.Sequential(*layers) + activations = torch.randn(1, 512, dtype=torch.float32) + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(64), + intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8, + version=2, + ), + ) + expected = model(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(model.state_dict(), f"{tmpdirname}/model.pt") + state_dict = torch.load( + f"{tmpdirname}/model.pt", map_location="cpu", weights_only=True + ) + + # Load deserialized weights into model2 and check result + model2.load_state_dict(state_dict, assign=True) + actual = model2(activations) + self.assertTrue(torch.allclose(expected, actual)) + + def test_serialization_intx_weight_only_config(self): + layers = [ + torch.nn.Linear(512, 256), + ] + model = torch.nn.Sequential(*layers) + model2 = torch.nn.Sequential(*layers) + activations = torch.randn(1, 512, dtype=torch.float32) + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(64), + intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8, + version=2, + ), + ) + expected = model(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(model.state_dict(), f"{tmpdirname}/model.pt") + state_dict = torch.load( + f"{tmpdirname}/model.pt", map_location="cpu", weights_only=True + ) + + # Load deserialized weights into model2 and check result + model2.load_state_dict(state_dict, assign=True) + actual = model2(activations) + self.assertTrue(torch.allclose(expected, actual)) + + @parameterized.expand( + [ + param( + weight_dtype=weight_dtype, + group_size=group_size, + mapping_type=mapping_type, + scale_dtype=scale_dtype, + model_dtype=model_dtype, + ) + for weight_dtype in list(getattr(torch, f"int{x}") for x in range(1, 9)) + for group_size in [32, 64, 128] + for mapping_type in [MappingType.SYMMETRIC] + for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] + for model_dtype in [torch.float32, torch.bfloat16, torch.float16] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_qat_int8_dyn_act_intx_weight_config( + self, weight_dtype, group_size, mapping_type, scale_dtype, model_dtype + ): + activation_config = IntxFakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False, scale_precision=scale_dtype + ) + weight_config = IntxFakeQuantizeConfig( + weight_dtype, + group_size=group_size, + mapping_type=mapping_type, + scale_precision=scale_dtype, + ) + qat_config_prepare = QATConfig( + activation_config=activation_config, + weight_config=weight_config, + step="prepare", + ) + qat_config_convert = QATConfig( + step="convert", + ) + quant_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_config.dtype, + weight_granularity=PerGroup(group_size), + weight_mapping_type=mapping_type, + weight_scale_dtype=scale_dtype, + intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8, + version=2, + ) + + k0 = 512 + k1 = 256 + layers = [ + torch.nn.Linear(k0, k1), + torch.nn.Linear(k1, k0), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn( + k0, + ) + model = model.to(model_dtype) + activations = activations.to(model_dtype) + + quantize_(model, qat_config_prepare) + prepared_out = model(activations) + + quantize_(model, qat_config_convert) + converted_out = model(activations) + + quantize_( + model, + quant_config, + ) + quantizeed_out = model(activations) + + sqnr = compute_error(prepared_out, converted_out).item() + sqnr = compute_error(prepared_out, quantizeed_out).item() + + if model_dtype == scale_dtype: + self.assertTrue( + sqnr == float("inf"), + f"Got SQNR of {sqnr} between prepared and quantized", + ) + else: + # There is slight difference in how v2 does dynamic activation quantization + # It uses the model_dtype, whereas v1 always uses float32 + self.assertTrue( + sqnr > 35, f"Got SQNR of {sqnr} between prepared and quantized" + ) + + @parameterized.expand( + [ + param( + weight_dtype=weight_dtype, + group_size=group_size, + mapping_type=mapping_type, + act_mapping_type=act_mapping_type, + scale_dtype=scale_dtype, + model_dtype=model_dtype, + ) + for weight_dtype in list(getattr(torch, f"int{x}") for x in range(1, 9)) + for group_size in [32, 64, 128] + for mapping_type in [MappingType.SYMMETRIC] + for act_mapping_type in [MappingType.ASYMMETRIC] + for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] + for model_dtype in [torch.float32, torch.bfloat16, torch.float16] + ], + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", + ) + def test_intx_unpacked_v2_is_close_to_qdq_v1( + self, + weight_dtype, + group_size, + mapping_type, + act_mapping_type, + scale_dtype, + model_dtype, + ): + k0 = 512 + k1 = 256 + layers = [ + torch.nn.Linear(k0, k1), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn( + k0, + ) + + model = model.to(model_dtype) + activations = activations.to(model_dtype) + + model_v1 = copy.deepcopy(model) + quantize_( + model_v1, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=PerGroup(group_size), + weight_mapping_type=mapping_type, + weight_scale_dtype=scale_dtype, + act_mapping_type=act_mapping_type, + version=1, + layout=QDQLayout(), + ), + ) + out_v1 = model_v1(activations) + + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=PerGroup(group_size), + weight_mapping_type=mapping_type, + weight_scale_dtype=scale_dtype, + act_mapping_type=act_mapping_type, + intx_packing_format=IntxPackingFormat.UNPACKED_TO_INT8, + version=2, + ), + ) + out_v2 = model(activations) + sqnr = compute_error(out_v1, out_v2).item() + + if model_dtype == torch.float32 and model_dtype == torch.float32: + self.assertTrue(sqnr == float("inf"), f"Got SQNR of {sqnr}") + else: + # There is slight difference in how v2 does dynamic activation quantization + # It uses the model_dtype, whereas v1 always uses float32 + self.assertTrue(sqnr > 35, f"Got SQNR of {sqnr}") + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py new file mode 100644 index 0000000000..80094beb2d --- /dev/null +++ b/test/quantization/test_da8w4_cpu.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from torch._dynamo.utils import counters +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao import quantize_ +from torchao.dtypes import ( + Int8DynamicActInt4WeightCPULayout, + PlainLayout, +) +from torchao.quantization.quant_api import ( + Int8DynamicActivationInt4WeightConfig, +) +from torchao.quantization.quant_primitives import MappingType +from torchao.utils import torch_version_at_least + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64, bias=False): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=bias).to(torch.float) + self.linear2 = torch.nn.Linear(n, k, bias=bias).to(torch.float) + + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestDa8w4Cpu(TestCase): + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.7.0"), "Test only enabled for 2.7+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + @common_utils.parametrize("sym_quant_a", [True, False]) + def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): + if sym_quant_a and not torch_version_at_least("2.8.0"): + # not supported until PT 2.8 + return + device = "cpu" + m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) + m2 = copy.deepcopy(m) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + # Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout + # is that the former packs two int4 weights into one int8, while the latter does not. + quantize_( + m, + Int8DynamicActivationInt4WeightConfig( + group_size=32, + layout=Int8DynamicActInt4WeightCPULayout(), + act_mapping_type=MappingType.SYMMETRIC + if sym_quant_a + else MappingType.ASYMMETRIC, + ), + ) + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0] + quantize_( + m2, + Int8DynamicActivationInt4WeightConfig( + group_size=32, + layout=PlainLayout(), + act_mapping_type=MappingType.SYMMETRIC + if sym_quant_a + else MappingType.ASYMMETRIC, + ), + ) + torch._dynamo.reset() # may segfault without this + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) + atol, rtol = 4e-7, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 1e-2, 3e-3 + elif dtype == torch.half: + atol, rtol = 6e-3, 2e-3 + assert torch.allclose(y, y2, atol=atol, rtol=rtol) + # Test get_plain by dequantize() + dqw1 = m.linear1.weight.original_weight_tensor.dequantize() + dqw2 = m.linear2.weight.original_weight_tensor.dequantize() + dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() + dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() + assert torch.allclose(dqw1, dqw1_ref) + assert torch.allclose(dqw2, dqw2_ref) + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.8.0"), "Test only enabled for 2.8+") + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + def test_8da4w_concat_linear_cpu(self, x_dim, bias): + N, K = 64, 128 + + class Mod(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear1 = torch.nn.Linear(K, N, bias=bias) + self.linear2 = torch.nn.Linear(K, N, bias=bias) + self.linear3 = torch.nn.Linear(K, N, bias=bias) + + def forward(self, x): + a = self.linear1(x) + b = self.linear2(x) + c = self.linear3(x) + return a + b + c + + dtype = torch.bfloat16 + device = "cpu" + m = Mod(bias).eval().to(dtype).to(device) + x_shape = [2] * x_dim + x_shape[-1] = K + x = torch.rand(x_shape, dtype=dtype, device=device) + with torch.no_grad(): + quantize_( + m, + Int8DynamicActivationInt4WeightConfig( + group_size=32, + layout=Int8DynamicActInt4WeightCPULayout(), + act_mapping_type=MappingType.SYMMETRIC, + ), + ) + # Need to turn on freezing to get the pattern + # set enable_concat_linear to true to enable the fusion + with torch._inductor.config.patch( + {"freezing": True, "cpp.enable_concat_linear": True} + ): + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + x, + ) + # ensure the expected op occurs only once in the code after fusion + # The trailing "(" is to avoid matching the op in the comment + assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1 + + # Ensure that when concat linear is enabled, fxgraph cache works + # without being bypassed (fxgraph_cache_bypass = 0), indicating that + # DA8W4ConcatLinearCPUPass properly implements the CustomGraphPass + # interface and uuid() function, allowing fxgraph to be saved and hit + # on subsequent runs (fxgraph_cache_hit > 0). + fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"] + assert fx_cache_bypass_count == 0 + + with torch._inductor.config.patch( + {"freezing": True, "cpp.enable_concat_linear": False} + ): + y_ref, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + x, + ) + assert torch.allclose(y, y_ref) + + # Ensure that the fxgraph cache is also not bypassed when concat linear is disabled + fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"] + assert fx_cache_bypass_count == 0 + + +common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/test_gptq.py b/test/quantization/test_gptq.py index 98760f8cf6..6f7ac10d45 100644 --- a/test/quantization/test_gptq.py +++ b/test/quantization/test_gptq.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + import unittest from pathlib import Path @@ -12,9 +18,6 @@ from torchao._models.llama.tokenizer import get_tokenizer from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, -) torch.manual_seed(0) @@ -101,7 +104,6 @@ def test_gptq_quantizer_int4_weight_only(self): class TestMultiTensorFlow(TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ import MultiTensor @@ -114,7 +116,6 @@ def test_multitensor_add_tensors(self): self.assertTrue(torch.equal(mt.values[0], tensor1)) self.assertTrue(torch.equal(mt.values[1], tensor2)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ import MultiTensor @@ -126,7 +127,6 @@ def test_multitensor_pad_unpad(self): mt.unpad() self.assertEqual(mt.count, 1) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ import MultiTensor @@ -179,7 +179,7 @@ def test_gptq_with_input_recorder(self): model2 = copy.deepcopy(model) out = model(*test_input) - quantize_(model2, Int4WeightOnlyConfig()) + quantize_(model2, Int4WeightOnlyConfig(version=1)) outq = model2(*test_input) del model2 diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py similarity index 88% rename from torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py rename to test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py index 08548b9e9e..224e745ac4 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-2025 Arm Limited and affiliates. # All rights reserved. # # This source code is licensed under the license found in the @@ -9,15 +10,19 @@ import unittest import torch -from parameterized import param, parameterized from torch.testing import FileCheck +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, +) from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import ( - FakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, Int8DynActInt4WeightQATQuantizer, + IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, ) from torchao.quantization.quant_api import ( @@ -26,44 +31,42 @@ MappingType, quantize_, ) +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + _is_kernel_library_loaded, +) from torchao.quantization.utils import compute_error -class TestInt8DynamicActivationIntxWeight(unittest.TestCase): - TEST_ACCURACY_CASES = [ - param( - layout=layout, - weight_dtype=weight_dtype, - weight_mapping_type=weight_mapping_type, - weight_granularity=weight_granularity, - ) - for layout in [ - PackedLinearInt8DynamicActivationIntxWeightLayout(), - PackedLinearInt8DynamicActivationIntxWeightLayout(target="universal"), - ] - for weight_dtype in [ - torch.int1, - torch.int2, - torch.int3, - torch.int4, - torch.int5, - torch.int6, - torch.int7, - torch.int8, - ] - for weight_mapping_type in [ - MappingType.SYMMETRIC, - MappingType.ASYMMETRIC, - ] - for weight_granularity in [ - PerGroup(128), - PerAxis(0), - ] - ] - - @parameterized.expand( - TEST_ACCURACY_CASES, - name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", +@unittest.skipIf(not _is_kernel_library_loaded(), "Kernel library not loaded") +class TestInt8DynamicActivationIntxWeight(TestCase): + @parametrize( + "layout, weight_dtype, weight_mapping_type, weight_granularity", + [ + (layout, weight_dtype, weight_mapping_type, weight_granularity) + for layout in [ + PackedLinearInt8DynamicActivationIntxWeightLayout(), + PackedLinearInt8DynamicActivationIntxWeightLayout(target="universal"), + ] + for weight_dtype in [ + torch.int1, + torch.int2, + torch.int3, + torch.int4, + torch.int5, + torch.int6, + torch.int7, + torch.int8, + ] + for weight_mapping_type in [ + MappingType.SYMMETRIC, + MappingType.ASYMMETRIC, + MappingType.SYMMETRIC_NO_CLIPPING_ERR, + ] + for weight_granularity in [ + PerGroup(128), + PerAxis(0), + ] + ], ) def test_accuracy( self, layout, weight_dtype, weight_mapping_type, weight_granularity @@ -71,6 +74,12 @@ def test_accuracy( """ Checks the accuracy of packed layouts """ + if ( + weight_dtype == torch.int1 + and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR + ): + return + m = 3 n = 1071 k = 2048 @@ -93,6 +102,7 @@ def test_accuracy( weight_mapping_type=weight_mapping_type, weight_scale_dtype=weight_scale_dtype, layout=layout, + version=1, ), ) @@ -105,6 +115,7 @@ def test_accuracy( weight_mapping_type=weight_mapping_type, weight_scale_dtype=weight_scale_dtype, layout=self._reference_layout(), + version=1, ), ) @@ -138,6 +149,7 @@ def test_accuracy_kleidiai(self): layout=PackedLinearInt8DynamicActivationIntxWeightLayout( target="kleidiai" ), + version=1, ), ) @@ -150,6 +162,7 @@ def test_accuracy_kleidiai(self): weight_mapping_type=weight_mapping_type, weight_scale_dtype=weight_scale_dtype, layout=self._reference_layout(), + version=1, ), ) @@ -191,6 +204,7 @@ def test_accuracy_aten(self): weight_mapping_type=weight_mapping_type, weight_scale_dtype=weight_scale_dtype, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"), + version=1, ), ) @@ -203,6 +217,7 @@ def test_accuracy_aten(self): weight_mapping_type=weight_mapping_type, weight_scale_dtype=weight_scale_dtype, layout=self._reference_layout(), + version=1, ), ) @@ -213,7 +228,7 @@ def test_accuracy_aten(self): self._assert_close(result, expected_result) def _assert_close( - self, result, expected_result, mse_tol=1e-6, atol=1e-2, rtol=1e-5 + self, result, expected_result, mse_tol=1e-5, atol=5e-2, rtol=5e-5 ): mse_loss = torch.nn.functional.mse_loss(result, expected_result) self.assertTrue( @@ -262,6 +277,7 @@ def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout( weight_mapping_type=weight_mapping_type, weight_scale_dtype=torch.bfloat16, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + version=1, ), ) eager_results = model(activations) @@ -323,6 +339,7 @@ def test_export_dynamic_shape_PackedLinearInt8DynamicActivationIntxWeightLayout( weight_mapping_type=weight_mapping_type, weight_scale_dtype=torch.bfloat16, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + version=1, ), ) eager_results = model(activations) @@ -351,6 +368,7 @@ def test_export_QDQLayout(self): weight_granularity=PerGroup(64), weight_mapping_type=MappingType.SYMMETRIC, layout=QDQLayout(), + version=1, ), ) eager_results = model(activations) @@ -375,15 +393,12 @@ def test_export_QDQLayout(self): exported.graph_module.code ) - @parameterized.expand( + @parametrize( + "layout", [ - param(layout=layout) - for layout in [ - PackedLinearInt8DynamicActivationIntxWeightLayout(), - QDQLayout(), - ] + PackedLinearInt8DynamicActivationIntxWeightLayout(), + QDQLayout(), ], - name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) def test_serialization(self, layout): layers = [ @@ -399,6 +414,7 @@ def test_serialization(self, layout): weight_dtype=torch.int4, weight_granularity=PerGroup(64), layout=layout, + version=1, ), ) expected = model(activations) @@ -414,32 +430,16 @@ def test_serialization(self, layout): actual = model2(activations) self.assertTrue(torch.allclose(expected, actual)) - def test_moved_error(self): - from torchao.experimental.quant_api import Int8DynamicActivationIntxWeightConfig - - with self.assertRaisesRegex( - NotImplementedError, - "Int8DynamicActivationIntxWeightConfig has moved from torchao.experimental.quant_api to torchao.quantization.quant_api", - ): - config = Int8DynamicActivationIntxWeightConfig( # noqa: F841 - weight_dtype=torch.int4, - granularity=PerGroup(64), - ) - - @parameterized.expand( + @parametrize( + "group_size, mapping_type, act_mapping_type", [ - param( - group_size=group_size, - mapping_type=mapping_type, - act_mapping_type=act_mapping_type, - ) + (group_size, mapping_type, act_mapping_type) for group_size, mapping_type, act_mapping_type in zip( [32, 64], [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], ) ], - name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) def test_identical_to_Int8DynamicActivationInt4WeightConfig( self, group_size, mapping_type, act_mapping_type @@ -465,6 +465,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig( weight_mapping_type=mapping_type, weight_scale_dtype=None, act_mapping_type=act_mapping_type, + version=1, ), ) quantize_( @@ -479,15 +480,16 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig( sqnr = compute_error(model(activations), model_copy(activations)).item() self.assertTrue(sqnr == float("inf")) - @parameterized.expand( + @parametrize( + "weight_dtype, group_size, mapping_type, act_mapping_type, scale_dtype, model_dtype", [ - param( - weight_dtype=weight_dtype, - group_size=group_size, - mapping_type=mapping_type, - act_mapping_type=act_mapping_type, - scale_dtype=scale_dtype, - model_dtype=model_dtype, + ( + weight_dtype, + group_size, + mapping_type, + act_mapping_type, + scale_dtype, + model_dtype, ) for weight_dtype in list(getattr(torch, f"int{x}") for x in range(1, 9)) for group_size in [32, 64, 128] @@ -496,7 +498,6 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig( for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] for model_dtype in [torch.float32, torch.bfloat16, torch.float16] ], - name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) def test_identical_to_IntXQuantizationAwareTrainingConfig( self, @@ -530,12 +531,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( model = model.to(model_dtype) activations = activations.to(model_dtype) - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=is_act_symmetric, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( weight_dtype, group_size=group_size, is_symmetric=is_symmetric, @@ -563,6 +564,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( weight_mapping_type=mapping_type, weight_scale_dtype=scale_dtype, act_mapping_type=act_mapping_type, + version=1, ), ) converted_out = model(activations) @@ -570,18 +572,14 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( sqnr = compute_error(prepared_out, converted_out).item() self.assertTrue(sqnr == float("inf")) - @parameterized.expand( + @parametrize( + "group_size, scale_dtype, model_dtype", [ - param( - group_size=group_size, - scale_dtype=scale_dtype, - model_dtype=model_dtype, - ) + (group_size, scale_dtype, model_dtype) for group_size in [32, 64, 128] for scale_dtype in [torch.float32, torch.bfloat16, torch.float16] for model_dtype in [torch.float32, torch.bfloat16, torch.float16] ], - name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", ) def test_identical_to_Int8DynActInt4WeightQATQuantizer( self, group_size, scale_dtype, model_dtype @@ -617,6 +615,7 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer( weight_mapping_type=MappingType.SYMMETRIC, weight_scale_dtype=scale_dtype, act_mapping_type=MappingType.ASYMMETRIC, + version=1, ), ) converted_out1 = model(activations) @@ -655,7 +654,7 @@ def test_moe_quant_intx(self): out = model(x).clone() base_config = Int8DynamicActivationIntxWeightConfig( - layout=PackedLinearInt8DynamicActivationIntxWeightLayout() + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), version=1 ) moe_config = MoEQuantConfig( base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE @@ -677,5 +676,7 @@ def test_moe_quant_intx(self): self.assertGreater(compute_error(out_qc, out), 30) +instantiate_parametrized_tests(TestInt8DynamicActivationIntxWeight) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 8fe21c6bd3..e0733520ff 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -16,7 +16,7 @@ unpack_from_marlin_qqq, ) from torchao.quantization.quant_api import ( - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, quantize_, ) from torchao.quantization.quant_primitives import ( @@ -24,7 +24,6 @@ _choose_qparams_and_quantize_affine_qqq, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @skip_if_rocm("ROCm enablement in progress") @@ -54,7 +53,7 @@ def test_marlin_qqq(self): modelq = copy.deepcopy(self.model) quantize_( modelq, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=group_size, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, @@ -67,7 +66,6 @@ def test_marlin_qqq(self): "Results are not close" ) - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): @@ -79,7 +77,7 @@ def test_marlin_qqq_compile(self): modelq = copy.deepcopy(self.model) quantize_( modelq, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=group_size, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 425b881dba..61000babc1 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + import unittest import pytest @@ -27,11 +33,7 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - is_sm_at_least_90, -) +from torchao.utils import is_sm_at_least_90 if torch.version.hip is not None: pytest.skip( @@ -116,11 +118,10 @@ def _test_impl_moe_quant( def test_int4wo_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( - Int4WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE + Int4WeightOnlyConfig(version=1), + use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, ) tensor_impl_class = TensorCoreTiledAQTTensorImpl @@ -142,10 +143,8 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): self.skipTest("Need CUDA available") if not is_sm_at_least_90(): self.skipTest("Requires CUDA capability >= 9.0") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") - config = MoEQuantConfig(Int4WeightOnlyConfig()) + config = MoEQuantConfig(Int4WeightOnlyConfig(version=1)) tensor_impl_class = TensorCoreTiledAQTTensorImpl self._test_impl_moe_quant( @@ -164,8 +163,6 @@ def test_int4wo_base(self, name, num_tokens, fullgraph): def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE @@ -188,8 +185,6 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph): def test_int8wo_base(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_6: - self.skipTest("Test only enabled for 2.6+") config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -208,9 +203,6 @@ def test_int8wo_base(self, name, num_tokens, fullgraph): ] ) def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): - if not TORCH_VERSION_AT_LEAST_2_6: - self.skipTest("Test only enabled for 2.6+") - config = MoEQuantConfig(Int8WeightOnlyConfig()) tensor_impl_class = PlainAQTTensorImpl @@ -230,8 +222,6 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph): def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig( Int8DynamicActivationInt8WeightConfig(), @@ -255,8 +245,6 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph): def test_int8dq_base(self, name, num_tokens, fullgraph): if not torch.cuda.is_available(): self.skipTest("Need CUDA available") - if not TORCH_VERSION_AT_LEAST_2_5: - self.skipTest("Test only enabled for 2.5+") config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) base_class = LinearActivationQuantizedTensor diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index f51b89d6cd..84428ba8d7 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -14,20 +14,14 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase -from torchao.quantization.granularity import ( - PerAxis, - PerTensor, -) +from torchao.quantization.granularity import PerAxis, PerTensor from torchao.quantization.observer import ( + AffineQuantizedFixedQParamObserver, AffineQuantizedMinMaxObserver, + AffineQuantizedMSEObserver, ) -from torchao.quantization.quant_api import ( - insert_observers_, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from torchao.quantization.quant_api import insert_observers_ +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain class TestQuantFlow(TestCase): @@ -145,6 +139,56 @@ def test_block_size_row_errors(self): for example_input in example_inputs: obs(example_input) + def test_mse_observer(self): + obs = AffineQuantizedMSEObserver( + MappingType.SYMMETRIC, + torch.int8, + granularity=PerAxis(0), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=ZeroPointDomain.NONE, + steps=100, + run_once=True, + ) + example_input = torch.randn(10, 2048) + obs(example_input) + + scale, zero_point = obs.calculate_qparams() + self.assertIsNone(zero_point) + + minmax_obs = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.int8, + granularity=PerAxis(0), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=ZeroPointDomain.NONE, + ) + minmax_obs(example_input) + min_val, max_val = minmax_obs.min_val, minmax_obs.max_val + assert torch.all( + obs.loss_fn(example_input, obs.min_val, obs.max_val) + <= obs.loss_fn(example_input, min_val, max_val) + 1e6 + ) + + def test_fixed_qparams_observer(self): + obs = AffineQuantizedFixedQParamObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity=PerAxis(0), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=ZeroPointDomain.NONE, + ) + example_input = torch.randn(10, 2048) + obs(example_input) + obs.set_qparams(torch.ones(2048)) + scale, zero_point = obs.calculate_qparams() + self.assertTrue(torch.allclose(scale, torch.ones(2048))) + class TestLinearObserver(TestCase): @common_utils.parametrize("observe_weight", [True, False]) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f0404a2ac2..a6ef09e6e8 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -9,21 +9,28 @@ import copy import unittest -from typing import List +import warnings +from typing import List, Type import torch import torch.nn.functional as F -from parameterized import parameterized from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, +) from torchao import quantize_ -from torchao.float8.config import ScalingGranularity -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.float8_tensor import LinearMMConfig +from torchao.core.config import AOBaseConfig +from torchao.float8.config import e4m3_dtype +from torchao.quantization import Float8Tensor from torchao.quantization.granularity import ( + Granularity, PerAxis, PerGroup, PerRow, + PerTensor, PerToken, ) from torchao.quantization.linear_quant_modules import ( @@ -32,18 +39,23 @@ ) from torchao.quantization.qat.api import ( ComposableQATQuantizer, - FakeQuantizeConfig, + FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, - from_intx_quantization_aware_training, + QATConfig, + QATStep, initialize_fake_quantizers, - intx_quantization_aware_training, ) from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) +from torchao.quantization.qat.fake_quantize_config import ( + Float8FakeQuantizeConfig, + Int4WeightFakeQuantizeConfig, + IntxFakeQuantizeConfig, +) from torchao.quantization.qat.fake_quantizer import ( - FakeQuantizer, - _Float8RowwiseActivationFakeQuantizer, + Float8FakeQuantizer, + IntxFakeQuantizer, ) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, @@ -54,11 +66,15 @@ from torchao.quantization.qat.utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, - _Float8RowwiseFakeQuantize, _get_qmin_qmax, ) from torchao.quantization.quant_api import ( - int8_dynamic_activation_int4_weight, + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, ) from torchao.quantization.quant_primitives import ( MappingType, @@ -69,6 +85,7 @@ dequantize_affine, quantize_affine, ) +from torchao.quantization.quantize_.workflows import Int4PackingFormat from torchao.quantization.unified import ( TwoStepQuantizer, ) @@ -80,9 +97,9 @@ groupwise_affine_quantize_tensor, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_6, + _is_fbgemm_genai_gpu_available, + is_fbcode, + is_sm_at_least_89, ) # TODO: put this in a common test utils file @@ -108,19 +125,26 @@ def __init__(self): self.sub = Sub() self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float) - def example_inputs(self): - return (torch.randn(1, 512).to(torch.float),) + def example_inputs(self, device: torch.device = None): + return (torch.randn((1, 512), device=device).to(torch.float),) - def _get_all_weight_qparams(self) -> List[torch.Tensor]: + def _get_all_weight_scales(self) -> List[torch.Tensor]: return [ self.linear1.weight_fake_quantizer.scale, - self.linear1.weight_fake_quantizer.zero_point, self.sub.linear.weight_fake_quantizer.scale, - self.sub.linear.weight_fake_quantizer.zero_point, self.linear2.weight_fake_quantizer.scale, + ] + + def _get_all_weight_zero_points(self) -> List[torch.Tensor]: + return [ + self.linear1.weight_fake_quantizer.zero_point, + self.sub.linear.weight_fake_quantizer.zero_point, self.linear2.weight_fake_quantizer.zero_point, ] + def _get_all_weight_qparams(self) -> List[torch.Tensor]: + return self._get_all_weight_scales() + self._get_all_weight_zero_points() + def forward(self, x): x = self.linear1(x) x = self.sub(x) @@ -187,12 +211,9 @@ def forward(self, x): return x -class TestQAT(unittest.TestCase): +class TestQAT(TestCase): SEED = 123 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) @@ -237,9 +258,6 @@ def test_fake_quantize_per_channel_group(self): ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_per_token(self): (qmin, qmax) = _get_qmin_qmax(8) @@ -337,9 +355,6 @@ def _set_ptq_weight( else: raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_linear(self): from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear @@ -370,9 +385,6 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer(self): from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -408,9 +420,6 @@ def test_qat_8da4w_quantizer(self): ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -422,9 +431,6 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -483,9 +489,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. @@ -582,9 +585,6 @@ def _test_qat_quantized_gradients(self, quantizer): optimizer.step() current_step += 1 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_quantizer_gradients(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -651,9 +651,6 @@ def test_qat_4w_primitives(self): self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear @@ -689,18 +686,12 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_4w_quantizer_gradients(self): from torchao.quantization.qat import Int4WeightOnlyQATQuantizer quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer @@ -786,9 +777,6 @@ def test_composable_qat_quantizer(self): values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"] ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_4w_embedding(self): from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, @@ -829,26 +817,28 @@ def test_qat_4w_embedding(self): def test_fake_quantize_config_granularity(self): """ - Test initialization and property setting of `FakeQuantizeConfig`'s granularity. + Test initialization and property setting of `IntxFakeQuantizeConfig`'s granularity. """ # per token - per_token_config1 = FakeQuantizeConfig(torch.int8, PerToken()) - per_token_config2 = FakeQuantizeConfig(torch.int8, "per_token") + per_token_config1 = IntxFakeQuantizeConfig(torch.int8, PerToken()) + per_token_config2 = IntxFakeQuantizeConfig(torch.int8, "per_token") self.assertIsInstance(per_token_config1.granularity, PerToken) self.assertIsInstance(per_token_config2.granularity, PerToken) # per channel - per_channel_config1 = FakeQuantizeConfig(torch.int8, PerAxis(0)) - per_channel_config2 = FakeQuantizeConfig(torch.int8, "per_channel") + per_channel_config1 = IntxFakeQuantizeConfig(torch.int8, PerAxis(0)) + per_channel_config2 = IntxFakeQuantizeConfig(torch.int8, "per_channel") self.assertIsInstance(per_channel_config1.granularity, PerAxis) self.assertIsInstance(per_channel_config2.granularity, PerAxis) self.assertEqual(per_channel_config1.granularity.axis, 0) self.assertEqual(per_channel_config2.granularity.axis, 0) # per group - per_group_config1 = FakeQuantizeConfig(torch.int8, PerGroup(32)) - per_group_config2 = FakeQuantizeConfig(torch.int8, "per_group", group_size=32) - per_group_config3 = FakeQuantizeConfig(torch.int8, group_size=32) + per_group_config1 = IntxFakeQuantizeConfig(torch.int8, PerGroup(32)) + per_group_config2 = IntxFakeQuantizeConfig( + torch.int8, "per_group", group_size=32 + ) + per_group_config3 = IntxFakeQuantizeConfig(torch.int8, group_size=32) self.assertIsInstance(per_group_config1.granularity, PerGroup) self.assertIsInstance(per_group_config2.granularity, PerGroup) self.assertIsInstance(per_group_config3.granularity, PerGroup) @@ -869,48 +859,48 @@ def test_fake_quantize_config_granularity(self): def test_fake_quantize_config_granularity_error_cases(self): """ - Test incorrect settings of `FakeQuantizeConfig`'s granularity. + Test incorrect settings of `IntxFakeQuantizeConfig`'s granularity. """ # no granularity provided with self.assertRaisesRegex( ValueError, "`granularity` or `group_size` must be set" ): - FakeQuantizeConfig(torch.int8) + IntxFakeQuantizeConfig(torch.int8) # group_size with conflicting granularity msg = "`group_size` conflicts with granularity" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, PerToken(), group_size=32) + IntxFakeQuantizeConfig(torch.int8, PerToken(), group_size=32) with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, PerGroup(64), group_size=32) + IntxFakeQuantizeConfig(torch.int8, PerGroup(64), group_size=32) with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, "per_token", group_size=32) + IntxFakeQuantizeConfig(torch.int8, "per_token", group_size=32) # 'per_group' but no group_size msg = "Granularity was 'per_group' but no `group_size` was set" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, "per_group") + IntxFakeQuantizeConfig(torch.int8, "per_group") # not supported with self.assertRaisesRegex(ValueError, "not supported"): - FakeQuantizeConfig(torch.int8, PerRow()) + IntxFakeQuantizeConfig(torch.int8, PerRow()) with self.assertRaisesRegex(ValueError, "Only axis=0 is supported"): - FakeQuantizeConfig(torch.int8, PerAxis(1)) + IntxFakeQuantizeConfig(torch.int8, PerAxis(1)) with self.assertRaisesRegex(ValueError, "Unexpected granularity"): - FakeQuantizeConfig(torch.int8, "blah") + IntxFakeQuantizeConfig(torch.int8, "blah") with self.assertRaisesRegex(ValueError, "unexpected type"): - FakeQuantizeConfig(torch.int8, 1234) + IntxFakeQuantizeConfig(torch.int8, 1234) def test_fake_quantize_config_mapping_type(self): """ - Test initialization and property setting of `FakeQuantizeConfig`'s mapping type. + Test initialization and property setting of `IntxFakeQuantizeConfig`'s mapping type. """ # symmetric - symmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token") - symmetric_config2 = FakeQuantizeConfig( + symmetric_config1 = IntxFakeQuantizeConfig(torch.int8, "per_token") + symmetric_config2 = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=True ) - symmetric_config3 = FakeQuantizeConfig( + symmetric_config3 = IntxFakeQuantizeConfig( torch.int8, "per_token", MappingType.SYMMETRIC ) self.assertEqual(symmetric_config1.mapping_type, MappingType.SYMMETRIC) @@ -921,10 +911,10 @@ def test_fake_quantize_config_mapping_type(self): self.assertTrue(symmetric_config3.is_symmetric) # asymmetric - asymmetric_config1 = FakeQuantizeConfig( + asymmetric_config1 = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False ) - asymmetric_config2 = FakeQuantizeConfig( + asymmetric_config2 = IntxFakeQuantizeConfig( torch.int8, "per_token", MappingType.ASYMMETRIC ) self.assertEqual(asymmetric_config1.mapping_type, MappingType.ASYMMETRIC) @@ -940,66 +930,62 @@ def test_fake_quantize_config_mapping_type(self): # bad config1: both mapping_type and is_symmetric are set msg = "Cannot set both `mapping_type` and `is_symmetric`" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False ) # bad config2: not supported with self.assertRaisesRegex(ValueError, "not supported"): - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR ) def test_fake_quantize_config_dtype(self): """ - Test that unsupported dtypes are caught in `FakeQuantizeConfig`. + Test that unsupported dtypes are caught in `IntxFakeQuantizeConfig`. """ msg = "Unsupported dtype" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int16, "per_token") + IntxFakeQuantizeConfig(torch.int16, "per_token") with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int32, "per_token") + IntxFakeQuantizeConfig(torch.int32, "per_token") with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.bfloat16, "per_token") + IntxFakeQuantizeConfig(torch.bfloat16, "per_token") with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.float32, "per_token") + IntxFakeQuantizeConfig(torch.float32, "per_token") # OK - if TORCH_VERSION_AT_LEAST_2_3: - FakeQuantizeConfig(torch.uint1, "per_token") - FakeQuantizeConfig(torch.uint2, "per_token") - FakeQuantizeConfig(torch.uint3, "per_token") - FakeQuantizeConfig(torch.uint4, "per_token") - FakeQuantizeConfig(torch.uint5, "per_token") - FakeQuantizeConfig(torch.uint6, "per_token") - FakeQuantizeConfig(torch.uint7, "per_token") - FakeQuantizeConfig(torch.uint8, "per_token") - FakeQuantizeConfig(TorchAODType.INT1, "per_token") - FakeQuantizeConfig(TorchAODType.INT2, "per_token") - FakeQuantizeConfig(TorchAODType.INT3, "per_token") - FakeQuantizeConfig(TorchAODType.INT4, "per_token") - FakeQuantizeConfig(TorchAODType.INT5, "per_token") - FakeQuantizeConfig(TorchAODType.INT6, "per_token") - FakeQuantizeConfig(TorchAODType.INT7, "per_token") - FakeQuantizeConfig(torch.int8, "per_token") + IntxFakeQuantizeConfig(torch.uint1, "per_token") + IntxFakeQuantizeConfig(torch.uint2, "per_token") + IntxFakeQuantizeConfig(torch.uint3, "per_token") + IntxFakeQuantizeConfig(torch.uint4, "per_token") + IntxFakeQuantizeConfig(torch.uint5, "per_token") + IntxFakeQuantizeConfig(torch.uint6, "per_token") + IntxFakeQuantizeConfig(torch.uint7, "per_token") + IntxFakeQuantizeConfig(torch.uint8, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT1, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT2, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT3, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT4, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT5, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT6, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT7, "per_token") + IntxFakeQuantizeConfig(torch.int8, "per_token") def test_fake_quantize_config_dynamic_and_range_learning(self): """ Test that `is_dynamic` and `range_learning` cannot both be set. """ - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=True, range_learning=False ) - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, range_learning=True ) with self.assertRaisesRegex(ValueError, "not compatible"): - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=True, range_learning=True ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_linear_8da4w(self): """ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. @@ -1010,10 +996,12 @@ def test_fake_quantized_linear_8da4w(self): 256, 688, bias=False, - activation_config=FakeQuantizeConfig( + activation_config=IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False ), - weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), + weight_config=IntxFakeQuantizeConfig( + TorchAODType.INT4, group_size=group_size + ), ) def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: @@ -1051,15 +1039,12 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_8da4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_linear_4w(self): """ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. """ group_size = 128 - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.uint4, group_size=group_size, is_symmetric=False, @@ -1100,9 +1085,6 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_replace_linear_8da4w(self): module = torch.nn.ModuleList( [ @@ -1122,9 +1104,6 @@ def test_replace_linear_8da4w(self): assert isinstance(module[0], Int8DynActInt4WeightQATLinear) assert isinstance(module[1], Int8DynActInt4WeightQATLinear) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_replace_linear_int4(self): module = torch.nn.ModuleList( [torch.nn.Linear(in_features=256, out_features=50, bias=True)] @@ -1157,9 +1136,6 @@ def test_replace_linear_int4(self): ) assert isinstance(module[0], Int4WeightOnlyQATLinear) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantized_embedding_4w(self): """ Test that we can express int4 per group symmetric weight only fake quantization @@ -1172,7 +1148,9 @@ def test_fake_quantized_embedding_4w(self): fq_embedding = FakeQuantizedEmbedding( num_embeddings, embedding_dim, - weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), + weight_config=IntxFakeQuantizeConfig( + TorchAODType.INT4, group_size=group_size + ), ) def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: @@ -1195,9 +1173,6 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = embedding_forward_4w(x2, fq_embedding.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_prototype_bc(self): """ Just to make sure we can import all the old prototype paths. @@ -1223,8 +1198,8 @@ def test_qat_prototype_bc(self): Int8DynActInt4WeightQATQuantizerModuleSwap, ) from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( # noqa: F401, F811 - AffineFakeQuantizedTensor, - to_affine_fake_quantized, + _AffineFakeQuantizedTensor, + _to_affine_fake_quantized, ) from torchao.quantization.prototype.qat.api import ( # noqa: F401, F811 ComposableQATQuantizer, @@ -1251,14 +1226,66 @@ def test_qat_prototype_bc(self): Int8DynActInt4WeightQATQuantizer, ) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) - def test_quantize_api_standalone(self): + def test_qat_config_init(self): + """ + Test that the correct errors are thrown if `QATConfig` is not instantiated properly. + """ + base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) + fq_config = IntxFakeQuantizeConfig(torch.int8, "per_channel") + + # OK + QATConfig(base_config, step="prepare") + QATConfig(base_config, step="convert") + QATConfig(base_config, step=QATStep.PREPARE) + QATConfig(base_config, step=QATStep.CONVERT) + QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare") + QATConfig(weight_config=fq_config, step="prepare") + QATConfig(step="convert") + + # OK: good step values + self.assertEqual(QATConfig(base_config).step, "prepare") + self.assertEqual(QATConfig(base_config, step="Prepare").step, "prepare") + self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert") + + # Bad step + with self.assertRaisesRegex(ValueError, "`step` must be one of"): + QATConfig(base_config, step="blah") + + # Step was not a keyword arg + with self.assertRaisesRegex( + TypeError, "4 positional arguments but 5 were given" + ): + QATConfig(base_config, None, None, "prepare") + + # No configs were provided in prepare step + with self.assertRaisesRegex( + ValueError, + "Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step", + ): + QATConfig(step="prepare") + + # Clashing configs are provided + with self.assertRaisesRegex(ValueError, "Cannot specify both"): + QATConfig(base_config, weight_config=fq_config, step="prepare") + with self.assertRaisesRegex(ValueError, "Cannot specify both"): + QATConfig(base_config, activation_config=fq_config, step="prepare") + with self.assertRaisesRegex( + ValueError, "Cannot specify .* in the convert step" + ): + QATConfig(weight_config=fq_config, step="convert") + + # FakeQuantizeConfigBase was specified as base_config + with self.assertRaisesRegex( + ValueError, + "was passed as `base_config`. Did you mean to do the following instead?", + ): + QATConfig(fq_config, step="prepare") + + def test_quantize_api_prepare(self): """ Test that the following: - quantize_(model, intx_quantization_aware_training(...)) + quantize_(model, QATConfig(...)) can produce the same results as `ComposableQATQuantizer`. """ @@ -1283,20 +1310,15 @@ def test_quantize_api_standalone(self): baseline_model = baseline_quantizer.prepare(baseline_model) # quantize_ API - activation_config = FakeQuantizeConfig( - torch.int8, - "per_token", - is_symmetric=False, + act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + qat_config1 = QATConfig( + activation_config=act_config, weight_config=weight_config ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + qat_config2 = QATConfig(weight_config=weight_config) + quantize_(m, qat_config1) quantize_( - m, - intx_quantization_aware_training(activation_config, weight_config), - ) - quantize_( - m, - intx_quantization_aware_training(weight_config=weight_config), - filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + m, qat_config2, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding) ) # Compare model values @@ -1307,45 +1329,31 @@ def test_quantize_api_standalone(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_api_errors(self): """ Test that we throw exceptions with helpful error messages if `quantize_` runs into unexpected configurations. """ - my_config = FakeQuantizeConfig(torch.int8, group_size=32) + fq_config = IntxFakeQuantizeConfig(torch.int8, group_size=32) + qat_config = QATConfig(activation_config=fq_config, weight_config=fq_config) m = M3() # Embedding currently only supports weight-only quantization with self.assertRaisesRegex( ValueError, "Activation fake quantization is not supported for embedding" ): - quantize_( - m, - intx_quantization_aware_training(my_config, my_config), - lambda m, _: isinstance(m, torch.nn.Embedding), - ) + quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.Embedding)) # Only linear and embedding are supported currently with self.assertRaisesRegex(ValueError, "does not have QAT support"): - quantize_( - m, - intx_quantization_aware_training(my_config, my_config), - lambda m, _: isinstance(m, torch.nn.ReLU), - ) + quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) - def test_quantize_api_convert_path(self): + def test_quantize_api_e2e(self): """ Test that the following: - quantize_(model, intx_quantization_aware_training(...)) - quantize_(model, from_intx_quantization_aware_training(...)) - quantize_(model, int8_dynamic_activation_int4_weight()) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert. """ @@ -1363,16 +1371,8 @@ def test_quantize_api_convert_path(self): baseline_model = baseline_quantizer.prepare(baseline_model) # quantize_ prepare - activation_config = FakeQuantizeConfig( - torch.int8, - "per_token", - is_symmetric=False, - ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) - quantize_( - m, - intx_quantization_aware_training(activation_config, weight_config), - ) + base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size) + quantize_(m, QATConfig(base_config, step="prepare")) # Compare prepared values torch.manual_seed(self.SEED) @@ -1386,8 +1386,7 @@ def test_quantize_api_convert_path(self): baseline_model = baseline_quantizer.convert(baseline_model) # quantize_ convert - quantize_(m, from_intx_quantization_aware_training()) - quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + quantize_(m, QATConfig(base_config, step="convert")) # Compare converted values torch.manual_seed(self.SEED) @@ -1397,16 +1396,13 @@ def test_quantize_api_convert_path(self): baseline_out = baseline_model(*x2) torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_fake_quantize_config_torch_intx(self): """ - Test that `FakeQuantizeConfig` works with torch.intx. + Test that `IntxFakeQuantizeConfig` works with torch.intx. """ group_size = 16 - config1 = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) - config2 = FakeQuantizeConfig(torch.int4, group_size=group_size) + config1 = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + config2 = IntxFakeQuantizeConfig(torch.int4, group_size=group_size) linear1 = FakeQuantizedLinear(32, 64, weight_config=config1) linear2 = FakeQuantizedLinear(32, 64, weight_config=config2) linear2.weight = linear1.weight @@ -1417,64 +1413,50 @@ def test_fake_quantize_config_torch_intx(self): out2 = linear2(*x2) torch.testing.assert_close(out1, out2, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_fake_quantizer_repr(self): """ - Test that `repr(FakeQuantizer(config))` exposes useful config details. + Test that `repr(IntxFakeQuantizer(config))` exposes useful config details. """ - config = FakeQuantizeConfig(torch.int4, group_size=128) - fake_quantizer = FakeQuantizer(config) + config = IntxFakeQuantizeConfig(torch.int4, group_size=128) + fake_quantizer = IntxFakeQuantizer(config) fake_quantizer_repr = repr(fake_quantizer) self.assertTrue("dtype=torch.int4" in fake_quantizer_repr) self.assertTrue("group_size=128" in fake_quantizer_repr) self.assertTrue("PerGroup" in fake_quantizer_repr) self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_linear_bias(self): """ Test that QAT supports linear bias. """ m = ModelWithLinearBias() - activation_config = FakeQuantizeConfig( - torch.int8, "per_token", is_symmetric=False - ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32) - quantize_( - m, - intx_quantization_aware_training(activation_config, weight_config), + act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32) + qat_config = QATConfig( + activation_config=act_config, weight_config=weight_config ) + quantize_(m, qat_config) example_inputs = m.example_inputs() m(*example_inputs) - @parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)]) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) + @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): """ Test that the following produce the exact same numerics: - 1. FakeQuantizer with asymmetric per_token config + 1. IntxFakeQuantizer with asymmetric per_token config 2. torchao.quantization.utils.per_token_dynamic_quant """ from torchao.quantization.utils import per_token_dynamic_quant torch.manual_seed(self.SEED) x = torch.randn(1, 235, 2048).to(dtype) - config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - fake_quantizer = FakeQuantizer(config) + config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + fake_quantizer = IntxFakeQuantizer(config) fake_quantizer_out = fake_quantizer(x) baseline_out = per_token_dynamic_quant(x) torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0) - @parameterized.expand([(torch.float32,), (torch.bfloat16,), (torch.float16,)]) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) + @parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): """ Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces @@ -1513,12 +1495,9 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): ) self.assertEqual(len(non_inf_sqnr), 0, fail_message) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_config_eps(self): """ - Test that users can set arbitrary eps value in `FakeQuantizeConfig`. + Test that users can set arbitrary eps value in `IntxFakeQuantizeConfig`. """ eps = 0.00123 x = torch.randn(2, 3).to(torch.float32) @@ -1532,19 +1511,16 @@ def test_fake_quantize_config_eps(self): eps=eps, ) expected_out = _fake_quantize_per_token(x, scale, zp, -128, 127) - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, eps=eps, ) - fake_quantizer = FakeQuantizer(config) + fake_quantizer = IntxFakeQuantizer(config) actual_out = fake_quantizer(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_qat_8da4w_eps(self): """ Test that the 8da4w QAT flow uses the expected eps. @@ -1591,22 +1567,21 @@ def test_qat_8da4w_eps(self): actual_out = converted_model.linear1(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) - def test_fake_quantizer_range_learning(self): + @parametrize("is_symmetric", [True, False]) + def test_fake_quantizer_range_learning(self, is_symmetric): """ - Test that range learning requires `FakeQuantizer`s to be initialized correctly. + Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly. """ - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, range_learning=True, scale_precision=torch.float32, zero_point_precision=torch.float32, + is_symmetric=is_symmetric, ) - fake_quantizer = FakeQuantizer(config) + fake_quantizer = IntxFakeQuantizer(config) example_inputs = (torch.randn(2, 3),) # Not initialized, should fail @@ -1624,29 +1599,32 @@ def test_fake_quantizer_range_learning(self): initialize_fake_quantizers(fake_quantizer, example_inputs) self.assertTrue(fake_quantizer._initialized) self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter) - self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter) self.assertTrue(fake_quantizer.scale.requires_grad) - self.assertTrue(fake_quantizer.zero_point.requires_grad) + if config.is_symmetric: + self.assertFalse(isinstance(fake_quantizer.zero_point, torch.nn.Parameter)) + self.assertTrue(torch.all(fake_quantizer.zero_point == 0)) + else: + self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter) + self.assertTrue(fake_quantizer.zero_point.requires_grad) fake_quantizer(*example_inputs) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) - def test_qat_range_learning(self): + @parametrize("is_symmetric", [True, False]) + def test_qat_range_learning(self, is_symmetric): """ Test end-to-end QAT flow with range learning. """ - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, range_learning=True, scale_precision=torch.float32, zero_point_precision=torch.float32, + is_symmetric=is_symmetric, ) m = M() example_inputs = m.example_inputs() - quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) + quantize_(m, QATConfig(weight_config=config)) # Not initialized, should fail for t in m._get_all_weight_qparams(): @@ -1662,10 +1640,21 @@ def test_qat_range_learning(self): # All scales and zero points should be in `m.parameters()` initialize_fake_quantizers(m, example_inputs) params = set(m.parameters()) - for t in m._get_all_weight_qparams(): - self.assertIsInstance(t, torch.nn.Parameter) - self.assertTrue(t.requires_grad) - self.assertTrue(t in params) + + for scale in m._get_all_weight_scales(): + self.assertIsInstance(scale, torch.nn.Parameter) + self.assertTrue(scale.requires_grad) + self.assertTrue(scale in params) + + for zero_point in m._get_all_weight_zero_points(): + if config.is_symmetric: + self.assertFalse(isinstance(zero_point, torch.nn.Parameter)) + self.assertTrue(torch.all(zero_point == 0)) + else: + self.assertIsInstance(zero_point, torch.nn.Parameter) + self.assertTrue(zero_point.requires_grad) + self.assertTrue(zero_point in params) + m(*example_inputs) # Simulate training @@ -1694,27 +1683,6 @@ def test_qat_range_learning(self): self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) self.assertFalse(torch.equal(new_weight, prev_weight)) - def test_float8_rowwise_fake_quantize(self): - """ - Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8Tensor`. - """ - torch.manual_seed(self.SEED) - dtype = torch.float8_e4m3fn - x = torch.randn(32, 64) - axiswise_dim = 0 - out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim) - out_expected = hp_tensor_to_float8_dynamic( - x, - dtype, - LinearMMConfig(), - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=axiswise_dim, - ).to_original_precision() - torch.testing.assert_close(out, out_expected, atol=0, rtol=0) - - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" - ) def test_qat_fp8a4w_quantizer(self): """ Test basic model training with `Float8ActInt4WeightQATQuantizer`. @@ -1726,9 +1694,10 @@ def test_qat_fp8a4w_quantizer(self): for linear in [m.linear1, m.sub.linear, m.linear2]: self.assertIsInstance(linear, FakeQuantizedLinear) self.assertIsInstance( - linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer + linear.activation_fake_quantizer, + Float8FakeQuantizer, ) - self.assertIsInstance(linear.weight_fake_quantizer, FakeQuantizer) + self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer) prev_weight = copy.deepcopy(m.linear1.weight) # Simulate training @@ -1749,6 +1718,646 @@ def test_qat_fp8a4w_quantizer(self): self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0) self.assertFalse(torch.equal(new_weight, prev_weight)) + def test_legacy_quantize_api_e2e(self): + """ + Test that the following two APIs are numerically equivalent: + + New API: + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) + + Old API: + quantize_(model, IntXQuantizationAwareTrainingConfig(...)) + quantize_(model, FromIntXQuantizationAwareTrainingConfig()) + quantize_(model, Int8DynamicActivationInt4WeightConfig()) + """ + group_size = 16 + torch.manual_seed(self.SEED) + m = M() + baseline_model = copy.deepcopy(m) + + # Baseline prepare + act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config) + quantize_(baseline_model, old_qat_config) + + # QATConfig prepare + base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size) + quantize_(m, QATConfig(base_config, step="prepare")) + + # Compare prepared values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + + # Baseline convert + quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig()) + quantize_(baseline_model, base_config) + + # quantize_ convert + quantize_(m, QATConfig(base_config, step="convert")) + + # Compare converted values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + + def test_qat_api_deprecation(self): + """ + Test that the appropriate deprecation warning is logged exactly once per class. + """ + from torchao.quantization.qat import ( + FakeQuantizeConfig, + FakeQuantizer, + from_intx_quantization_aware_training, + intx_quantization_aware_training, + ) + + # Reset deprecation warning state, otherwise we won't log warnings here + warnings.resetwarnings() + + # Map from deprecated API to the args needed to instantiate it + deprecated_apis_to_args = { + IntXQuantizationAwareTrainingConfig: (), + FromIntXQuantizationAwareTrainingConfig: (), + intx_quantization_aware_training: (), + from_intx_quantization_aware_training: (), + FakeQuantizeConfig: (torch.int8, "per_channel"), + FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),), + } + + with warnings.catch_warnings(record=True) as _warnings: + # Call each deprecated API twice + for cls, args in deprecated_apis_to_args.items(): + cls(*args) + cls(*args) + + # Each call should trigger the warning only once + self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) + for w in _warnings: + self.assertIn( + "is deprecated and will be removed in a future release", + str(w.message), + ) + + def test_qat_api_convert_no_quantization(self): + """ + Test that `QATConfig(step="convert")` swaps back to nn modules without quantization. + """ + torch.manual_seed(self.SEED) + m = M() + baseline_model = copy.deepcopy(m) + + # Prepare swaps to FakeQuantizedLinear + quantize_(m, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + self.assertEqual(type(m.linear1), FakeQuantizedLinear) + self.assertEqual(type(m.sub.linear), FakeQuantizedLinear) + self.assertEqual(type(m.linear2), FakeQuantizedLinear) + + # Convert without a `base_config` swaps back to nn.Linear + quantize_(m, QATConfig(step="convert")) + self.assertEqual(type(m.linear1), torch.nn.Linear) + self.assertEqual(type(m.sub.linear), torch.nn.Linear) + self.assertEqual(type(m.linear2), torch.nn.Linear) + + # Model weights should be identical to before + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + + def test_float8_fake_quantize_config(self): + """ + Test that the correct errors are thrown if `Float8FakeQuantizeConfig` is not instantiated properly. + """ + # OK + Float8FakeQuantizeConfig(torch.float8_e4m3fn) + Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerRow()) + Float8FakeQuantizeConfig(torch.float8_e4m3fn, PerTensor()) + + with self.assertRaisesRegex(ValueError, "not a float8 dtype"): + Float8FakeQuantizeConfig(torch.int8) + with self.assertRaisesRegex( + ValueError, "Please specify the granularity object instead of the class" + ): + Float8FakeQuantizeConfig(granularity=PerRow) + with self.assertRaisesRegex( + ValueError, "Expected PerRow or PerTensor granularity" + ): + Float8FakeQuantizeConfig(granularity=PerToken()) + + @parametrize("granularity", [PerTensor(), PerRow()]) + def test_float8_fake_quantize(self, granularity: Granularity): + """ + Test that `Float8FakeQuantizer` is numerically close to `Float8Tensor`. + """ + dtype = torch.float8_e4m3fn + fq_config = Float8FakeQuantizeConfig(dtype, granularity) + fake_quantizer = Float8FakeQuantizer(fq_config) + torch.manual_seed(self.SEED) + x = torch.randn(32, 64) + out = fake_quantizer(x) + out_expected = Float8Tensor.from_hp(x, dtype, granularity).dequantize() + sqnr = compute_error(out, out_expected) + self.assertGreater(sqnr, 16) + + def _test_quantize_api_against_ptq( + self, + base_config: AOBaseConfig, + target_prepare_sqnr: float, + target_convert_sqnr: float, + dtype: torch.dtype = torch.bfloat16, + module_type: str = "linear", + ): + """ + Test the following: + + quantize_(model, QATConfig(base_config, step="prepare")) + quantize_(model, QATConfig(base_config, step="convert")) + + and compare model outputs of each step against: + + quantize_(model, base_config) + """ + torch.manual_seed(self.SEED) + + if module_type == "linear": + m = M().to(dtype).cuda() + example_inputs = (m.example_inputs()[0].to(dtype).cuda(),) + filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear) + elif module_type == "embedding": + m = M3().to(dtype).cuda() + example_inputs = (m.example_inputs()[0].cuda(),) + filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding) + else: + raise ValueError(f"Unknown module type {module_type}") + + # baseline + m_baseline = copy.deepcopy(m) + quantize_(m_baseline, base_config, filter_fn) + out_baseline = m_baseline(*example_inputs) + + # compare prepare + quantize_(m, QATConfig(base_config, step="prepare"), filter_fn) + out_prepared = m(*example_inputs) + prepare_sqnr = compute_error(out_prepared, out_baseline) + + self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr) + + # compare convert + quantize_(m, QATConfig(base_config, step="convert"), filter_fn) + out_converted = m(*example_inputs) + convert_sqnr = compute_error(out_converted, out_baseline) + self.assertGreaterEqual(convert_sqnr, target_convert_sqnr) + + @parametrize("granularity", [PerTensor(), PerRow()]) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") + def test_quantize_api_fp8_fp8(self, granularity: Granularity): + """ + Test the following: + quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="prepare")) + quantize_(model, QATConfig(Float8DynamicActivationFloat8Weight(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), + target_prepare_sqnr=15, + target_convert_sqnr=float("inf"), + ) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + def test_quantize_api_fp8_int4(self): + """ + Test the following: + quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Float8DynamicActivationInt4WeightConfig(), + target_prepare_sqnr=22, + target_convert_sqnr=float("inf"), + ) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + @unittest.skipIf(is_fbcode(), "cutlass cannot initialize") + @parametrize("version", [1, 2]) + @parametrize( + "packing_format", [Int4PackingFormat.PLAIN, Int4PackingFormat.PRESHUFFLED] + ) + def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat): + """ + Test the following: + quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="prepare")) + quantize_(model, QATConfig(Int4WeightOnlyConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Int4WeightOnlyConfig(version=version, int4_packing_format=packing_format), + target_prepare_sqnr=45 if version == 2 else 12, + target_convert_sqnr=float("inf"), + ) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_quantize_api_int8_int4(self): + """ + Test the following: + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Int8DynamicActivationInt4WeightConfig(group_size=32), + target_prepare_sqnr=30, + target_convert_sqnr=float("inf"), + ) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @parametrize( + "weight_dtype, weight_granularity, dtype", + [ + (weight_dtype, weight_granularity, dtype) + for weight_dtype in [getattr(torch, f"int{i}") for i in range(2, 9)] + for weight_granularity in [PerGroup(32), PerAxis(0)] + for dtype in [torch.bfloat16, torch.float32] + ], + ) + def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): + """ + Test the following: + quantize_(model, QATConfig(Int8DynamicActivationIntxWeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationIntxWeightConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, weight_granularity=weight_granularity + ), + target_prepare_sqnr=float("inf"), + target_convert_sqnr=float("inf"), + dtype=dtype, + ) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @parametrize( + "weight_dtype, granularity, dtype, module_type", + [ + (weight_dtype, granularity, dtype, module_type) + for weight_dtype in [getattr(torch, f"int{i}") for i in range(2, 9)] + for granularity in [PerGroup(32), PerAxis(0)] + for dtype in [torch.bfloat16, torch.float32] + for module_type in ["linear", "embedding"] + ], + ) + def test_quantize_api_intx(self, weight_dtype, granularity, dtype, module_type): + """ + Test the following: + quantize_(model, QATConfig(IntxWeightOnlyConfig(), step="prepare")) + quantize_(model, QATConfig(IntxWeightOnlyConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + target_prepare_sqnr=float("inf"), + target_convert_sqnr=float("inf"), + dtype=dtype, + module_type=module_type, + ) + + def test_infer_fp8_int4_config(self): + """ + Test that fake quantize configs are correctly inferred from + `Float8DynamicActivationInt4WeightConfig`. + """ + from torchao.quantization.qat.fake_quantize_config import ( + _infer_fake_quantize_configs, + ) + + base_config = Float8DynamicActivationInt4WeightConfig() + (act_config, weight_config) = _infer_fake_quantize_configs(base_config) + self.assertIsInstance(act_config, Float8FakeQuantizeConfig) + self.assertEqual(act_config.dtype, e4m3_dtype) + self.assertIsInstance(act_config.granularity, PerRow) + self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig) + self.assertEqual(weight_config.group_size, 128) + self.assertEqual(weight_config.activation_dtype, e4m3_dtype) + + def test_infer_int4_weight_only_config(self): + """ + Test that fake quantize configs are correctly inferred from `Int4WeightOnlyConfig`. + """ + from torchao.quantization.qat.fake_quantize_config import ( + _infer_fake_quantize_configs, + ) + + base_config = Int4WeightOnlyConfig(version=1) + (act_config, weight_config) = _infer_fake_quantize_configs(base_config) + self.assertIsNone(act_config) + self.assertIsInstance(weight_config, IntxFakeQuantizeConfig) + self.assertEqual(weight_config.dtype, torch.uint4) + self.assertEqual(weight_config.group_size, 128) + self.assertFalse(weight_config.is_symmetric) + + base_config = Int4WeightOnlyConfig(version=2) + (act_config, weight_config) = _infer_fake_quantize_configs(base_config) + self.assertIsNone(act_config) + self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig) + self.assertEqual(weight_config.group_size, 128) + self.assertEqual(weight_config.activation_dtype, torch.bfloat16) + + @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") + def test_quantize_api_nvfp4(self): + """ + Test the following: + quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare")) + quantize_(model, QATConfig(NVFP4InferenceConfig(), step="convert")) + """ + from torchao.prototype.mx_formats import NVFP4InferenceConfig + + self._test_quantize_api_against_ptq( + NVFP4InferenceConfig(), + target_prepare_sqnr=8, + target_convert_sqnr=float("inf"), + ) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @parametrize("use_per_tensor_scale", [True, False]) + def test_qat_nvfp4(self, use_per_tensor_scale: bool): + """ + Test QAT with `NVFP4FakeQuantizeConfig`. + """ + from torchao.prototype.qat import NVFP4FakeQuantizeConfig + + torch.manual_seed(self.SEED) + m = M().cuda() + baseline_model = copy.deepcopy(m) + qat_config = QATConfig( + activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), + weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), + step="prepare", + ) + quantize_(m, qat_config) + + # Compare prepared values + torch.manual_seed(self.SEED) + x = m.example_inputs("cuda") + out = m(*x) + baseline_out = baseline_model(*x) + sqnr = compute_error(out, baseline_out).item() + self.assertGreater(sqnr, 24) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + @unittest.skipIf(is_fbcode(), "triton compilation error") + def test_fbgemm_fp8_primitives(self): + """ + Compare numerics between: + (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row + (2) Our reference QAT version in `Float8FakeQuantizer` + """ + from fbgemm_gpu.experimental.gen_ai.quantize import quantize_fp8_row + + from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, + ) + + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x2 = copy.deepcopy(x1) + + # (1) Just call `quantize_fp8_row` + (q1, scale1) = quantize_fp8_row(x1) + + # (2) Our reference implementation for QAT without the dequantize + scale2 = _choose_scale_float8( + x2, + (1, x2.shape[-1]), + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + q2 = _quantize_affine_float8(x2, scale2, torch.float8_e4m3fn) + sqnr = compute_error(q1.to(torch.float32), q2.to(torch.float32)) + scale_sqnr = compute_error( + scale1.to(torch.float32).flatten(), + scale2.to(torch.float32).flatten(), + ) + self.assertGreater(sqnr, 40) + self.assertGreater(scale_sqnr, 50) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + @unittest.skipIf(is_fbcode(), "triton compilation error") + def test_fbgemm_fp8_int4_preshuffled_primitives(self): + """ + Compare numerics between: + (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle + (2) Our reference QAT version in `Int4WeightFakeQuantizer` + """ + from fbgemm_gpu.experimental.gen_ai.quantize import ( + int4_row_quantize, + pack_int4, + quantize_fp8_row, + quantize_int4_preshuffle, + ) + + from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, + _quantize_affine_no_dtype_cast, + ) + + group_size = 128 + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x2 = copy.deepcopy(x1) + x3 = copy.deepcopy(x1) + + # (1) Just call `quantize_int4_preshuffle` + (q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="fp8") + + # (2) Call `quantize_int4_preshuffle` but skip packing and shuffling + (q2, _) = quantize_fp8_row(x2) + (q2, scale2) = int4_row_quantize(q2, group_size) + + # (3) Reference implementation for QAT without the dequantize + fp8_scale = _choose_scale_float8( + x3, + (1, x3.shape[-1]), + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + x3_fp8 = _quantize_affine_float8(x3, fp8_scale, torch.float8_e4m3fn) + x3_fp8 = x3_fp8.to(torch.float32) + x3_fp8_grouped = x3_fp8.view(x3_fp8.shape[0], -1, group_size) + max_abs = torch.amax(torch.abs(x3_fp8_grouped), dim=-1, keepdim=False) + scale = torch.clamp(max_abs / 8, min=1e-6) + zero_point = torch.zeros_like(scale) + q3 = _quantize_affine_no_dtype_cast( + x3_fp8, + (1, group_size), + scale, + zero_point, + quant_min=-8, + quant_max=7, + ) + scale3 = scale + + def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + t = pack_int4(t.to(torch.int8)) + return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.float8_e4m3fn))[0] + + # First, sanity check that shuffle_and_pack(q2) == q1 + torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0) + + # Now check q2 vs q3 with and without shuffle + sqnr_q2_q3 = compute_error(q2.to(torch.float32), q3.to(torch.float32)) + sqnr_q2_q3_preshuffle = compute_error( + shuffle_and_pack(q2, scale2).to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + ) + self.assertGreater(sqnr_q2_q3, 32) + self.assertGreater(sqnr_q2_q3_preshuffle, 32) + + # Now check shuffle_and_pack(q3) vs q1 + sqnr_q1_q3_preshuffle = compute_error( + q1.to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + ) + self.assertGreater(sqnr_q1_q3_preshuffle, 32) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + @unittest.skipIf(is_fbcode(), "triton compilation error") + def test_fbgemm_int4_weight_only_primitives(self): + """ + Compare numerics between: + (1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp + (2) Our reference QAT version in `Int4WeightFakeQuantizer` + """ + from fbgemm_gpu.experimental.gen_ai.quantize import ( + int4_row_quantize_zp, + pack_int4, + quantize_int4_preshuffle, + ) + + group_size = 128 + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x2 = copy.deepcopy(x1) + x3 = copy.deepcopy(x1) + + # (1) Just call `quantize_int4_preshuffle` with dtype="bf16" + (q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="bf16") + + # (2) Call `int4_row_quantize_zp`, which should be the same as (1) + # but without the packing and shuffling + (q2, scale2, _) = int4_row_quantize_zp(x2, group_size) + + # (3) Reference implementation for QAT without the dequantize + eps = 1e-6 + qmin, qmax = 0, 15 + fbgemm_symmetric_qmax = 8 + w_grouped = x3.to(torch.float32).view(x3.shape[0], -1, group_size) + max_val = torch.amax(w_grouped, dim=-1, keepdim=True) + min_val = torch.amin(w_grouped, dim=-1, keepdim=True) + scale3 = torch.clamp(max_val - min_val, min=eps) / qmax + q3 = (w_grouped.sub(min_val).div(scale3)).round().clamp_(qmin, qmax) + q3 = q3 - fbgemm_symmetric_qmax + q3 = q3.view(x3.shape) + + def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + t = pack_int4(t.to(torch.int8)) + return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.bfloat16))[0] + + # First, sanity check that shuffle_and_pack(q2) == q1 + torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0) + + # Now check q2 vs q3 with and without shuffle + torch.testing.assert_close(q2.to(torch.float32), q3, atol=0, rtol=0) + torch.testing.assert_close( + shuffle_and_pack(q2, scale2).to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + atol=0, + rtol=0, + ) + + # Now check shuffle_and_pack(q3) vs q1 + torch.testing.assert_close( + q1.to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + atol=0, + rtol=0, + ) + + @parametrize( + "base_config_cls", + [ + IntxWeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, + ], + ) + def test_range_learning_convert_pass_qparams( + self, base_config_cls: Type[AOBaseConfig] + ): + """ + Verify that range learning QAT can pass qparams from the prepared + model to the convert model. + """ + group_size = 32 + config = IntxFakeQuantizeConfig( + torch.int4, + group_size=group_size, + is_symmetric=True, + is_dynamic=False, + range_learning=True, + ) + m = M() + example_inputs = m.example_inputs() + quantize_(m, QATConfig(weight_config=config, step="prepare")) + initialize_fake_quantizers(m, example_inputs) + + # convert and verify scales are what we expect + scale1 = m.linear1.weight_fake_quantizer.scale + scale2 = m.linear2.weight_fake_quantizer.scale + sub_scale = m.sub.linear.weight_fake_quantizer.scale + if base_config_cls == Int8DynamicActivationInt4WeightConfig: + base_config = base_config_cls() + quantize_(m, QATConfig(base_config, step="convert")) + torch.testing.assert_close( + m.linear1.weight.original_weight_tensor.tensor_impl.scale, scale1 + ) + torch.testing.assert_close( + m.linear2.weight.original_weight_tensor.tensor_impl.scale, scale2 + ) + torch.testing.assert_close( + m.sub.linear.weight.original_weight_tensor.tensor_impl.scale, sub_scale + ) + else: + base_config = base_config_cls(torch.int4, PerGroup(group_size)) + quantize_(m, QATConfig(base_config, step="convert")) + torch.testing.assert_close(m.linear1.weight.scale, scale1) + torch.testing.assert_close(m.linear2.weight.scale, scale2) + torch.testing.assert_close(m.sub.linear.weight.scale, sub_scale) + + +instantiate_parametrized_tests(TestQAT) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 2bb20d5afd..b5ea7bf09a 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -10,6 +10,7 @@ import gc import tempfile import unittest +import warnings from pathlib import Path import torch @@ -29,9 +30,7 @@ AffineQuantizedTensor, Int4CPULayout, Int4XPULayout, - Int8DynamicActInt4WeightCPULayout, PlainLayout, - QDQLayout, TensorCoreTiledLayout, ) from torchao.quantization import ( @@ -39,27 +38,28 @@ PerGroup, ) from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Float8StaticActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8DynamicActivationIntxWeightConfig, Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, Quantizer, TwoStepQuantizer, + UIntXWeightOnlyConfig, _replace_with_custom_fn_if_matches_filter, - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, ) from torchao.quantization.quant_primitives import MappingType +from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( + IntxUnpackedToInt8Tensor, +) from torchao.quantization.subclass import ( Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -67,14 +67,9 @@ from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_7, - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_89, is_sm_at_least_90, + torch_version_at_least, unwrap_tensor_subclass, ) @@ -129,7 +124,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - quantize_(model, int8_dynamic_activation_int8_weight()) + quantize_(model, Int8DynamicActivationInt8WeightConfig()) return model @@ -152,32 +147,6 @@ def forward(self, x): return x -def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): - """ - The deprecated implementation for int8 dynamic quant API, used as a reference for - numerics and performance - """ - from torchao.quantization.quant_api import ( - _get_subclass_inserter, - _in_features_greater_than_16, - _is_linear, - ) - from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight - - if filter_fn is None: - filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( - *args - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - filter_fn, - ) - - def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ @@ -215,7 +184,7 @@ class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - quantize_(m, int8_dynamic_activation_int8_weight()) + quantize_(m, Int8DynamicActivationInt8WeightConfig()) m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -253,12 +222,12 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) @unittest.skipIf(not torch.xpu.is_available(), "Need XPU available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "only works for torch 2.8+") + @unittest.skipIf(not torch_version_at_least("2.8.0"), "only works for torch 2.8+") def test_int4_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() def api(model): - quantize_(model, int4_weight_only(layout=Int4XPULayout())) + quantize_(model, Int4WeightOnlyConfig(layout=Int4XPULayout(), version=1)) unwrap_tensor_subclass(model) api(m) @@ -281,12 +250,11 @@ def api(model): torch.testing.assert_close(ref, res.cpu()) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() def api(model): - quantize_(model, int8_weight_only()) + quantize_(model, Int8WeightOnlyConfig()) unwrap_tensor_subclass(model) api(m) @@ -306,11 +274,10 @@ def api(model): example_inputs = map(lambda x: x.cuda(), example_inputs) res = m2(*example_inputs) - torch.testing.assert_close(ref, res.cpu()) + # TODO: figure out why ROCm has a larger error + atol, rtol = (1e-2, 1e-2) if torch.version.hip else (None, None) + torch.testing.assert_close(ref, res.cpu(), atol=atol, rtol=rtol) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" - ) def test_8da4w_quantizer(self): from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -323,9 +290,6 @@ def test_8da4w_quantizer(self): assert isinstance(m.linear2, Int8DynActInt4WeightLinear) m(*example_inputs) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" - ) def test_8da4w_quantizer_linear_bias(self): from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -444,7 +408,6 @@ def test_eval_wrapper_llama3(self): ) # TODO: move to a separate test file - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @common_utils.parametrize( "mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR] ) @@ -455,7 +418,7 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): example_inputs = m.example_inputs() quantize_( m, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=group_size, mapping_type=mapping_type ), ) @@ -484,8 +447,6 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") def test_quantized_tensor_subclass_int4(self): for device in self.GPU_DEVICES: @@ -497,10 +458,13 @@ def test_quantized_tensor_subclass_int4(self): group_size = 32 if device == "xpu": quantize_( - m, int4_weight_only(group_size=group_size, layout=Int4XPULayout()) + m, + Int4WeightOnlyConfig( + group_size=group_size, layout=Int4XPULayout(), version=1 + ), ) else: - quantize_(m, int4_weight_only(group_size=group_size)) + quantize_(m, Int4WeightOnlyConfig(group_size=group_size, version=1)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -512,14 +476,13 @@ def test_quantized_tensor_subclass_int4(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_wo(self): m = ToyLinearModel().eval().to(torch.bfloat16) m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - quantize_(m, int8_weight_only()) + quantize_(m, Int8WeightOnlyConfig()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -532,57 +495,13 @@ def test_quantized_tensor_subclass_int8_wo(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below") - def test_quantized_tensor_subclass_int8_dyn_quant(self): - # use multiples of 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") - m_copy = copy.deepcopy(m) - # setting batch_size to 20 to be compatible with the kernel - example_inputs = m.example_inputs( - batch_size=20, dtype=torch.bfloat16, device="cuda" - ) - quantize_(m, int8_dynamic_activation_int8_weight()) - - assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance( - m.linear1.weight.original_weight_tensor, AffineQuantizedTensor - ) - assert isinstance( - m.linear2.weight.original_weight_tensor, AffineQuantizedTensor - ) - - # reference - _ref_change_linear_weights_to_int8_dqtensors(m_copy) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - - self.assertTrue(torch.equal(res, ref)) - - # workaround for export path - from torchao.utils import unwrap_tensor_subclass - - m_unwrapped = unwrap_tensor_subclass(m) - - m = torch.export.export(m_unwrapped, example_inputs, strict=True).module() - exported_model_res = m(*example_inputs) - - self.assertTrue(torch.equal(exported_model_res, ref)) - - # make sure it compiles - torch._export.aot_compile(m_unwrapped, example_inputs) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load(self): m = ToyLinearModel().eval().to(torch.bfloat16) m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16) - quantize_(m, int8_weight_only()) + quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) @@ -594,13 +513,12 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") - quantize_(m, int8_weight_only()) + quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) example_inputs_cuda = (example_inputs[0].to("cuda"),) @@ -608,31 +526,12 @@ def test_int8wo_quantized_model_to_device(self): cuda_res = m(*example_inputs_cuda) self.assertEqual(cuda_res.cpu(), ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+") - def test_int4wo_quantized_model_to_device(self): - # TODO: change initial model to "cpu" - devices = ["cuda", "cuda:0"] - for device in devices: - m = ToyLinearModel().eval().to(torch.bfloat16).to(device) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) - - quantize_(m, int4_weight_only()) - ref = m(*example_inputs) - - example_inputs_cuda = (example_inputs[0].to(device),) - m.to(device=device) - cuda_res = m(*example_inputs_cuda) - self.assertEqual(cuda_res.cpu(), ref) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load_map_location(self): m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") - quantize_(m, int8_weight_only()) + quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) @@ -648,7 +547,6 @@ def test_quantized_tensor_subclass_save_load_map_location(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_model_streaming(self): def reset_memory(): @@ -658,20 +556,19 @@ def reset_memory(): reset_memory() m = ToyLinearModel() - quantize_(m.to(device="cuda"), int8_weight_only()) + quantize_(m.to(device="cuda"), Int8WeightOnlyConfig()) memory_baseline = torch.cuda.max_memory_allocated() del m reset_memory() m = ToyLinearModel() - quantize_(m, int8_weight_only(), device="cuda") + quantize_(m, Int8WeightOnlyConfig(), device="cuda") memory_streaming = torch.cuda.max_memory_allocated() for param in m.parameters(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("use_hqq", [True, False]) @@ -685,8 +582,8 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): with torch.no_grad(): quantize_( m, - int4_weight_only( - group_size=32, layout=Int4CPULayout(), use_hqq=use_hqq + Int4WeightOnlyConfig( + group_size=32, layout=Int4CPULayout(), use_hqq=use_hqq, version=1 ), ) # ensure the expected op is in the code @@ -697,123 +594,56 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] - @unittest.skipIf( - "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"), - reason="cpp kernels not built", - ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Test only enabled for 2.7+") - @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) - @common_utils.parametrize("x_dim", [2, 3]) - @common_utils.parametrize("bias", [True, False]) - @common_utils.parametrize("bs", [1, 160]) - @common_utils.parametrize("sym_quant_a", [True, False]) - def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): - if sym_quant_a and not TORCH_VERSION_AT_LEAST_2_8: - # not supported until PT 2.8 - return - device = "cpu" - m = ToyLinearModel(bias=bias).eval().to(dtype).to(device) - m2 = copy.deepcopy(m) - example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) - if x_dim == 3: - example_inputs = (example_inputs[0].unsqueeze(0),) - - with torch.no_grad(): - # Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout - # is that the former packs two int4 weights into one int8, while the latter does not. - quantize_( - m, - Int8DynamicActivationInt4WeightConfig( - group_size=32, - layout=Int8DynamicActInt4WeightCPULayout(), - act_mapping_type=MappingType.SYMMETRIC - if sym_quant_a - else MappingType.ASYMMETRIC, - ), - ) - y, code = torch._inductor.utils.run_and_get_code( - torch.compile(m, fullgraph=True, dynamic=True), - *example_inputs, - ) - # ensure the expected op is in the code - assert "torch.ops.torchao.da8w4_linear_cpu.default" in code[0] - quantize_( - m2, - int8_dynamic_activation_int4_weight( - group_size=32, - layout=PlainLayout(), - act_mapping_type=MappingType.SYMMETRIC - if sym_quant_a - else MappingType.ASYMMETRIC, - ), - ) - torch._dynamo.reset() # may segfault without this - y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) - atol, rtol = 4e-7, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 1e-2, 3e-3 - elif dtype == torch.half: - atol, rtol = 6e-3, 2e-3 - assert torch.allclose(y, y2, atol=atol, rtol=rtol) - # Test get_plain by dequantize() - dqw1 = m.linear1.weight.original_weight_tensor.dequantize() - dqw2 = m.linear2.weight.original_weight_tensor.dequantize() - dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() - dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() - assert torch.allclose(dqw1, dqw1_ref) - assert torch.allclose(dqw2, dqw2_ref) - # TODO(#1690): move to new config names - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( "config", [ - int4_weight_only(), - float8_weight_only(), - float8_dynamic_activation_float8_weight(), - float8_static_activation_float8_weight(scale=torch.tensor([1.0])), - int4_dynamic_activation_int4_weight(), - int8_dynamic_activation_int8_weight(), - int8_dynamic_activation_int4_weight(), - int8_weight_only(), - fpx_weight_only(ebits=4, mbits=3), - gemlite_uintx_weight_only(), - uintx_weight_only(dtype=torch.uint4), + Int4WeightOnlyConfig(version=1), + Float8WeightOnlyConfig(), + Float8DynamicActivationFloat8WeightConfig(), + Float8StaticActivationFloat8WeightConfig(scale=torch.tensor([1.0])), + Int4DynamicActivationInt4WeightConfig(), + Int8DynamicActivationInt8WeightConfig(), + Int8DynamicActivationInt4WeightConfig(), + Int8WeightOnlyConfig(), + FPXWeightOnlyConfig(ebits=4, mbits=3), + GemliteUIntXWeightOnlyConfig(), + UIntXWeightOnlyConfig(dtype=torch.uint4), ], ) @skip_if_rocm("ROCm enablement in progress") def test_workflow_e2e_numerics(self, config): """ - Simple test of e2e int4_weight_only workflow, comparing numerics + Simple test of e2e Int4WeightOnlyConfig workflow, comparing numerics to a bfloat16 baseline. """ if ( isinstance( config, ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, + Float8DynamicActivationFloat8WeightConfig, + Float8StaticActivationFloat8WeightConfig, ), ) and not is_sm_at_least_89() ): return unittest.skip("requires CUDA capability 8.9 or greater") elif ( - isinstance(config, int4_dynamic_activation_int4_weight) + isinstance(config, Int4DynamicActivationInt4WeightConfig) and is_sm_at_least_90() ): return unittest.skip("only supported on CUDA capability 8.9, not greater") - elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite: + elif isinstance(config, GemliteUIntXWeightOnlyConfig) and not has_gemlite: return unittest.skip("gemlite not available") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability - if isinstance(config, float8_static_activation_float8_weight): + if isinstance(config, Float8StaticActivationFloat8WeightConfig): config.scale = config.scale.to("cuda") dtype = torch.bfloat16 - if isinstance(config, gemlite_uintx_weight_only): + if isinstance(config, GemliteUIntXWeightOnlyConfig): dtype = torch.float16 # set up inputs @@ -835,7 +665,7 @@ def test_workflow_e2e_numerics(self, config): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_module_fqn_to_config_default(self): - config1 = Int4WeightOnlyConfig(group_size=32) + config1 = Int4WeightOnlyConfig(group_size=32, version=1) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"_default": config1, "linear2": config2}) model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) @@ -849,7 +679,7 @@ def test_module_fqn_to_config_default(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_module_fqn_to_config_module_name(self): - config1 = Int4WeightOnlyConfig(group_size=32) + config1 = Int4WeightOnlyConfig(group_size=32, version=1) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"linear1": config1, "linear2": config2}) model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) @@ -861,7 +691,6 @@ def test_module_fqn_to_config_module_name(self): assert isinstance(model.linear2.weight, AffineQuantizedTensor) assert isinstance(model.linear2.weight._layout, PlainLayout) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch 2.6+") def test_module_fqn_to_config_embedding_linear(self): weight_dtype = torch.int8 granularity = PerGroup(8) @@ -870,10 +699,12 @@ def test_module_fqn_to_config_embedding_linear(self): weight_dtype=weight_dtype, granularity=granularity, mapping_type=mapping_type, - scale_dtype=None, ) # example model linear is Linear(16, 8) - linear_config = Int8DynamicActivationInt4WeightConfig(group_size=16) + linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(16), + ) config = ModuleFqnToConfig({"emb": embedding_config, "linear": linear_config}) indices = torch.randint(0, 10, (32,)) @@ -889,13 +720,12 @@ def test_module_fqn_to_config_embedding_linear(self): ) model(*example_inputs) - assert isinstance(model.emb.weight, AffineQuantizedTensor) - assert isinstance(model.emb.weight._layout, QDQLayout) - assert isinstance(model.linear.weight, LinearActivationQuantizedTensor) + assert isinstance(model.emb.weight, IntxUnpackedToInt8Tensor) + assert isinstance(model.linear.weight, IntxUnpackedToInt8Tensor) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_module_fqn_to_config_skip(self): - config1 = Int4WeightOnlyConfig(group_size=32) + config1 = Int4WeightOnlyConfig(group_size=32, version=1) config = ModuleFqnToConfig({"_default": config1, "linear2": None}) model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) @@ -907,7 +737,7 @@ def test_module_fqn_to_config_skip(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int4wo_cuda_serialization(self): - config = Int4WeightOnlyConfig(group_size=32) + config = Int4WeightOnlyConfig(group_size=32, version=1) model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) # quantize in cuda quantize_(model, config) @@ -924,6 +754,56 @@ def test_int4wo_cuda_serialization(self): # load state_dict in cuda model.load_state_dict(sd, assign=True) + def test_config_deprecation(self): + """ + Test that old config functions like `int4_weight_only` trigger deprecation warnings. + """ + from torchao.quantization import ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, + fpx_weight_only, + gemlite_uintx_weight_only, + int4_dynamic_activation_int4_weight, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + uintx_weight_only, + ) + + # Reset deprecation warning state, otherwise we won't log warnings here + warnings.resetwarnings() + + # Map from deprecated API to the args needed to instantiate it + deprecated_apis_to_args = { + float8_dynamic_activation_float8_weight: (), + float8_static_activation_float8_weight: (torch.randn(3)), + float8_weight_only: (), + fpx_weight_only: (3, 2), + gemlite_uintx_weight_only: (), + int4_dynamic_activation_int4_weight: (), + int4_weight_only: (), + int8_dynamic_activation_int4_weight: (), + int8_dynamic_activation_int8_weight: (), + int8_weight_only: (), + uintx_weight_only: (torch.uint4,), + } + + with warnings.catch_warnings(record=True) as _warnings: + # Call each deprecated API twice + for cls, args in deprecated_apis_to_args.items(): + cls(*args) + cls(*args) + + # Each call should trigger the warning only once + self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) + for w in _warnings: + self.assertIn( + "is deprecated and will be removed in a future release", + str(w.message), + ) + common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index ac2a42b9cf..bed8421671 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -16,6 +16,7 @@ _choose_qparams_affine_tinygemm, _fake_quantize_affine, _fake_quantize_affine_cachemask, + _maybe_expand_scale_to_tensor_shape, choose_qparams_affine, dequantize_affine, quantize_affine, @@ -23,16 +24,12 @@ # TODO: remove test for utils? from torchao.quantization.utils import ( + _quantize_activation_per_token_absmax, get_group_qparams_symmetric, groupwise_affine_dequantize_tensor_from_qparams, groupwise_affine_quantize_tensor_from_qparams, - quantize_activation_per_token_absmax, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, check_cpu_version, check_xpu_version, is_fbcode, @@ -132,11 +129,10 @@ def _groupwise_affine_quantize_tensor_from_qparams( .reshape_as(w) ) - if TORCH_VERSION_AT_LEAST_2_5: - if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) - if check_xpu_version(w.device): - w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) + if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if check_xpu_version(w.device): + w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) return w_int4x8 @@ -175,9 +171,6 @@ def _groupwise_affine_dequantize_tensor_from_qparams( class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as @@ -264,34 +257,21 @@ def test_choose_qparams_group_sym_no_clipping_err(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (1, 10) - if TORCH_VERSION_AT_LEAST_2_6: - scale, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - eps=torch.finfo(torch.float32).eps, - scale_dtype=torch.float64, - zero_point_dtype=torch.int64, - ) - else: - scale, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - eps=torch.finfo(torch.float32).eps, - ) - + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float64, + zero_point_dtype=torch.int64, + ) scale_ref, zp_ref = ( torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( input, dtype @@ -347,12 +327,9 @@ def test_choose_qparams_tensor_sym(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) - quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) + quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) mapping_type = MappingType.SYMMETRIC block_size = list(input.shape) @@ -380,33 +357,24 @@ def test_quantize_activation_per_token_abs_max(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(scale, scale_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max_zero_input(self): input = torch.zeros(10, 10) # make sure it still works - quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) + quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_quantize_activation_per_token_abs_max_dtype(self): input = torch.zeros(10, 10, dtype=torch.bfloat16) - quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) + quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.bfloat16) input = torch.zeros(10, 10, dtype=torch.float32) - quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) + quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.float32) input = torch.zeros(10, 10, dtype=torch.float16) - quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) + quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.float32) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) @@ -449,9 +417,6 @@ def test_quantize_dequantize_group_sym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) @@ -493,9 +458,6 @@ def test_quantize_dequantize_channel_asym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) @@ -535,9 +497,6 @@ def test_quantize_dequantize_tensor_asym(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) @@ -578,9 +537,6 @@ def test_quantize_dequantize_channel_asym_4d(self): self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" - ) def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC @@ -726,32 +682,22 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]: if zero_point_domain == ZeroPointDomain.INT: zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32) - if TORCH_VERSION_AT_LEAST_2_5: - input_tmp = input - if (not (check_cpu_version(input.device))) and ( - not (check_xpu_version(input.device)) - ): - input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - if check_xpu_version(input.device): - input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( - input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain - ) - else: - if zero_point_domain == ZeroPointDomain.INT: - continue - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( - input, scales, zeros, n_bit, groupsize - ) + input_tmp = input + if (not (check_cpu_version(input.device))) and ( + not (check_xpu_version(input.device)) + ): + input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + if check_xpu_version(input.device): + input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain + ) w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( input, scales, zeros, n_bit, groupsize, zero_point_domain ) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_affine(self): input = torch.randn(10, 10) @@ -785,9 +731,6 @@ def test_fake_quantize_affine(self): ) torch.testing.assert_close(dequantized, fake_quantized) - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" - ) def test_fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) @@ -829,6 +772,32 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + def test_maybe_expand_scale_to_tensor_shape(self): + # rowwise quantization: if all dimensions match except for the last one, + # and the last dimension is 1, then just return the scale as is + scale = torch.randn([3, 2, 1]) + target_shape = torch.Size([3, 2, 8]) + new_scale = _maybe_expand_scale_to_tensor_shape(scale, target_shape) + self.assertIs(scale, new_scale) + # other broadcastable shapes + scale1 = torch.randn([3, 1, 1]) + scale2 = torch.randn([1, 2, 1]) + scale3 = torch.randn([1, 1, 8]) + scale4 = torch.randn([1, 1, 1]) + new_scale1 = _maybe_expand_scale_to_tensor_shape(scale1, target_shape) + new_scale2 = _maybe_expand_scale_to_tensor_shape(scale2, target_shape) + new_scale3 = _maybe_expand_scale_to_tensor_shape(scale3, target_shape) + new_scale4 = _maybe_expand_scale_to_tensor_shape(scale4, target_shape) + self.assertIs(scale1, new_scale1) + self.assertIs(scale2, new_scale2) + self.assertIs(scale3, new_scale3) + self.assertIs(scale4, new_scale4) + # blockwise quantization: scales are repeated to fit target_shape + scale5 = torch.randn([3, 2, 2]) + new_scale5 = _maybe_expand_scale_to_tensor_shape(scale5, target_shape) + self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8])) + self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2])) + if __name__ == "__main__": unittest.main() diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 804a585dd8..424306f897 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -15,7 +15,7 @@ swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_fbcode +from torchao.utils import is_fbcode class ToyModel(nn.Module): @@ -32,7 +32,6 @@ def forward(self, x): class TestRuntimeSemiStructuredSparsity(TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @unittest.skip("Temporarily skipping to unpin nightlies") @@ -81,7 +80,6 @@ def test_runtime_weight_sparsification(self): for name, mod in model_c.named_modules(): assert not isinstance(mod, SemiSparseLinear) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @unittest.skip("Temporarily skipping to unpin nightlies") diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 783de6c6ae..e602210ee5 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -11,7 +11,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests from torchao.dtypes import MarlinSparseLayout -from torchao.quantization.quant_api import int4_weight_only, quantize_ +from torchao.quantization.quant_api import Int4WeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, @@ -20,7 +20,6 @@ from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.testing.utils import skip_if_rocm -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class SparseMarlin24(TestCase): @@ -48,17 +47,18 @@ def test_quant_sparse_marlin_layout_eager(self): model_copy = copy.deepcopy(self.model) # Quantized - quantize_(model_copy.bfloat16(), int4_weight_only()) + quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1)) dense_result = model_copy(self.input.bfloat16()).half() # Sparse + quantized - quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) + quantize_( + self.model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1) + ) sparse_result = self.model(self.input) assert torch.allclose(dense_result, sparse_result, atol=3e-1), ( "Results are not close" ) - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_compile(self): @@ -66,12 +66,14 @@ def test_quant_sparse_marlin_layout_compile(self): model_copy = copy.deepcopy(self.model) # Quantized - quantize_(model_copy.bfloat16(), int4_weight_only()) + quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1)) model_copy.foward = torch.compile(model_copy.forward, fullgraph=True) dense_result = model_copy(self.input.bfloat16()).half() # Sparse + quantized - quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) + quantize_( + self.model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1) + ) self.model.forward = torch.compile(self.model.forward, fullgraph=True) sparse_result = self.model(self.input) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 5e3086c411..003a50c4d1 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -12,18 +12,17 @@ from torch.testing._internal import common_utils from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout +from torchao.quantization import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, + Float8DynamicActivationFloat8WeightConfig, +) from torchao.quantization.quant_api import ( - int4_weight_only, - int8_dynamic_activation_int8_weight, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, -) +from torchao.utils import is_sm_at_least_90 logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO @@ -31,7 +30,6 @@ class TestSemiStructuredSparse(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skip("Temporarily skipping to unpin nightlies") def test_sparse(self): @@ -59,7 +57,6 @@ def test_sparse(self): class TestQuantSemiSparse(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [False]) @unittest.skip("Temporarily skip to unbreak CI") @@ -84,12 +81,12 @@ def test_quant_semi_sparse(self, compile): ) apply_fake_sparsity(model) model_copy = copy.deepcopy(model) - quantize_(model_copy, int8_dynamic_activation_int8_weight()) + quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) dense_result = model_copy(input) quantize_( model, - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), ) if compile: model = torch.compile(model) @@ -97,7 +94,6 @@ def test_quant_semi_sparse(self, compile): torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse_marlin(self, compile): @@ -119,23 +115,82 @@ def test_sparse_marlin(self, compile): model_copy = copy.deepcopy(model) # Quantized - quantize_(model_copy.bfloat16(), int4_weight_only()) + quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1)) dense_result = model_copy(input.bfloat16()).half() # Sparse + quantized - quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) + quantize_(model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)) if compile: model = torch.compile(model) sparse_result = model(input) torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @common_utils.parametrize("compile", [True, False]) + def test_fp8_cutlass_sparse(self, compile): + input = torch.rand((256, 256)).half().cuda() + model = ( + nn.Sequential( + nn.Linear(256, 1024), + nn.Linear(1024, 256), + ) + .half() + .cuda() + .eval() + ) + + apply_fake_sparsity(model) + model_copy = copy.deepcopy(model) + + # Quantized + quantize_(model_copy.bfloat16(), Float8DynamicActivationFloat8WeightConfig()) + dense_result = model_copy(input.bfloat16()).half() + + # Sparse + quantized + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + if compile: + model = torch.compile(model) + sparse_result = model(input) + + torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1) + + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_fp8_cutlass_sparse_lowering_op_clone(self): + with torch.inference_mode(): + model = nn.Linear(256, 1024).half().cuda().eval() + apply_fake_sparsity(model) + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + + original = model.weight.original_weight_tensor.tensor_impl.get_plain() + cloned = model.weight.original_weight_tensor.tensor_impl.clone().get_plain() + + for o, c in zip(original, cloned): + torch.testing.assert_close(o, c, atol=0.0, rtol=0.0) + + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_fp8_cutlass_sparse_lowering_op_to(self): + # Need to run with inference mode to avoid dispatching to `aten.to_copy` + with torch.inference_mode(): + model = nn.Linear(256, 1024).half().cuda().eval() + apply_fake_sparsity(model) + model_copy = copy.deepcopy(model) + expected = model_copy.weight.to(dtype=torch.float) + + quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) + + original = torch.ops.aten.to.dtype_layout( + model.weight.original_weight_tensor.tensor_impl, + dtype=torch.float, + layout=torch.strided, + ) + torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1) + class TestBlockSparseWeight(common_utils.TestCase): - @unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_4, - "pytorch 2.4+ feature due to need for custom op support", - ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("input_shape", [1, 1024]) @@ -170,7 +225,6 @@ def test_sparse(self, compile, input_shape): class TestQuantBlockSparseWeight(common_utils.TestCase): - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "pytorch 2.6+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) def test_sparse(self, compile): @@ -196,14 +250,16 @@ def test_sparse(self, compile): model_copy = copy.deepcopy(model) - quantize_(model_copy, int8_dynamic_activation_int8_weight()) + quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) reference = model_copy(input) from torchao.dtypes import BlockSparseLayout quantize_( model, - int8_dynamic_activation_int8_weight(layout=BlockSparseLayout(blocksize=64)), + Int8DynamicActivationInt8WeightConfig( + layout=BlockSparseLayout(blocksize=64) + ), ) if compile: model = torch.compile(model) diff --git a/test/test_ao_models.py b/test/test_ao_models.py index 79e4cc3ef5..a658216a7e 100644 --- a/test/test_ao_models.py +++ b/test/test_ao_models.py @@ -3,32 +3,53 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import pytest +import unittest + import torch +from torch.testing._internal import common_utils from torchao._models.llama.model import Transformer -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) - def init_model(name="stories15M", device="cpu", precision=torch.bfloat16): + """Initialize and return a Transformer model with specified configuration.""" model = Transformer.from_name(name) model.to(device=device, dtype=precision) return model.eval() -@pytest.mark.parametrize("device", _AVAILABLE_DEVICES) -@pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("is_training", [True, False]) -def test_ao_llama_model_inference_mode(device, batch_size, is_training): - random_model = init_model(device=device) - seq_len = 16 - input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device) - input_pos = None if is_training else torch.arange(seq_len).to(device) - with torch.device(device): - random_model.setup_caches( - max_batch_size=batch_size, max_seq_length=seq_len, training=is_training - ) - for i in range(3): - out = random_model(input_ids, input_pos) - assert out is not None, "model failed to run" +class TorchAOBasicTestCase(unittest.TestCase): + """Test suite for basic Transformer inference functionality.""" + + @common_utils.parametrize( + "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] + ) + @common_utils.parametrize("batch_size", [1, 4]) + @common_utils.parametrize("is_training", [True, False]) + def test_ao_inference_mode(self, device, batch_size, is_training): + # Initialize model with specified device + random_model = init_model(device=device) + + # Set up test input parameters + seq_len = 16 + input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device) + + # input_pos is None for training mode, tensor for inference mode + input_pos = None if is_training else torch.arange(seq_len).to(device) + + # Setup model caches within the device context + with torch.device(device): + random_model.setup_caches( + max_batch_size=batch_size, max_seq_length=seq_len, training=is_training + ) + + # Run multiple inference iterations to ensure consistency + for i in range(3): + out = random_model(input_ids, input_pos) + self.assertIsNotNone(out, f"Model failed to run on iteration {i}") + + +common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index 692a0d9e6c..b0edfc7fc5 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -16,6 +16,7 @@ OffloadPolicy, fully_shard, ) +from torch.testing._internal import common_utils from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( @@ -25,6 +26,9 @@ run_tests, ) +if common_utils.SEED is None: + common_utils.SEED = 1234 + from packaging.version import Version from torchao import optim from torchao.optim.quant_utils import ( @@ -37,9 +41,8 @@ from torchao.optim.subclass_fp8 import OptimStateFp8 from torchao.testing.utils import skip_if_rocm from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_7, get_available_devices, + torch_version_at_least, ) try: @@ -187,6 +190,27 @@ def test_optim_default_dtype_bf16(self, optim_name, device): finally: torch.set_default_dtype(old_dtype) + @parametrize("optim_name", ["Adam8bit", "Adam4bit", "AdamFp8"]) + @parametrize("device", _DEVICES) + def test_param_groups(self, optim_name, device): + if optim_name.endswith("Fp8") and device == "cuda": + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 CUDA requires compute capability >= 8.9") + + model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32)) + model.to(device=device) + param_groups = [ + dict(params=list(model[0].parameters()), lr=1e-4), + dict(params=list(model[2].parameters()), lr=1e-5), + ] + optimizer = getattr(optim, optim_name)(param_groups) + + x = torch.randn(4, 32, device=device) + loss = model(x).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + # aten.slice is required for dcp.load() when world size changes i.e. re-sharding # however, it's cumbersome to test it directly, since we would need to run distributed # test 2 times with different world size, and persist checkpoint across the 2 runs. @@ -197,8 +221,6 @@ def test_optim_default_dtype_bf16(self, optim_name, device): @parametrize("device", _DEVICES) def test_subclass_slice(self, subclass, shape, device): if subclass == OptimStateFp8: - if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5: - pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5") if device == "cuda" and torch.cuda.get_device_capability() < (8, 9): pytest.skip("FP8 CUDA requires compute capability >= 8.9") @@ -220,7 +242,7 @@ def test_subclass_slice(self, subclass, shape, device): ) @skip_if_rocm("ROCm enablement in progress") @pytest.mark.skipif( - TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI" + torch_version_at_least("2.7.0"), reason="Failing in CI" ) # TODO: fix this @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): @@ -444,9 +466,6 @@ class TestFSDP2(FSDPTest): def world_size(self) -> int: return _FSDP_WORLD_SIZE - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." - ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) @skip_if_rocm("ROCm enablement in progress") def test_fsdp2(self): @@ -562,9 +581,6 @@ def _test_fsdp2(self, args): v2 = v2.dequantize() self.assertEqual(v1, v2) - @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." - ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) @skip_if_rocm("ROCm enablement in progress") def test_uneven_shard(self): diff --git a/test/test_ops.py b/test/test_ops.py index faec689a69..65015e68ba 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -28,9 +28,8 @@ ) from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_7, compute_max_diff, + torch_version_at_least, ) IS_CUDA = torch.cuda.is_available() and torch.version.cuda @@ -155,50 +154,101 @@ def _scaled_dot_product_int8_op_ref( out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255) return out.to(torch.uint8) + def _scaled_dot_product_fp8_op_ref( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0, + is_causal=False, + q_scale=1.0, + k_scale=1.0, + v_scale=1.0, + a_scale=1.0, + o_scale=1.0, + ): + q = q.to(torch.float) * q_scale + k = k.to(torch.float) * k_scale + v = v.to(torch.float) * v_scale + scale_factor = 1 / math.sqrt(q.size(-1)) + attn = q @ k.transpose(-2, -1) + + attn = attn * scale_factor + if attn_mask is not None: + attn = attn + attn_mask.to(torch.float) + attn_max = attn.max(dim=-1, keepdim=True).values + attn = attn - attn_max + attn = torch.exp(attn) + attn_sum = torch.sum(attn, dim=-1, keepdim=True) + attn = attn / attn_sum + attn = torch.clamp(attn / a_scale, min=-448, max=448) + attn = attn.to(torch.float8_e4m3fn).to(torch.float) + attn = attn * a_scale + out = attn @ v + out = torch.clamp(out / o_scale, min=-448, max=448) + return out.to(torch.float8_e4m3fn) + @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later" + not torch_version_at_least("2.7.0"), + reason="quantized sdpa requires torch 2.7 or later", ) @pytest.mark.skipif(not IS_LINUX, reason="only support on linux") @pytest.mark.skipif( "CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"), reason="cpp kernels not built", ) + @parametrize("input_dtype", [torch.uint8, torch.float8_e4m3fn]) @parametrize("batch_size", [56, 120]) @parametrize("n_head", [2, 16]) @parametrize("q_seq_len", [18, 89]) @parametrize("kv_seq_len", [100, 253]) @parametrize("head_dim", [32, 64]) @parametrize("mask_dtype", [None, torch.float32, torch.bfloat16]) - def test_scaled_dot_product_int8_op( - self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype + def test_quantized_scaled_dot_product_op( + self, + input_dtype, + batch_size, + n_head, + q_seq_len, + kv_seq_len, + head_dim, + mask_dtype, ): torch.manual_seed(1234) device = "cpu" - q_scale = float(1.7907238006591797) - q_zp = int(127) - k_scale = float(1.8039721250534058) - k_zp = int(125) - v_scale = float(1.839004635810852) - v_zp = int(127) - a_scale = float(0.003919653594493866) - a_zp = int(120) - o_scale = float(1.8191684484481812) - o_zp = int(128) + if input_dtype == torch.uint8: + q_scale = float(1.7907238006591797) + k_scale = float(1.8039721250534058) + v_scale = float(1.839004635810852) + a_scale = float(0.003919653594493866) + o_scale = float(1.8191684484481812) + q_zp = int(127) + k_zp = int(125) + v_zp = int(127) + a_zp = int(120) + o_zp = int(128) + atol, rtol = 1.0, 5e-6 + else: + q_scale = float(5.96875) + k_scale = float(5.78125) + v_scale = float(0.98046875) + a_scale = float(4.84375) + o_scale = float(3.171875) + atol, rtol = 0.125, 5e-6 q_shape = [batch_size, q_seq_len, n_head, head_dim] kv_shape = [batch_size, kv_seq_len, n_head, head_dim] mask_shape = [batch_size, 1, 1, kv_seq_len] - q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100 - k = ( - torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) - * 100 - ) - v = ( - torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) - * 100 - ) - q = q.to(torch.uint8) - k = k.to(torch.uint8) - v = v.to(torch.uint8) + q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) + k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) + v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2) + if input_dtype == torch.uint8: + q *= 100 + k *= 100 + v *= 100 + q = q.to(input_dtype) + k = k.to(input_dtype) + v = v.to(input_dtype) attn_mask = ( torch.randn(mask_shape, dtype=mask_dtype, device=device) if mask_dtype is not None @@ -211,44 +261,71 @@ def test_scaled_dot_product_int8_op( attn_mask.clone() if mask_dtype is not None else None, ) - math_ref = self._scaled_dot_product_int8_op_ref( - q2, - k2, - v2, - attn_mask=attn_mask, - dropout_p=0.0, - is_causal=False, - q_scale=q_scale, - q_zp=q_zp, - k_scale=k_scale, - k_zp=k_zp, - v_scale=v_scale, - v_zp=v_zp, - a_scale=a_scale, - a_zp=a_zp, - o_scale=o_scale, - o_zp=o_zp, - ) - actual = torch.ops.torchao.qscaled_dot_product( - q, - k, - v, - attn_mask=attn_mask_2, - dropout_p=0.0, - is_causal=False, - q_scale=q_scale, - q_zp=q_zp, - k_scale=k_scale, - k_zp=k_zp, - v_scale=v_scale, - v_zp=v_zp, - a_scale=a_scale, - a_zp=a_zp, - o_scale=o_scale, - o_zp=o_zp, - ) - - self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6) + if input_dtype == torch.uint8: + math_ref = self._scaled_dot_product_int8_op_ref( + q2, + k2, + v2, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + q_zp=q_zp, + k_scale=k_scale, + k_zp=k_zp, + v_scale=v_scale, + v_zp=v_zp, + a_scale=a_scale, + a_zp=a_zp, + o_scale=o_scale, + o_zp=o_zp, + ) + actual = torch.ops.torchao.qscaled_dot_product( + q, + k, + v, + attn_mask=attn_mask_2, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + q_zp=q_zp, + k_scale=k_scale, + k_zp=k_zp, + v_scale=v_scale, + v_zp=v_zp, + a_scale=a_scale, + a_zp=a_zp, + o_scale=o_scale, + o_zp=o_zp, + ) + else: + math_ref = self._scaled_dot_product_fp8_op_ref( + q2, + k2, + v2, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + a_scale=a_scale, + o_scale=o_scale, + ) + actual = torch.ops.torchao.qscaled_dot_product( + q, + k, + v, + attn_mask=attn_mask_2, + dropout_p=0.0, + is_causal=False, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + a_scale=a_scale, + o_scale=o_scale, + ) + self.assertEqual(actual.float(), math_ref.float(), atol=atol, rtol=rtol) instantiate_parametrized_tests(TestOps) @@ -281,25 +358,21 @@ def make_test_id(param): @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): N, K = shape assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - if TORCH_VERSION_AT_LEAST_2_5: - t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) - if TORCH_VERSION_AT_LEAST_2_5: - unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) assert torch.equal(t, unpacked) # TODO: Fix "test_aot_dispatch_dynamic" test failure @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): test_utils = [ @@ -308,13 +381,10 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - if TORCH_VERSION_AT_LEAST_2_5: - t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) + t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) opcheck( @@ -345,7 +415,6 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -413,7 +482,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant( # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -438,8 +506,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( # Unpack and dequantize unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) - if TORCH_VERSION_AT_LEAST_2_5: - unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) + unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8) dq_ao = groupwise_affine_dequantize_tensor_from_qparams( unpacked, scales, zeros, n_bit=4, groupsize=group_size @@ -479,7 +546,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") -# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str ) @@ -488,8 +554,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size device = "cuda" q = torch.randint(0, 16, shape, dtype=torch.int, device=device) - if TORCH_VERSION_AT_LEAST_2_5: - q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) + q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) q_groups = k // group_size scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) @@ -501,9 +566,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size "test_autograd_registration", "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") opcheck( torch.ops.torchao.dequantize_tensor_core_tiled_layout, (packed_w, scales_and_zeros, group_size, inner_k_tiles), @@ -766,9 +829,7 @@ def test_swizzle_mm(): "test_faketensor", ] - # TODO: Figure out why test fails unless torch >= 2.5 - if TORCH_VERSION_AT_LEAST_2_5: - test_utils.append("test_aot_dispatch_dynamic") + test_utils.append("test_aot_dispatch_dynamic") mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda") mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda") @@ -780,5 +841,69 @@ def test_swizzle_mm(): ) +EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10] +EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024] +EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512] +EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32] + +EMBEDINGBAG_TEST_PARAMS = list( + itertools.product( + EMBEDINGBAG_MULTIHOT_SIZES, + EMBEDINGBAG_BAG_SIZES, + EMBEDINGBAG_VECTOR_SIZES, + EMBEDINGBAG_INDEX_DTYPES, + ) +) + + +@pytest.mark.skipif( + "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), + reason="cpp kernels not built", +) +@pytest.mark.parametrize( + "multi_hot, batch_size, vector_size, index_type", + EMBEDINGBAG_TEST_PARAMS, + ids=str, +) +def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type): + qtype = torch.float8_e4m3fn + dtype = torch.float32 + weight_scale = torch.tensor([2.0]) + include_last_offset = True + mode = "sum" + + if mode == "sum": + mode_enum = 0 + elif mode == "mean": + mode_enum = 1 + elif mode == "max": + mode_enum = 2 + indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type) + offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(index_type) + + m = torch.nn.EmbeddingBag( + 1000, + vector_size, + mode=mode, + dtype=dtype, + include_last_offset=include_last_offset, + ) + fp8_weight = m.weight.data.to(qtype) + m.weight.data = fp8_weight.to(m.weight.dtype) + + with torch.no_grad(): + refe_out = m.forward(indices, offsets) * weight_scale + test_out = torch.ops.torchao._scaled_embedding_bag( + fp8_weight, + indices, + offsets, + weight_scale, + 1.0, + mode_enum, + include_last_offset, + ).to(dtype) + torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": pytest.main(sys.argv) diff --git a/test/test_utils.py b/test/test_utils.py index d41168b5a7..f06835c932 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,24 +4,26 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import unittest +import warnings from unittest.mock import patch import torch +from torchao.testing.utils import skip_if_no_cuda from torchao.utils import TorchAOBaseTensor, torch_version_at_least -class TestTorchVersionAtLeast(unittest.TestCase): +class TestTorchVersion(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ - ("2.5.0a0+git9f17037", "2.5.0", True), - ("2.5.0a0+git9f17037", "2.4.0", True), - ("2.5.0.dev20240708+cu121", "2.5.0", True), - ("2.5.0.dev20240708+cu121", "2.4.0", True), - ("2.5.0", "2.4.0", True), - ("2.5.0", "2.5.0", True), - ("2.4.0", "2.4.0", True), - ("2.4.0", "2.5.0", False), + ("2.5.0a0+git9f17037", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] + ("2.5.0a0+git9f17037", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] + ("2.5.0.dev20240708+cu121", "2.5.0", False), # [2, 5, -1] < [2, 5, 0] + ("2.5.0.dev20240708+cu121", "2.4.0", True), # [2, 5, -1] > [2, 4, 0] + ("2.5.0", "2.4.0", True), # [2, 5, 0] > [2, 4, 0] + ("2.5.0", "2.5.0", True), # [2, 5, 0] >= [2, 5, 0] + ("2.4.0", "2.4.0", True), # [2, 4, 0] >= [2, 4, 0] + ("2.4.0", "2.5.0", False), # [2, 4, 0] < [2, 5, 0] ] for torch_version, compare_version, expected_result in test_cases: @@ -34,6 +36,55 @@ def test_torch_version_at_least(self): f"Failed for torch.__version__={torch_version}, comparing with {compare_version}", ) + def test_torch_version_deprecation(self): + """ + Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER* + trigger deprecation warnings on use, not on import. + """ + # Reset deprecation warning state, otherwise we won't log warnings here + warnings.resetwarnings() + + # Importing and referencing should not trigger deprecation warning + with warnings.catch_warnings(record=True) as _warnings: + from torchao.utils import ( + TORCH_VERSION_AFTER_2_2, + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, + TORCH_VERSION_AT_LEAST_2_2, + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, + TORCH_VERSION_AT_LEAST_2_8, + ) + + deprecated_api_to_name = [ + (TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"), + (TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"), + (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), + (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), + (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), + (TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"), + (TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"), + (TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"), + (TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"), + (TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"), + (TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"), + ] + self.assertEqual(len(_warnings), 0) + + # Accessing the boolean value should trigger deprecation warning + with warnings.catch_warnings(record=True) as _warnings: + for api, name in deprecated_api_to_name: + num_warnings_before = len(_warnings) + if api: + pass + regex = f"{name} is deprecated and will be removed" + self.assertEqual(len(_warnings), num_warnings_before + 1) + self.assertIn(regex, str(_warnings[-1].message)) + class TestTorchAOBaseTensor(unittest.TestCase): def test_print_arg_types(self): @@ -46,9 +97,249 @@ def __init__(self, data): self.data = data l = torch.nn.Linear(10, 10) + # since we did not define `tensor_data_names` and `tensor_attribute_names` for MyTensor + # the following call will error out because `detach` is defined in `TorchAOBaseTensor` + # but would rely on `tensor_data_names` and `tensor_attribute_names` being defined for it to work + # user could either specify `tensor_data_names` and `tensor_attribute_names` or manually implement + # detach op with self.assertRaisesRegex(NotImplementedError, "arg_types"): l.weight = torch.nn.Parameter(MyTensor(l.weight)) + def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy): + # get `all_tensor_data_names` and `all_tensor_attribute_names` + all_tensor_data_names = lp_tensor.tensor_data_names.copy() + if hasattr(lp_tensor, "optional_tensor_data_names"): + for tensor_data_name in lp_tensor.optional_tensor_data_names: + if getattr(lp_tensor, tensor_data_name) is not None: + all_tensor_data_names.append(tensor_data_name) + all_tensor_attribute_names = lp_tensor.tensor_attribute_names.copy() + if hasattr(lp_tensor, "optional_tensor_attribute_names"): + for tensor_attribute_name in lp_tensor.optional_tensor_attribute_names: + if getattr(lp_tensor, tensor_attribute_name) is not None: + all_tensor_attribute_names.append(tensor_attribute_name) + + # test __tensor_flatten__ and __tensor_unflatten__ + tensor_data_names, tensor_attributes = lp_tensor.__tensor_flatten__() + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_names + } + outer_size = lp_tensor.size() + outer_stride = lp_tensor.stride() + reconstructed = type(lp_tensor).__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) + for tensor_data_name in all_tensor_data_names: + self.assertTrue( + torch.equal( + getattr(lp_tensor, tensor_data_name), + getattr(reconstructed, tensor_data_name), + ) + ) + for tensor_attribute_name in all_tensor_attribute_names: + self.assertEqual( + getattr(lp_tensor, tensor_attribute_name), + getattr(reconstructed, tensor_attribute_name), + ) + + self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata)) + self.assertEqual(lp_tensor.attr, reconstructed.attr) + + # `to` / `_to_copy` + original_device = lp_tensor.device + lp_tensor = lp_tensor.to("cuda") + self.assertEqual(lp_tensor.device.type, "cuda") + lp_tensor = lp_tensor.to(original_device) + self.assertEqual(lp_tensor.device, original_device) + + # __repr__ + _ = str(lp_tensor) + + # op test: detach + lp_tensor = lp_tensor.detach() + # op test: alias + lp_tensor = torch.ops.aten.alias(lp_tensor) + + # op test: clone + lp_tensor_clone = lp_tensor.clone() + + for tensor_data_name in all_tensor_data_names: + self.assertTrue( + torch.equal( + getattr(lp_tensor_clone, tensor_data_name), + getattr(lp_tensor, tensor_data_name), + ) + ) + for tensor_attribute_name in all_tensor_attribute_names: + self.assertEqual( + getattr(lp_tensor_clone, tensor_attribute_name), + getattr(lp_tensor, tensor_attribute_name), + ) + + # op test: transpose + # non optional and valid optional tensors + + # for each of the tensor data, we try to + # make it non-contiguous and then use + # lp_tensor.contiguous() call to make sure + # contiguous() works + for tensor_data_name in all_tensor_data_names: + tensor = getattr(lp_tensor, tensor_data_name) + # making qdata not contiguous + tensor = tensor.transpose(0, 1).contiguous() + tensor = tensor.transpose(0, 1) + setattr(lp_tensor, tensor_data_name, tensor) + self.assertFalse(getattr(lp_tensor, tensor_data_name).is_contiguous()) + + lp_tensor_t = lp_tensor.contiguous() + + # making sure contiguous call works + for tensor_data_name in all_tensor_data_names: + self.assertTrue(getattr(lp_tensor_t, tensor_data_name).is_contiguous()) + + # making sure transpose does not change attributes + for tensor_attribute_name in all_tensor_attribute_names: + self.assertEqual( + getattr(lp_tensor_t, tensor_attribute_name), + getattr(lp_tensor, tensor_attribute_name), + ) + + # op test: copy_ + # making sure that initially tensor values are not the same so we can test copy_ + self.assertNotEqual(lp_tensor.qdata[0][0], lp_tensor_for_copy.qdata[0][0]) + # copy_ requires the attributes to be the same + for tensor_attribute_name in all_tensor_attribute_names: + self.assertEqual( + getattr(lp_tensor_for_copy, tensor_attribute_name), + getattr(lp_tensor, tensor_attribute_name), + ) + + lp_tensor.copy_(lp_tensor_for_copy) + # after copy_, the tensor values should match + for tensor_data_name in all_tensor_data_names: + self.assertTrue( + torch.equal( + getattr(lp_tensor, tensor_data_name), + getattr(lp_tensor_for_copy, tensor_data_name), + ) + ) + # after copy_, the tensor attributes still matches + # copy_ requires the attributes to be the same + for tensor_attribute_name in all_tensor_attribute_names: + self.assertEqual( + getattr(lp_tensor_for_copy, tensor_attribute_name), + getattr(lp_tensor, tensor_attribute_name), + ) + + @skip_if_no_cuda() + def test_default_impls(self): + """Making sure some common functions has default implementations, such as + __tensor_unflatten__, __tensor_flatten__, _apply_fn_to_data, __repr__, to + """ + + class MyTensor(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr", "device"] + + def __new__(cls, qdata, attr, device): + shape = qdata.shape + if device is None: + device = qdata.device + kwargs = {"device": device} + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, qdata, attr, device): + self.qdata = qdata + self.attr = attr + + l = torch.nn.Linear(2, 3) + l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr", None)) + lp_tensor = l.weight + + another_tensor = torch.nn.Linear(2, 3).weight + # attribute has to be the same + lp_tensor_for_copy = MyTensor(another_tensor, "attr", None) + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + + @skip_if_no_cuda() + def test_default_impls_with_optional_data(self): + class MyTensorWithOptionalData(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr", "device"] + optional_tensor_data_names = ["zero_point"] + + def __new__(cls, qdata, attr, device, zero_point=None): + shape = qdata.shape + if device is None: + device = qdata.device + kwargs = {"device": device} + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, qdata, attr, device, zero_point=None): + self.qdata = qdata + self.attr = attr + self.zero_point = zero_point + + # test both the optional Tensor is None + # and not None + l = torch.nn.Linear(2, 3) + lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, None) + l = torch.nn.Linear(2, 3) + lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, "attr", None, None) + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + + l = torch.nn.Linear(2, 3) + lp_tensor = MyTensorWithOptionalData( + l.weight, "attr", None, torch.zeros_like(l.weight) + ) + l = torch.nn.Linear(2, 3) + lp_tensor_for_copy = MyTensorWithOptionalData( + l.weight, "attr", None, torch.zeros_like(l.weight) + ) + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + + @skip_if_no_cuda() + def test_default_impls_with_optional_attr(self): + class MyTensorWithOptionalData(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr", "device"] + optional_tensor_data_names = ["zero_point"] + optional_tensor_attribute_names = ["optional_attr"] + + def __new__(cls, qdata, attr, device, zero_point=None, optional_attr=None): + shape = qdata.shape + if device is None: + device = qdata.device + kwargs = {"device": device} + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, qdata, attr, device, zero_point=None, optional_attr=None + ): + self.qdata = qdata + self.attr = attr + self.zero_point = zero_point + self.optional_attr = optional_attr + + # test both the optional Tensor is None + # and not None + l = torch.nn.Linear(2, 3) + lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, zero_point=None) + l = torch.nn.Linear(2, 3) + lp_tensor_for_copy = MyTensorWithOptionalData( + l.weight, "attr", None, zero_point=None + ) + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + + l = torch.nn.Linear(2, 3) + lp_tensor = MyTensorWithOptionalData( + l.weight, "attr", None, zero_point=None, optional_attr="value" + ) + l = torch.nn.Linear(2, 3) + lp_tensor_for_copy = MyTensorWithOptionalData( + l.weight, "attr", None, zero_point=None, optional_attr="value" + ) + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + if __name__ == "__main__": unittest.main() diff --git a/third_party/cutlass b/third_party/cutlass index ad7b2f5e84..e51efbfe18 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e +Subproject commit e51efbfe18fe4f4cbb66ab814c55bf4aa0185491 diff --git a/torchao/__init__.py b/torchao/__init__.py index e6e291309f..3a25a72114 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -20,23 +20,41 @@ except PackageNotFoundError: __version__ = "unknown" # In case this logic breaks don't break the build -try: - from pathlib import Path - - so_files = list(Path(__file__).parent.glob("_C*.so")) - if len(so_files) > 0: - for file in so_files: - torch.ops.load_library(str(file)) - from . import ops - - # The following library contains CPU kernels from torchao/experimental - # They are built automatically by ao/setup.py if on an ARM machine. - # They can also be built outside of the torchao install process by - # running the script `torchao/experimental/build_torchao_ops.sh ` - # For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md - from torchao.experimental.op_lib import * # noqa: F403 -except Exception as e: - logging.debug(f"Skipping import of cpp extensions: {e}") +logger = logging.getLogger(__name__) + +skip_loading_so_files = False +# if torchao version has "+git", assume it's locally built and we don't know +# anything about the PyTorch version used to build it +# otherwise, assume it's prebuilt by torchao's build scripts and we can make +# assumptions about the PyTorch version used to build it. +if (not "+git" in __version__) and not ("unknown" in __version__): + # torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so + # files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. + # The following code skips importing the .so files if PyTorch 2.9+ is + # detected, to avoid crashing the Python process with "Aborted (core + # dumped)". + # TODO(#2901, and before next torchao release): make this generic for + # future torchao and torch versions + if __version__.startswith("0.13.0") and str(torch.__version__) >= "2.9": + logger.warning( + f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}" + ) + skip_loading_so_files = True + +if not skip_loading_so_files: + try: + from pathlib import Path + + so_files = list(Path(__file__).parent.glob("_C*.so")) + if len(so_files) > 0: + for file in so_files: + torch.ops.load_library(str(file)) + from . import ops + + # The following registers meta kernels for some CPU kernels + from torchao.csrc_meta_ops import * # noqa: F403 + except Exception as e: + logger.debug(f"Skipping import of cpp extensions: {e}") from torchao.quantization import ( autoquant, diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 4b761ad725..5d680bcf82 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -12,37 +12,17 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.quantize_per_channel_group to mitigate availability issue until it can be supplanted by new quantize_affine function. - - torch.ops.quantized_decomposed.quantize_per_channel_group is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_channel_group( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs) def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later." + return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + *args, **kwargs ) @@ -50,50 +30,21 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.dequantize_per_channel_group to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.dequantize_per_channel_group is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_channel_group( - *args, **kwargs - ) - raise ImportError( - "Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs) def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.quantize_per_token to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.quantize_per_token is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) - raise ImportError( - "Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): """ Wrapper around torch.ops.quantized_decomposed.dequantize_per_token to mitigate availability issue until it can be supplanted by new choose_qparams_affine function. - - torch.ops.quantized_decomposed.dequantize_per_token is only available - in PyTorch 2.3+ and recently changed signatures. """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) - raise ImportError( - "Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later." - ) + return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) diff --git a/torchao/_models/README.md b/torchao/_models/README.md index 074adf884c..300f1ed7d3 100644 --- a/torchao/_models/README.md +++ b/torchao/_models/README.md @@ -1,4 +1,43 @@ -## SAM2 +# LLAMA + +## Eval on Llama 3.1 8B and Llama 3.2 3B + +We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below: + +| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Model Size (GB) | +|------------|---------------------------|-------|--------|----------------|-------------------| +| Llama 3.1 8B | None | 60.01 | 78.84 | 7.33 | 15.01 | +| Llama 3.1 8B | int4wo-128 | 58.10 | 77.06 | 8.25 | 4.76 | +| Llama 3.1 8B | int8wo | 59.92 | 78.95 | 7.34 | 8.04 | +| Llama 3.1 8B | int8dq | 60.01 | 78.82 | 7.45 | 8.03 | +| Llama 3.1 8B | float8wo | 59.83 | 78.61 | 7.37 | 8.03 | +| Llama 3.1 8B | float8dq (PerRow) | 59.86 | 78.57 | 7.41 | 8.04 | +| Llama 3.1 8B | float8dq (PerTensor) | 59.95 | 78.66 | 7.42 | 8.03 | +| Llama 3.1 8B | gemlite (gp=128) | 58.48 | 77.34 | 8.07 | 4.76 | + +| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Model Size (GB) | +|------------|---------------------------|-------|--------|----------------|-------------------| +| Llama 3.2 3B | None | 55.27 | 73.70 | 9.26 | 6.43 | +| Llama 3.2 3B | int4wo-128 | 53.13 | 71.31 | 10.36 | 2.29 | +| Llama 3.2 3B | int8wo | 55.15 | 73.44 | 9.28 | 3.61 | +| Llama 3.2 3B | int8dq | 55.00 | 73.29 | 9.43 | 3.61 | +| Llama 3.2 3B | float8wo | 55.18 | 73.58 | 9.31 | 3.61 | +| Llama 3.2 3B | float8dq (PerRow) | 55.18 | 73.37 | 9.33 | 3.61 | +| Llama 3.2 3B | float8dq (PerTensor) | 55.16 | 73.53 | 9.35 | 3.61 | +| Llama 3.2 3B | gemlite (gp=128) | 53.71 | 71.99 | 10.05 | 2.29 | + +To generate the above results run: +``` +sh benchmarks/_models/eval_hf_models.sh +``` + +To run lm-eval for a different hf-model with AO quantization technique, run: +``` +python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag +``` +Replace model id, quantization and tasks with your desired values Please refer to ([HuggingFace <-> TorchAO](https://huggingface.co/docs/transformers/main/en//quantization/torchao)) integration docs for more details about the supported quantization techniques. + +# SAM2 sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc It includes diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index faf059c400..de7f010035 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -57,8 +57,13 @@ def _model_call(self, inps): max_seq_length = min(max(inps.size()), self.max_length) with torch.device(self._device): - self._model.setup_caches(self.batch_size, max_seq_length) + if hasattr(self._model, "setup_caches"): + self._model.setup_caches(self.batch_size, max_seq_length) logits = self._model(*input) + from transformers.modeling_outputs import CausalLMOutputWithPast + + if isinstance(logits, CausalLMOutputWithPast): + logits = logits.logits return logits def run_eval(self, tasks, limit): @@ -84,7 +89,11 @@ def eot_token_id(self): try: return self.tokenizer.eos_id() except: - return self.tokenizer.eos_id + try: + return self.tokenizer.eos_id + except: + idx = self.tokenizer.all_special_tokens.index("<|endoftext|>") + return self.tokenizer.all_special_ids[idx] @property def max_length(self): @@ -102,8 +111,8 @@ def batch_size(self): def device(self): return self._device - def tok_decode(self, tokens): - decoded = self.tokenizer.decode(tokens) + def tok_decode(self, tokens, **kwargs): + decoded = self.tokenizer.decode(tokens, **kwargs) return decoded def tok_encode(self, string: str, **kwargs): @@ -115,9 +124,6 @@ def tok_encode(self, string: str, **kwargs): tokens = [self.tokenizer.bos_id] + tokens return tokens - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - class LMEvalInputRecorder(TransformerEvalWrapper): def __init__( diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 8ee15f1fd3..fdd9792cb4 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -17,18 +17,17 @@ import torchao from torchao._models.llama.model import prepare_inputs_for_model from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, PerRow, PerTensor, - float8_dynamic_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - int4_weight_only, - int8_dynamic_activation_int8_weight, - int8_weight_only, + UIntXWeightOnlyConfig, quantize_, - uintx_weight_only, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass def run_evaluation( @@ -74,11 +73,11 @@ def run_evaluation( apply_spinquant(model) if "int8wo" in quantization: - quantize_(model, int8_weight_only()) + quantize_(model, Int8WeightOnlyConfig()) if "int8dq" in quantization: - quantize_(model, int8_dynamic_activation_int8_weight()) + quantize_(model, Int8DynamicActivationInt8WeightConfig()) if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) + quantize_(model, FPXWeightOnlyConfig(3, 2)) if "int4wo" in quantization and not "gptq" in quantization: if "hqq" in quantization: use_hqq = True @@ -90,7 +89,7 @@ def run_evaluation( ) quantize_( model.to(device), - int4_weight_only(group_size=groupsize, use_hqq=use_hqq), + Int4WeightOnlyConfig(group_size=groupsize, use_hqq=use_hqq, version=1), ) if "uintx" in quantization: # uintx-nbits-groupsize @@ -113,11 +112,13 @@ def run_evaluation( } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) - quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + quantize_(model, UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)) if "marlin" in quantization: from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) + quantize_( + model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1) + ) if "int4wo" in quantization and "gptq" in quantization: # avoid circular imports from torchao._models._eval import LMEvalInputRecorder @@ -151,11 +152,8 @@ def run_evaluation( model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) quantizer.quantize(model, *inputs) model = model.to(device) - else: - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) if "float8wo" in quantization: - quantize_(model, float8_weight_only()) + quantize_(model, Float8WeightOnlyConfig()) if "float8dq" in quantization: granularity = str(quantization.split("-")[-1]) if granularity == "tensor": @@ -168,7 +166,8 @@ def run_evaluation( else: raise ValueError(f"Unknown granularity {granularity}") quantize_( - model, float8_dynamic_activation_float8_weight(granularity=granularity) + model, + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), ) if "autoround" in quantization: from transformers import AutoTokenizer @@ -237,6 +236,41 @@ def run_evaluation( quantize_( model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64) ) + elif quantization.startswith("awq-uintx"): + from torchao._models._eval import TransformerEvalWrapper + from torchao.prototype.awq import ( + AWQObservedLinear, + awq_uintx, + insert_awq_observer_, + ) + + quant_dtype = quantization.split("-")[1] + group_size = int(quantization.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.uint8) + model = model.to(device) + # get calibration data + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=256, + input_prep_func=prepare_inputs_for_model, + device=device, + ).run_eval( + tasks=["wikitext"], + limit=1, + ) + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quantization + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) if compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 8f02e83a99..da1b848bcb 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -20,11 +20,7 @@ write_json_result_ossci, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - get_model_size_in_bytes, -) +from torchao.utils import get_model_size_in_bytes torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False torch.backends.cuda.enable_cudnn_sdp(True) @@ -47,6 +43,8 @@ def elapsed_time(self, other_event): def device_timer(device): if "cuda" in device: return torch.cuda.Event(enable_timing=True) + elif "xpu" in device: + return torch.xpu.Event(enable_timing=True) elif ("cpu" in device) or ("mps" in device): return HostEvent() else: @@ -342,21 +340,20 @@ def ffn_or_attn_only(mod, fqn): if quantization: from torchao.quantization import ( Float8DynamicActivationFloat8SemiSparseWeightConfig, + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + UIntXWeightOnlyConfig, autoquant, - float8_dynamic_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, quantize_, - uintx_weight_only, ) from torchao.quantization.granularity import PerRow, PerTensor - from torchao.utils import unwrap_tensor_subclass if "spinquant" in quantization: from torchao.prototype.spinquant import apply_spinquant @@ -378,7 +375,7 @@ def ffn_or_attn_only(mod, fqn): quantize_( model, - gemlite_uintx_weight_only( + GemliteUIntXWeightOnlyConfig( bit_width=bit_width, group_size=group_size, mode=mode ), ) @@ -398,25 +395,28 @@ def ffn_or_attn_only(mod, fqn): gemlite.cache_config(config_file) if "int8wo" in quantization: - quantize_(model, int8_weight_only()) + quantize_(model, Int8WeightOnlyConfig()) if "int8dq" in quantization: if sparsity and "semi" in sparsity: from torchao.dtypes import SemiSparseLayout quantize_( model, - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), filter_fn=ffn_only, ) quantize_( - model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only + model, + Int8DynamicActivationInt8WeightConfig(), + filter_fn=not_ffn_only, ) elif "int8dq_prefill_wo_decode" in quantization: quantize_( - model, int8_dynamic_activation_int8_weight(weight_only_decode=True) + model, + Int8DynamicActivationInt8WeightConfig(weight_only_decode=True), ) else: - quantize_(model, int8_dynamic_activation_int8_weight()) + quantize_(model, Int8DynamicActivationInt8WeightConfig()) if "int4wo" in quantization: use_hqq = False if "hqq" in quantization: @@ -430,25 +430,9 @@ def ffn_or_attn_only(mod, fqn): ], ( f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" ) - quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) - elif "fbgemm" in quantization and "int4" in quantization: - from torchao.quantization import FbgemmConfig - - _, precision, group_size = quantization.split("-") - group_size = int(group_size) - block_size = [1, group_size] - assert precision == "int4", f"FbegemmConfig({precision=}) not supported yet" - quantize_( - model, - FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, block_size), - ) - elif "fbgemm" in quantization and "fp8" in quantization: - from torchao.float8.config import e4m3_dtype - from torchao.quantization import FbgemmConfig - quantize_( model, - FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16), + Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1), ) elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout @@ -458,7 +442,7 @@ def ffn_or_attn_only(mod, fqn): if nbits == 4: quantize_( model, - int4_dynamic_activation_int4_weight( + Int4DynamicActivationInt4WeightConfig( mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, layout=CutlassInt4PackedLayout(), @@ -467,7 +451,7 @@ def ffn_or_attn_only(mod, fqn): elif nbits == 8: quantize_( model, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=None, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, @@ -480,7 +464,7 @@ def ffn_or_attn_only(mod, fqn): quantize_( model, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=128, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, @@ -492,24 +476,19 @@ def ffn_or_attn_only(mod, fqn): quantize_( model, - int4_weight_only(layout=MarlinSparseLayout()), + Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1), filter_fn=ffn_or_attn_only, ) if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) + quantize_(model, FPXWeightOnlyConfig(3, 2)) elif "embed-int8wo" in quantization: quantize_( model, - int8_weight_only(group_size=64), + Int8WeightOnlyConfig(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), ) elif quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - - if not TORCH_VERSION_AT_LEAST_2_3: - print("Awq requires torch2.3+") - exit() from torchao.prototype.awq import ( AWQObservedLinear, awq_uintx, @@ -565,16 +544,12 @@ def ffn_or_attn_only(mod, fqn): } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) - quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + quantize_(model, UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)) elif "int8_dynamic_activation_intx_weight" in quantization: - assert TORCH_VERSION_AT_LEAST_2_6, ( - "int8_dynamic_activation_intx_weight requires torch2.6+" - ) assert precision == torch.float32, ( "int8_dynamic_activation_intx_weight requires using precision=torch.float32" ) - from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( Int8DynamicActivationIntxWeightConfig, @@ -594,12 +569,11 @@ def ffn_or_attn_only(mod, fqn): weight_mapping_type=MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC, - weight_scale_dtype=torch.bfloat16, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + intx_packing_format="opaque_torchao_auto", ), ) elif "float8wo" in quantization: - quantize_(model, float8_weight_only()) + quantize_(model, Float8WeightOnlyConfig()) elif "float8dq" in quantization: if sparsity and "semi" in sparsity: quantize_( @@ -617,7 +591,7 @@ def ffn_or_attn_only(mod, fqn): granularity = PerTensor() quantize_( model, - float8_dynamic_activation_float8_weight(granularity=granularity), + Float8DynamicActivationFloat8WeightConfig(granularity=granularity), ) elif "autoquant_v2" in quantization: from torchao._models._eval import LMEvalInputRecorder @@ -829,10 +803,6 @@ def ffn_or_attn_only(mod, fqn): model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64) ) - else: - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - # standalone sparsity elif sparsity: from torchao.sparsity import semi_sparse_weight, sparsify_ diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 45dd2e9f29..46d8c2484a 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -192,7 +192,7 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out -from torchao.quantization.utils import quantize_activation_per_token_absmax +from torchao.quantization.utils import _quantize_activation_per_token_absmax class AffineQuantizedKVCache(nn.Module): @@ -218,13 +218,13 @@ def __init__( def update(self, input_pos, k_val, v_val): # quantize current k_val and store it in the cache - q_k_val, k_scale = quantize_activation_per_token_absmax(k_val) + q_k_val, k_scale = _quantize_activation_per_token_absmax(k_val) self.k_cache[:, :, input_pos] = q_k_val self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1) k_out = self.k_cache * self.k_cache_scale k_out[:, :, input_pos] = k_val - q_v_val, v_scale = quantize_activation_per_token_absmax(v_val) + q_v_val, v_scale = _quantize_activation_per_token_absmax(v_val) self.v_cache[:, :, input_pos] = q_v_val self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1) v_out = self.v_cache * self.v_cache_scale diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py index 11a53043ad..39ee6a4dcb 100644 --- a/torchao/_models/mixtral-moe/generate.py +++ b/torchao/_models/mixtral-moe/generate.py @@ -248,7 +248,6 @@ def main( Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationIntxWeightConfig, Int8WeightOnlyConfig, - PackedLinearInt8DynamicActivationIntxWeightLayout, PerRow, quantize_, ) @@ -275,11 +274,11 @@ def main( ) elif "int4wo-base" in moe_quant: - config = MoEQuantConfig(Int4WeightOnlyConfig()) + config = MoEQuantConfig(Int4WeightOnlyConfig(version=1)) elif "int4wo" in moe_quant: config = MoEQuantConfig( - Int4WeightOnlyConfig(), + Int4WeightOnlyConfig(version=1), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, ) @@ -306,7 +305,7 @@ def main( elif "intxdq" in moe_quant: config = MoEQuantConfig( Int8DynamicActivationIntxWeightConfig( - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + intx_packing_format="opaque_torchao_auto", ), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, ) diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index a0410fb734..467e24a9b6 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -22,13 +22,12 @@ from torchao.dtypes import SemiSparseLayout from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 from torchao.quantization import ( + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, autoquant, - int4_weight_only, - int8_dynamic_activation_int8_weight, quantize_, ) from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass torch._dynamo.config.cache_size_limit = 50000 @@ -363,11 +362,9 @@ def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and "mlp" in name if compress == "int8_dynamic_quant": - quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) + quantize_( + predictor.model.image_encoder, Int8DynamicActivationInt8WeightConfig() + ) elif compress == "sparse_mlp_only": def mlp_only(mod, name): @@ -386,19 +383,15 @@ def mlp_only(mod, name): quantize_( predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(), + Int8DynamicActivationInt8WeightConfig(), attn_only, ) quantize_( predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), mlp_lin1_only, ) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) elif compress == "int4_weight_only_sparse": # apply sparsify first to set qparams apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) @@ -406,19 +399,15 @@ def mlp_only(mod, name): quantize_( predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(), + Int8DynamicActivationInt8WeightConfig(), attn_only, ) quantize_( predictor.model.image_encoder, - int4_weight_only(layout=MarlinSparseLayout()), + Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1), mlp_lin1_only, ) sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only) - if not TORCH_VERSION_AT_LEAST_2_5: - predictor.model.image_encoder = unwrap_tensor_subclass( - predictor.model.image_encoder - ) elif compress is not None and "autoquant_v2" in compress: example_input = torch.randn( diff --git a/torchao/core/config.py b/torchao/core/config.py index 3451b90c59..330e6a42af 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -8,10 +8,21 @@ import enum import importlib import json -from typing import Any, ClassVar, Dict +import warnings +from typing import Any, Dict import torch +__all__ = [ + "AOBaseConfig", + "config_from_dict", + "config_to_dict", + "ALLOWED_AO_MODULES", +] + +# the default version for all configs, should never change +_DEFAULT_VERSION = 1 + class AOBaseConfig(abc.ABC): """ @@ -38,22 +49,22 @@ def _transform( """ - # Base Version of a config - VERSION: ClassVar[int] = 1 + """ + Note: this is not the version of AOBaseConfig, but the default version for instances of + all child configs inheriting from AOBaseConfig, and it should be `_DEFAULT_VERSION` and never change + this is making sure all config instances has a version defined, when they need to bump the default + version they have to define a instance variable version for the child config to overwrite the default version + that's defined here. Different child config instances will maintain their own version. + Why version is instance variable instead of class variable? instance level version is needed becuase + when we have multiple versions co-exist, we need to be able to load objects with earlier versions, + class level version is global and can't achieve this goal so we have to use instance variable. -class VersionMismatchError(Exception): - """Raised when trying to deserialize a config with a different version""" + to overwrite this in subclasses, we need to define `version: int` (with type annotations) - def __init__(self, type_path, stored_version, current_version): - self.type_path = type_path - self.stored_version = stored_version - self.current_version = current_version - message = ( - f"Version mismatch for {type_path}: " - f"stored version {stored_version} != current version {current_version}" - ) - super().__init__(message) + default Version of a config, should never change + """ + version: int = _DEFAULT_VERSION class ConfigJSONEncoder(json.JSONEncoder): @@ -65,14 +76,14 @@ def default(self, o): data_dict = {} # Process each attribute to handle nested objects for k, v in o.__dict__.items(): - if not k.startswith("_") and k != "VERSION": + if not k.startswith("_") and k != "version": # Recursively encode each value (important for nested objects) data_dict[k] = self.encode_value(v) return { # Only store the class name, not the full module path "_type": o.__class__.__name__, - "_version": getattr(o.__class__, "VERSION", 1), + "_version": getattr(o, "version", _DEFAULT_VERSION), "_data": data_dict, } @@ -86,7 +97,7 @@ def default(self, o): return { "_type": o.__class__.__name__, - "_version": getattr(o.__class__, "VERSION", 1), + "_version": getattr(o, "version", _DEFAULT_VERSION), "_data": processed_data, } @@ -95,13 +106,13 @@ def default(self, o): data_dict = {} # Process each field to handle nested objects for f in dataclasses.fields(o): - if f.name != "VERSION": + if f.name != "version": data_dict[f.name] = self.encode_value(getattr(o, f.name)) return { # Only store the class name for dataclasses too "_type": o.__class__.__name__, - "_version": getattr(o.__class__, "VERSION", 1), + "_version": getattr(o, "version", _DEFAULT_VERSION), "_data": data_dict, } @@ -182,7 +193,12 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: "torchao.sparsity.sparse_api", "torchao.prototype.quantization", "torchao.prototype.mx_formats", + "torchao.prototype.parq", "torchao.dtypes", + "torchao.prototype.awq", + "torchao.prototype.parq.quant", + "torchao.quantization.quantize_.common", + "torchao.quantization.quantize_.workflows", } @@ -197,7 +213,6 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: An instance of the appropriate AOBaseConfig subclass Raises: - VersionMismatchError: If the stored version doesn't match the class version ValueError: If deserialization fails for other reasons """ if not isinstance(data, dict): @@ -207,7 +222,7 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: raise ValueError("Input dictionary missing required '_type' or '_data' fields") type_path = data["_type"] - stored_version = data.get("_version", 1) + stored_version = data.get("_version", _DEFAULT_VERSION) obj_data = data["_data"] # Handle torch.dtype @@ -232,10 +247,11 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}" ) - # Check version - require exact match - current_version = getattr(cls, "VERSION", 1) - if stored_version != current_version: - raise VersionMismatchError(type_path, stored_version, current_version) + current_default_version = getattr(cls, "version", _DEFAULT_VERSION) + if stored_version != current_default_version: + warnings.warn( + f"Stored version is not the same as current default version of the config: {stored_version=}, {current_default_version=}, please check the deprecation warning" + ) # Handle the case where obj_data is not a dictionary if not isinstance(obj_data, dict): @@ -250,7 +266,11 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: return obj_data # Process nested structures for dictionary obj_data - processed_data = {} + if stored_version != current_default_version: + processed_data = {"version": stored_version} + else: + processed_data = {} + for key, value in obj_data.items(): if isinstance(value, dict) and "_type" in value and "_data" in value: # Recursively handle nested configs diff --git a/torchao/csrc/cpu/CMakeLists.txt b/torchao/csrc/cpu/CMakeLists.txt new file mode 100644 index 0000000000..aaea27ec74 --- /dev/null +++ b/torchao/csrc/cpu/CMakeLists.txt @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +include(CMakeDependentOption) + +project(torchao) + +set(CMAKE_CXX_STANDARD 17) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# Platform options +option(TORCHAO_BUILD_ATEN_OPS "Building torchao ops for ATen." ON) +option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF) +option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF) +option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) +option(TORCHAO_ENABLE_ARM_NEON_DOT "Enable ARM Neon Dot Product extension" OFF) +option(TORCHAO_ENABLE_ARM_I8MM "Enable ARM 8-bit Integer Matrix Multiply instructions" OFF) +option(TORCHAO_BUILD_TESTS "Build tests" OFF) +option(TORCHAO_BUILD_BENCHMARKS "Build tests" OFF) + +# Set default compiler options +add_compile_options("-fPIC" "-Wall" "-Werror" "-Wno-deprecated") +if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + add_compile_options( + "-Wno-error=unknown-pragmas" + "-Wno-array-parameter" + "-Wno-maybe-uninitialized" + "-Wno-sign-compare" + ) +elseif (APPLE) + add_compile_options("-Wno-shorten-64-to-32") +endif() + + + +if (NOT TARGET cpuinfo) + cmake_policy(PUSH) + cmake_policy(VERSION 3.5) # cpuinfo requires CMake 3.5 + + # For some reason cpuinfo package has unused functions/variables + # TODO (T215533422): fix upstream + add_compile_options(-Wno-unused-function -Wno-unused-variable) + + # set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + include(FetchContent) + set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "" FORCE) + set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "" FORCE) + set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) + FetchContent_Declare(cpuinfo + GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git + GIT_TAG c61fe919607bbc534d7a5a5707bdd7041e72c5ff + ) + FetchContent_MakeAvailable( + cpuinfo) + + cmake_policy(POP) +endif() + +if (TORCHAO_BUILD_TESTS) + include(FetchContent) + FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip + ) + FetchContent_MakeAvailable(googletest) +endif() + +if (TORCHAO_BUILD_BENCHMARKS) + include(FetchContent) + FetchContent_Declare(googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG main) # need main for benchmark::benchmark + + set(BENCHMARK_ENABLE_TESTING OFF) + FetchContent_MakeAvailable( + googlebenchmark) +endif() + +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +endif() + +if(NOT DEFINED TORCHAO_PARALLEL_BACKEND) + set(TORCHAO_PARALLEL_BACKEND aten_openmp) +endif() + +# Set default compiler options + +include(CMakePrintHelpers) +include(${CMAKE_CURRENT_SOURCE_DIR}/shared_kernels/Utils.cmake) + +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) + + +# Build fallback kernels +add_subdirectory(torch_free_kernels/fallback) + +# Build cpu/aarch64 kernels +if(TORCHAO_BUILD_CPU_AARCH64) + message(STATUS "Building with cpu/aarch64") + add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) + + if(TORCHAO_ENABLE_ARM_NEON_DOT) + message(STATUS "Building with ARM NEON dot product support") + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) + add_compile_options("-march=armv8.4-a+dotprod") + endif() + + if(TORCHAO_ENABLE_ARM_I8MM) + message(STATUS "Building with ARM I8MM support") + add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) + add_compile_options("-march=armv8.6-a") + endif() + + if(TORCHAO_BUILD_KLEIDIAI) + message(STATUS "Building with Arm KleidiAI library") + add_compile_definitions(TORCHAO_ENABLE_KLEIDI) + if (NOT TARGET kleidiai) + include(FetchContent) + # KleidiAI is an open-source library that provides optimized + # performance-critical routines, also known as micro-kernels, for artificial + # intelligence (AI) workloads tailored for Arm® CPUs. + set(KLEIDIAI_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(KLEIDIAI_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) + FetchContent_Declare(kleidiai + GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git + GIT_TAG v1.12.0 + ) + FetchContent_MakeAvailable(kleidiai) + endif() + endif() + + # Defines torchao_kernels_aarch64 + add_subdirectory(torch_free_kernels/aarch64) +endif() + +# Build ATen ops +if(TORCHAO_BUILD_ATEN_OPS) + find_package(Torch REQUIRED) + set(_torchao_op_srcs_aten) + list(APPEND _torchao_op_srcs_aten + shared_kernels/embedding_xbit/op_embedding_xbit_aten.cpp + shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp + shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp + shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp + shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp + ) + list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") + + # Use the Python extension name if provided + add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten}) + if(DEFINED TORCHAO_CMAKE_EXT_SO_NAME) + message(STATUS "Setting output name to: ${TORCHAO_CMAKE_EXT_SO_NAME}.so") + set_target_properties(torchao_ops_aten PROPERTIES + OUTPUT_NAME ${TORCHAO_CMAKE_EXT_SO_NAME} + PREFIX "" # Remove "lib" prefix for Python extensions + SUFFIX ".so" # Add ".so" suffix for Python extensions + ) + endif() + + target_link_torchao_parallel_backend(torchao_ops_aten "${TORCHAO_PARALLEL_BACKEND}") + if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries(torchao_ops_aten PRIVATE torchao_kernels_aarch64) + if (TORCHAO_BUILD_KLEIDIAI) + target_link_libraries(torchao_ops_aten PRIVATE kleidiai) + endif() + endif() + target_link_libraries(torchao_ops_aten PRIVATE cpuinfo) + target_include_directories(torchao_ops_aten PRIVATE "${TORCH_INCLUDE_DIRS}") + target_link_libraries(torchao_ops_aten PRIVATE "${TORCH_LIBRARIES}") + target_compile_definitions(torchao_ops_aten PRIVATE TORCHAO_SHARED_KERNELS_BUILD_ATEN=1) + + if (TORCHAO_BUILD_TESTS) + add_subdirectory(shared_kernels/tests) + endif() + + if (TORCHAO_BUILD_BENCHMARKS) + add_subdirectory(shared_kernels/benchmarks) + endif() + + # Install ATen targets + install( + TARGETS torchao_ops_aten + EXPORT _targets + DESTINATION lib + ) +endif() + + +# Build ExecuTorch ops +if(TORCHAO_BUILD_EXECUTORCH_OPS) + # ExecuTorch package is not required, but EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES must + # be defined and EXECUTORCH_LIBRARIES must include the following libraries installed by ExecuTorch: + # libexecutorch.a + # libextension_threadpool.a + # libcpuinfo.a + # libpthreadpool.a + if(NOT DEFINED EXECUTORCH_INCLUDE_DIRS AND NOT DEFINED EXECUTORCH_LIBRARIES) + message(WARNING "EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES are not defined. Looking for ExecuTorch.") + find_package(ExecuTorch HINTS ${CMAKE_PREFIX_PATH}/executorch/share/cmake) + endif() + set(_torchao_op_srcs_executorch) + list(APPEND _torchao_op_srcs_executorch + shared_kernels/embedding_xbit/op_embedding_xbit_executorch.cpp + shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp + shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp + shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp + shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp) + + list(TRANSFORM _torchao_op_srcs_executorch PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") + add_library(torchao_ops_executorch STATIC ${_torchao_op_srcs_executorch}) + + target_compile_definitions(torchao_ops_executorch PRIVATE TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH=1) + + # This links to ExecuTorch + target_link_torchao_parallel_backend(torchao_ops_executorch executorch) + if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries(torchao_ops_executorch PRIVATE torchao_kernels_aarch64) + if (TORCHAO_BUILD_KLEIDIAI) + target_link_libraries(torchao_ops_executorch PRIVATE kleidiai) + endif() + endif() + target_link_libraries(torchao_ops_executorch PRIVATE cpuinfo) +endif() diff --git a/torchao/csrc/cpu/README.md b/torchao/csrc/cpu/README.md new file mode 100644 index 0000000000..91cccd6978 --- /dev/null +++ b/torchao/csrc/cpu/README.md @@ -0,0 +1,11 @@ +# CPU kernels + +CPU kernels are contained in 3 directories: + +* torch_free_kernels: This directory contains CPU kernels written with raw pointers and do not use any PyTorch concepts like Tensor. + +* shared_kernels: This directory is for kernels that are shared between PyTorch/ATen and Executorch. They can be compiled with either platform using compile flags. Kernels in this directory often use torch_free_kernels in their implementation. + +* aten_kernels: This directory is for kernels written for PyTorch/ATen. + +If possible, we prefer contributors write a shared kernel when constributing new code. diff --git a/torchao/csrc/cpu/da8w4_linear.cpp b/torchao/csrc/cpu/aten_kernels/da8w4_linear.cpp similarity index 99% rename from torchao/csrc/cpu/da8w4_linear.cpp rename to torchao/csrc/cpu/aten_kernels/da8w4_linear.cpp index 537aa0fce9..7781ad7d47 100644 --- a/torchao/csrc/cpu/da8w4_linear.cpp +++ b/torchao/csrc/cpu/aten_kernels/da8w4_linear.cpp @@ -65,11 +65,13 @@ da8w4_linear_prepack_impl( at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); at::Tensor blocked_qzeros = new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); // Compensation = Σ(k)(W[k][n] - ZP[n]) for each block. + // Reorder compensation to [N/block_n, K/block_k, block_n] auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) - new_qzeros.view({Nc, block_n, G, -1}); weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, block_k}); at::Tensor compensation = weight_sub_qzero.sum(-1); compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt); +#if defined(CPU_CAPABILITY_AVX512) if (cpublas_could_pack()) { blocked_weight = at::empty({Nc, Kc, block_k, block_n / 2}, weight.options()); auto weight_ptr = weight_reordered.data_ptr(); @@ -105,7 +107,9 @@ da8w4_linear_prepack_impl( } } }); - } else { + } else +#endif + { // Pack weight: two int4 -> one int8 using namespace at::indexing; at::Tensor even_columns = @@ -619,9 +623,9 @@ void _da8w4_linear_impl( } else if (M < 64) { return 32; } else if (M < 96) { - return 48; - } else { return 64; + } else { + return 128; } }(); int64_t Mc = (M + block_m - 1) / block_m; diff --git a/torchao/csrc/cpu/int8_sdpa.cpp b/torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp similarity index 72% rename from torchao/csrc/cpu/int8_sdpa.cpp rename to torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp index a5928f6d9a..5abd3c66b9 100644 --- a/torchao/csrc/cpu/int8_sdpa.cpp +++ b/torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -82,12 +83,55 @@ inline void _store(scalar_t* dst, at::vec::Vectorized src, int size=at } template -inline typename std::enable_if_t || std::is_same_v, void> +inline typename std::enable_if_t || std::is_same_v || std::is_same_v, void> _store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { auto res = at::vec::convert(src); res.store(dst, size); } +/* +out = val * a + b +is_b_stride_zero: If the stride of b is 0 (mask broadcasting case), + take b as a scalar pointer. +*/ +template +inline void _scale_dequant_attn_mask_fusion_kernel( + T1* a, + T2* b, + const int& size, + T1* out, + const T1& val) { + const auto vec_size1 = at::vec::Vectorized::size(); + const auto vec_size2 = at::vec::Vectorized::size(); + constexpr int64_t T1_n = + (vec_size2 == vec_size1 * 2 && at::vec::is_reduced_floating_point_v) ? 2 : 1; + constexpr int64_t T2_n = 1; + auto vec_scale = at::vec::VectorizedN(val); + int64_t i = 0; + for (; i < size - (size % vec_size2); i += vec_size2) { + auto a_n = at::vec::VectorizedN::loadu(a + i); + at::vec::VectorizedN b_n; + if constexpr(is_b_stride_zero) { + b_n = at::vec::VectorizedN((T1)b[0]); + } else { + b_n = at::vec::VectorizedN::loadu(b + i); + } + auto b_n_convert = at::vec::convert(b_n); + auto res = a_n * vec_scale + b_n_convert; + res.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + T1 tmp1; + if constexpr(is_b_stride_zero) { + tmp1 = (T1)b[0]; + } else { + tmp1 = (T1)b[i]; + } + out[i] = tmp0 * val + tmp1; + } +} + /* 1. dequant 2. add mask @@ -618,7 +662,7 @@ inline void _int_sum_a_contiguous_kernel( // do the transpose: [in_rows, in_cols] -> [in_cols, in_rows] template inline void do_transpose( - scalar_t* src, + const scalar_t* src, scalar_t* dst, int64_t in_rows, int64_t in_cols, @@ -673,7 +717,7 @@ inline void pad_remain_row_col( // copy value_ptr to dst_ptr with padding: [rows, cols] -> [prows, pcols] template inline void copy_value_with_pad( - scalar_t* value_ptr, + const scalar_t* value_ptr, scalar_t* dst_ptr, int rows, int cols, @@ -725,13 +769,122 @@ inline void copy_value_with_pad( } +/* +1. out = a * scale +2. max = max(out) +*/ +template +inline void _mul_reduce_max_fusion_kernel( + const scalar_t* a, + const scalar_t& scale, + const int& size, + scalar_t* out, + scalar_t& max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + _store(out + i, tmp1); + } + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + auto reduced_tmp_max = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + vec_tmp_max); + // Guard against Q*K^T being NaN + max = std::isnan(reduced_tmp_max) ? std::numeric_limits::quiet_NaN() + : std::max(tmp_max, reduced_tmp_max); +} + +/* +1. out = exp(a - val) +2. val = sum(out) +3. quant +*/ +inline void _fp8_exp_reduce_sum_quant_fusion_kernel( + float* a, + const int& size, + at::Float8_e4m3fn* out, + float& val, + const float& scale) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + float tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + float min_val = -448; + float max_val = 448; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + auto vec_scale = at::vec::Vectorized(scale); + long i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + auto tmp3 = tmp2 * vec_scale; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(out + i, tmp4); + } + if (i < size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i, size - i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, size - i); + auto tmp3 = tmp2 * vec_scale; + auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); + _store(out + i, tmp4, size - i); + } + val = vec_tmp_sum.reduce_add(); +} + +/* +1. dequant +2. quant +*/ +inline void _fp8_dequant_quant_fusion_kernel( + float* a, + const int& size, + at::Float8_e4m3fn* out, + const float& scale) { + auto vec_size = at::vec::Vectorized::size(); + float min_val = -448; + float max_val = 448; + auto vec_min_val = at::vec::Vectorized(min_val); + auto vec_max_val = at::vec::Vectorized(max_val); + auto vec_scale = at::vec::Vectorized(scale); + long i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + auto tmp2 = at::vec::clamp(tmp1, vec_min_val, vec_max_val); + _store(out + i, tmp2); + } + if (i < size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i, size - i); + auto tmp1 = tmp0 * vec_scale; + auto tmp2 = at::vec::clamp(tmp1, vec_min_val, vec_max_val); + _store(out + i, tmp2, size - i); + } +} + // UINT8 - one parallel loop with u8u8s32 GEMM template = 0> inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( +int8_sdpa_fused_kernel_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -830,9 +983,9 @@ sdpa_int8_fused_kernel_impl( int av_gemm_K = kvSplitSize + av_gemm_K_padding; // Data ptrs - scalar_t* q_data = query.data_ptr(); - scalar_t* k_data = key.data_ptr(); - scalar_t* v_data = value.data_ptr(); + const scalar_t* q_data = query.data_ptr(); + const scalar_t* k_data = key.data_ptr(); + const scalar_t* v_data = value.data_ptr(); mask_t* mask_data = attention_mask.has_value() ? attention_mask.value().data_ptr() : nullptr; @@ -931,7 +1084,7 @@ sdpa_int8_fused_kernel_impl( bool istail = kvBlockSize - b < block_64; int64_t trans_rows = istail ? kvBlockSize - b : block_64; do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, + reinterpret_cast(k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN), B_blocked_xform_u8, trans_rows, headSize, @@ -1159,7 +1312,7 @@ template = 0> inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( +int8_sdpa_fused_kernel_impl( const at::Tensor& output, const at::Tensor& q, const at::Tensor& k, @@ -1622,10 +1775,373 @@ sdpa_int8_fused_kernel_impl( at::native::cpublas::brgemm_release(); } +#if defined(CPUBLAS_BRGEMM_F8F8F32) +// FP8 - kernel with f8f8f8 GEMM +template +inline typename std::enable_if_t, void> +fp8_sdpa_fused_kernel_impl( + const at::Tensor& output, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + double dropout_p, + bool is_causal, + std::optional attn_mask, + std::optional scale, + float q_scale, + float k_scale, + float v_scale, + float a_scale, + float o_scale) { + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + at::Tensor query = q.transpose(1, 2); + at::Tensor key = k.transpose(1, 2); + at::Tensor value = v.transpose(1, 2); + + using accum_t = float; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = calculate_scale(query, scale).expect_float(); + + // Sizes + TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), + "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size"); + int64_t batchSize = query.size(0); + int64_t qSize = query.size(1); + int64_t kvSize = value.size(1); + int64_t num_head = query.size(2); + int64_t headSize = query.size(3); + + bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); + if (has_attn_mask) { + reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize); + } + + // Strides + int64_t qStrideB = query.stride(0); + int64_t qStrideM = query.stride(1); + int64_t qStrideH = query.stride(2); + int64_t kStrideB = key.stride(0); + int64_t kStrideN = key.stride(1); + int64_t kStrideH = key.stride(2); + int64_t vStrideB = value.stride(0); + int64_t vStrideN = value.stride(1); + int64_t vStrideH = value.stride(2); + int64_t oStrideB = output.stride(0); + int64_t oStrideM = output.stride(1); + int64_t oStrideH = output.stride(2); + int64_t mStrideB = + (has_attn_mask && attn_mask.value().size(0) > 1) + ? attn_mask.value().stride(0) + : 0; + int64_t mStrideH = + (has_attn_mask && attn_mask.value().size(1) > 1) + ? attn_mask.value().stride(1) + : 0; + int64_t mStrideM = + (has_attn_mask && attn_mask.value().size(2) > 1) + ? attn_mask.value().stride(2) + : 0; + int64_t mStrideN = + (has_attn_mask && attn_mask.value().size(3) > 1) + ? attn_mask.value().stride(3) + : 0; + + int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; + int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; + int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize; + int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + int64_t num_thread = at::get_num_threads(); + + // Pad is needed for packing when K is not even + bool headSize_even = headSize % 4 == 0; + int64_t eheadSize = !headSize_even ? headSize + 4 - headSize % 4: headSize; + int64_t ekvSplitSize = (kvSplitSize % 4 != 0) ? kvSplitSize + 4 - kvSplitSize % 4 : kvSplitSize; + int64_t ekvTail = (kvTail % 4 != 0) ? kvTail + 4 - kvTail % 4 : kvTail; + + // Allocate per thread temp buf (accumulate type) + int64_t size_per_thread = + /* qk */ qSplitSize * kvSplitSize + + /* qk_max */ qSplitSize + + /* qk_sum */ qSplitSize + + /* dst */ qSplitSize * headSize; + + at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(at::kFloat)); + at::Tensor buf_reduced = at::empty( + {num_thread, + qSplitSize, + ekvSplitSize}, + query.options()); + + // Data ptrs + const scalar_t* q_data = query.const_data_ptr(); + const scalar_t* k_data = key.const_data_ptr(); + const scalar_t* v_data = value.const_data_ptr(); + mask_t* mask_data = has_attn_mask + ? attn_mask.value().data_ptr() + : nullptr; + scalar_t* out_data = output.data_ptr(); + // accum_t* lse_data = logsumexp.data_ptr(); + accum_t* buf_data = buf.data_ptr(); + scalar_t* buf_reduced_data = buf_reduced.data_ptr(); + + // Buffer to store padding query and packing key/value + int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail; + at::Tensor key_t_reorder = at::empty( + {batchSize, num_head, eheadSize, kvSize}, + c10::CppTypeToScalarType::value); + at::Tensor value_t_reorder = at::empty( + {batchSize, num_head, kv_padding_size, headSize}, + c10::CppTypeToScalarType::value); + scalar_t* key_reorder_ptr = key_t_reorder.data_ptr(); + scalar_t* value_reorder_ptr = value_t_reorder.data_ptr(); + + scalar_t* query_padding_ptr = nullptr; + at::Tensor query_t_padding; + if (!headSize_even) { + query_t_padding = at::empty( + {num_thread, qSplitSize, eheadSize}, + c10::CppTypeToScalarType::value); + query_padding_ptr = query_t_padding.data_ptr(); + } + + // Reorder K, V + at::Tensor tranpose_t_reorder = at::empty( + {num_thread, kvSplitSize, headSize}, + c10::CppTypeToScalarType::value); + scalar_t* transpose_buffer_ptr = tranpose_t_reorder.data_ptr(); + at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int ompIdx = at::get_thread_num(); + int64_t i = 0, j = 0, l = 0, n = 0; + scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize; + at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice); + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + + // transpose [kvBlockSize, headSize] -> [headSize, kvBlockSize] + at::native::utils::transpose( + kvBlockSize, + headSize, + /* src */ reinterpret_cast(k_data + i * kStrideB + j * kStrideH + n * kStrideN), + /* ld_src */ kStrideN, + /* dst */ reinterpret_cast(transpose_ptr), + /* ld_dst */ kvBlockSize); + + // Pack [headSize, kvBlockSize] + at::vec::pack_vnni4( + /* src */ reinterpret_cast(transpose_ptr), + /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize), + /* ld_src */ kvBlockSize, + /* K */ headSize, + /* N */ kvBlockSize); + + // Pack [kvBlockSize, headSize] + at::vec::pack_vnni4( + /* src */ reinterpret_cast(v_data + i * vStrideB + j * vStrideH + n * vStrideN), + /* dst */ reinterpret_cast(value_reorder_ptr + + i * num_head * kv_padding_size * headSize + + j * kv_padding_size * headSize + n * headSize), + /* ld_src */ vStrideN, + /* K */ kvBlockSize, + /* N */ headSize); + + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); + } + }); + + at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + accum_t* buf_ptr = buf_data + ompIdx * size_per_thread; + accum_t* qk_data = buf_ptr; + accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize; + accum_t* qk_sum_data = qk_max_data + qSplitSize; + accum_t* dst_data = qk_sum_data + qSplitSize; + scalar_t* qk_reduced_data = buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize; + scalar_t* query_t_padding_ptr = !headSize_even + ? query_padding_ptr + ompIdx * qSplitSize * eheadSize + : nullptr; + + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + int64_t m = k * qSplitSize; + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // Initialize max and sum + fill_stub(qk_max_data, + -std::numeric_limits::infinity(), qBlockSize); + fill_stub(qk_sum_data, + static_cast(0), qBlockSize); + int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + if (!headSize_even) { + // Pad query if headSize is not even + // [qBlockSize, headSize] -> [qBlockSize, eheadSize] + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + eheadSize, + qStrideM + ); + } + for (int64_t n = 0; n < num_keys; n += kvSplitSize) { + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + int64_t ekvBlockSize = (kvBlockSize % 4 != 0) ? kvBlockSize + 4 - kvBlockSize % 4 : kvBlockSize; + // Calculate scale * q @ k.T + at::native::cpublas::brgemm( + qBlockSize, + kvBlockSize, + eheadSize, + headSize_even ? qStrideM : eheadSize, + kvBlockSize, + kvBlockSize, + false, + !headSize_even + ? query_t_padding_ptr + : q_data + i * qStrideB + j * qStrideH + m * qStrideM, + key_reorder_ptr + i * num_head * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize, + qk_data); + // Apply causal mask, fill unused with -inf + if (is_causal && num_keys - n <= kvSplitSize) { + for (const auto row : c10::irange(qBlockSize)) { + int64_t last_col = m + row - n; + accum_t* row_ptr = qk_data + row * kvBlockSize; + fill_stub(row_ptr + last_col + 1, + -std::numeric_limits::infinity(), + kvBlockSize - last_col - 1); + } + } + // Update attention weights with attention mask + // And apply scaling factor + // qk <- qk * scaling + attn_mask + if (has_attn_mask) { + for (int64_t row = 0; row < qBlockSize; ++row) { + if (mStrideN == 0) { + _scale_dequant_attn_mask_fusion_kernel( + qk_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM, + kvBlockSize, + qk_data + row * kvBlockSize, + scaling_factor * q_scale * k_scale); + } else { + _scale_dequant_attn_mask_fusion_kernel( + qk_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM + n, + kvBlockSize, + qk_data + row * kvBlockSize, + scaling_factor * q_scale * k_scale); + } + } + } + // Update coefficients with Softmax + accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0; + for (int64_t row = 0; row < qBlockSize; ++row) { + if (has_attn_mask) { + // max per row + tmp_max = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, + qk_data + row * kvBlockSize, + kvBlockSize); + } else { + // apply scaling factor and max per row in fusion + _mul_reduce_max_fusion_kernel( + qk_data + row * kvBlockSize, + scaling_factor * q_scale * k_scale, + kvBlockSize, + qk_data + row * kvBlockSize, + tmp_max); + } + tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max; + if (tmp_max == -std::numeric_limits::infinity()) { + // to avoid `nan = exp2f(-inf - (-inf))` + fill_stub(qk_reduced_data + row * ekvBlockSize, + static_cast(0), kvBlockSize); + } else { + tmp_sum = tmp_max; + // qk <- exp(qk - max) and sum per row + _fp8_exp_reduce_sum_quant_fusion_kernel( + qk_data + row * kvBlockSize, kvBlockSize, + qk_reduced_data + row * ekvBlockSize, + tmp_sum, + 1.0 / a_scale); + // exp_tmp <- exp(max[row] - max) + exp_tmp = std::exp(qk_max_data[row] - tmp_max); + // sum[row] <- sum + exp_tmp * sum[row] + qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row]; + // max[row] <- max + qk_max_data[row] = tmp_max; + // dst <- dst * exp_tmp + if (n > 0) { + at::vec::map( + [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, + dst_data + row * headSize, + dst_data + row * headSize, + headSize); + } + } + if (kvBlockSize % 4 != 0) { + // Pad: [qSplitSize, kvBlockSize] -> [qSplitSize, kvBlockSize + 4 - kvBlockSize / 4] + for (int64_t psize = kvBlockSize; psize < ekvBlockSize; ++psize) { + *(qk_reduced_data + row * ekvBlockSize + psize) = scalar_t(0); + } + } + } + // Calculate Softmax(q @ k.T) @ v + int64_t psize = n / kvSplitSize * ekvSplitSize; + at::native::cpublas::brgemm( + qBlockSize, + headSize, + ekvBlockSize, + ekvBlockSize, + headSize, + headSize, + n > 0, + qk_reduced_data, + value_reorder_ptr + + i * num_head * kv_padding_size * headSize + + j * kv_padding_size * headSize + psize * headSize, + dst_data); + } + + // dst <- dst / sum[row] + // reorder MHA output with strides + for (int64_t row = 0; row < qBlockSize; ++row) { + // Row sums for full masked out rows are 0, we set them to 1 + // in order to avoid NaNs in the output and instead set fully + // masked out rows to 0 + qk_max_data[row] = qk_max_data[row] == -std::numeric_limits::infinity() ? 0 : qk_max_data[row]; + qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row]; + accum_t sum_reciprocal = 1 / qk_sum_data[row]; + _fp8_dequant_quant_fusion_kernel( + dst_data + row * headSize, + headSize, + out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, + sum_reciprocal * a_scale * v_scale / o_scale); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); + } + at::native::cpublas::brgemm_release(); + }); +} +#endif // CPUBLAS_BRGEMM_F8F8F32 template inline typename std::enable_if_t, void> -sdpa_int8_fused_kernel_impl( +int8_sdpa_fused_kernel_impl( bool use_one_parallel_loop, const at::Tensor& output, const at::Tensor& query, @@ -1646,7 +2162,7 @@ sdpa_int8_fused_kernel_impl( float o_scale, int32_t o_zp) { if (use_one_parallel_loop) { - sdpa_int8_fused_kernel_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1656,7 +2172,7 @@ sdpa_int8_fused_kernel_impl( a_scale, a_zp, o_scale, o_zp); } else { - sdpa_int8_fused_kernel_impl( output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1668,7 +2184,6 @@ sdpa_int8_fused_kernel_impl( } } - #define AT_DISPATCH_MASK_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, \ @@ -1684,7 +2199,7 @@ sdpa_int8_fused_kernel_impl( AT_PRIVATE_CASE_TYPE_USING_HINT( \ at::ScalarType::Half, mask_t, __VA_ARGS__)) -void sdpa_int8_fused_kernel( +void int8_sdpa_fused_kernel( const at::Tensor& output, const at::Tensor& query, const at::Tensor& key, @@ -1724,7 +2239,7 @@ void sdpa_int8_fused_kernel( (attn_size > 1.5 * l2_cache_size); if (!attn_mask.has_value()) { if (q_split_size == 256) { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1734,7 +2249,7 @@ void sdpa_int8_fused_kernel( a_scale, a_zp, o_scale, o_zp); } else if (q_split_size == 64) { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1744,7 +2259,7 @@ void sdpa_int8_fused_kernel( a_scale, a_zp, o_scale, o_zp); } else { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1757,7 +2272,7 @@ void sdpa_int8_fused_kernel( } else { AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { if (q_split_size == 256) { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1767,7 +2282,7 @@ void sdpa_int8_fused_kernel( a_scale, a_zp, o_scale, o_zp); } else if (q_split_size == 64) { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1777,7 +2292,7 @@ void sdpa_int8_fused_kernel( a_scale, a_zp, o_scale, o_zp); } else { - sdpa_int8_fused_kernel_impl( + int8_sdpa_fused_kernel_impl( use_one_parallel_loop, output, query, key, value, dropout_p, is_causal, attn_mask, scale, @@ -1790,9 +2305,88 @@ void sdpa_int8_fused_kernel( }); } } + +#if defined(CPUBLAS_BRGEMM_F8F8F32) +void fp8_sdpa_fused_kernel( + const at::Tensor& output, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + std::optional attn_mask, + std::optional scale, + float q_scale, + float k_scale, + float v_scale, + float a_scale, + float o_scale) { + TORCH_CHECK(query.scalar_type() == c10::kFloat8_e4m3fn); + int64_t batchSize = query.size(0); + int64_t num_head = query.size(1); + int64_t q_seq_len = query.size(2); + int64_t kv_seq_len = key.size(2); + int64_t q_split_size = 32; + if (q_seq_len >= 768) { + q_split_size = 256; + } else if (q_seq_len >= 192) { + q_split_size = 64; + } + + if (!attn_mask.has_value()) { + if (q_split_size == 256) { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } else if (q_split_size == 64) { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } else { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } + } else { + AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "sdpa_mask", [&]() { + if (q_split_size == 256) { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } else if (q_split_size == 64) { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } else { + fp8_sdpa_fused_kernel_impl( + output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + } + }); + } +} +#endif // CPUBLAS_BRGEMM_F8F8F32 #endif // CPU_CAPABILITY_AVX512 -at::Tensor sdpa_int8_math_kernel( +at::Tensor int8_sdpa_math_kernel( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, @@ -1834,6 +2428,43 @@ at::Tensor sdpa_int8_math_kernel( return output; } +at::Tensor fp8_sdpa_math_kernel( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + double dropout_p, + bool is_causal, + std::optional attn_mask, + std::optional scale, + float q_scale, + float k_scale, + float v_scale, + float a_scale, + float o_scale) { + // dequant q/k/v + auto q = query.to(at::kFloat) * q_scale; + auto k = key.to(at::kFloat) * k_scale; + auto v = value.to(at::kFloat) * v_scale; + const auto scaling_factor = calculate_scale(q, scale); + auto attn = at::matmul(q, k.transpose(-2, -1)) * scaling_factor; + if (attn_mask.has_value() && attn_mask.value().numel()) { + attn = attn.add(attn_mask.value().to(at::kFloat)); + } + attn = at::softmax(attn, -1); + // quant attn + attn = at::clamp_max( + at::clamp_min(attn / a_scale, -448), 448 + ); + attn = attn.to(at::kFloat8_e4m3fn).to(at::kFloat); + // dequant attn + attn = attn * a_scale; + auto output = at::matmul(attn, v); + // quant output + output = at::clamp_max( + at::clamp_min(output / o_scale, -448), 448 + ).to(at::kFloat8_e4m3fn); + return output; +} at::Tensor _qscaled_dot_product_cpu( const at::Tensor& query, @@ -1858,8 +2489,8 @@ at::Tensor _qscaled_dot_product_cpu( "_qscaled_dot_product_cpu: Only accept plain inputs"); TORCH_CHECK(!is_causal, "_qscaled_dot_product_cpu: is_causal not supported."); - TORCH_CHECK(dtype == at::ScalarType::Byte, - "_qscaled_dot_product_cpu: Expected data type be U8, but got ", dtype, " instead."); + TORCH_CHECK(dtype == at::ScalarType::Byte || dtype == at::ScalarType::Float8_e4m3fn, + "_qscaled_dot_product_cpu: Expected data type be U8 or Float8_e4m3, but got ", dtype, " instead."); TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, "_qscaled_dot_product_cpu: Accept only 4 dims inputs shape of {B, H, T, K}"); TORCH_CHECK(dropout_p == 0.0, @@ -1873,30 +2504,59 @@ at::Tensor _qscaled_dot_product_cpu( TORCH_CHECK(!attn_mask.has_value() || (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4), "_qscaled_dot_product_cpu: Attention mask dim in {2, 4}"); + if (dtype == at::ScalarType::Float8_e4m3fn) { + TORCH_CHECK(q_zp == 0 && k_zp == 0 && v_zp == 0 && a_zp == 0 && o_zp == 0, + "_qscaled_dot_product_cpu: Don't accept zero point for Float8_e4m3"); + } - #ifdef CPU_CAPABILITY_AVX512 - if (at::native::cpublas::could_pack(dtype)) { - at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); - sdpa_int8_fused_kernel(output, query, key, value, - dropout_p, is_causal, attn_mask, scale, - q_scale, q_zp, - k_scale, k_zp, - v_scale, v_zp, - a_scale, a_zp, - o_scale, o_zp); - return output.transpose(1, 2); - } else { - #endif // CPU_CAPABILITY_AVX512 - return sdpa_int8_math_kernel(query, key, value, + if (dtype == at::ScalarType::Byte) { +#ifdef CPU_CAPABILITY_AVX512 + if (at::native::cpublas::could_pack(dtype)) { + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + int8_sdpa_fused_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, q_zp, + k_scale, k_zp, + v_scale, v_zp, + a_scale, a_zp, + o_scale, o_zp); + return output.transpose(1, 2); + } else { +#endif // CPU_CAPABILITY_AVX512 + return int8_sdpa_math_kernel(query, key, value, dropout_p, is_causal, attn_mask, scale, q_scale, q_zp, k_scale, k_zp, v_scale, v_zp, a_scale, a_zp, o_scale, o_zp).transpose(1, 2).contiguous().transpose(1, 2); - #ifdef CPU_CAPABILITY_AVX512 - } - #endif // CPU_CAPABILITY_AVX512 +#ifdef CPU_CAPABILITY_AVX512 + } +#endif // CPU_CAPABILITY_AVX512 + } else if (dtype == at::ScalarType::Float8_e4m3fn) { +#if defined(CPUBLAS_BRGEMM_F8F8F32) && defined(CPU_CAPABILITY_AVX512) +// CPUBLAS_BRGEMM_F8F8F32 is defined if FP8 BRGEMM is supported in PyTorch CPUBlas. + if (at::native::cpublas::could_pack(dtype)) { + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); + fp8_sdpa_fused_kernel(output, query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale); + return output.transpose(1, 2); + } else { +#endif // CPU_CAPABILITY_AVX512 && CPUBLAS_BRGEMM_F8F8F32 + return fp8_sdpa_math_kernel(query, key, value, + dropout_p, is_causal, attn_mask, scale, + q_scale, k_scale, + v_scale, a_scale, + o_scale).transpose(1, 2).contiguous().transpose(1, 2); +#if defined(CPUBLAS_BRGEMM_F8F8F32) && defined(CPU_CAPABILITY_AVX512) + } +#endif // CPU_CAPABILITY_AVX512 && CPUBLAS_BRGEMM_F8F8F32 + } else { + TORCH_CHECK(false, "_qscaled_dot_product_cpu: Unsupported data type ", dtype); + } } diff --git a/torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp b/torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp new file mode 100644 index 0000000000..a83100d2ea --- /dev/null +++ b/torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp @@ -0,0 +1,182 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace torchao { + +namespace { + +#if defined(CPU_CAPABILITY_AVX512) +static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) { + __m512 o; + __m128i v = _mm_loadu_si128(reinterpret_cast(x)); + at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o); + return o; +} +#endif + +template +inline void _scaled_embedding_bag_krnl( + const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb, + const int64_t emb_dim, const index_t last_offset, const index_t *indices, + const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale, + float *result, const int64_t num_batch) { +#if defined(CPU_CAPABILITY_AVX512) + if (emb_dim % 128 == 0) { + constexpr int64_t block_dim = 128; + const int64_t num_blocks = emb_dim / block_dim; + __m512 scale_v = _mm512_set1_ps(scale); + for (int64_t b = bs_begin; b < bs_end; ++b) { + __m512 x0, x1, x2, x3, x4, x5, x6, x7; + int64_t start_idx = offsets[b]; + int64_t end_idx = ((b + 1) == num_batch && last_offset != -1) + ? last_offset + : offsets[b + 1]; + for (int64_t block_id = 0; block_id < num_blocks; block_id++) { + // load first indices + int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id; + float *block_result = result + block_dim * block_id; + x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]); + x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]); + x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]); + x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]); + x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]); + x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]); + x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]); + x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]); + for (int64_t j = start_idx + 1; j < end_idx; ++j) { + // add following idx + idx = indices[j] * emb_dim + block_dim * block_id; + x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx])); + x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16])); + x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32])); + x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48])); + x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64])); + x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80])); + x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96])); + x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112])); + } + x0 = _mm512_mul_ps(x0, scale_v); + x1 = _mm512_mul_ps(x1, scale_v); + x2 = _mm512_mul_ps(x2, scale_v); + x3 = _mm512_mul_ps(x3, scale_v); + x4 = _mm512_mul_ps(x4, scale_v); + x5 = _mm512_mul_ps(x5, scale_v); + x6 = _mm512_mul_ps(x6, scale_v); + x7 = _mm512_mul_ps(x7, scale_v); + // store + _mm512_store_ps(block_result, x0); + _mm512_store_ps(block_result + 16, x1); + _mm512_store_ps(block_result + 32, x2); + _mm512_store_ps(block_result + 48, x3); + _mm512_store_ps(block_result + 64, x4); + _mm512_store_ps(block_result + 80, x5); + _mm512_store_ps(block_result + 96, x6); + _mm512_store_ps(block_result + 112, x7); + } + result += num_emb * emb_dim; + } + return; + } +#endif + for (int64_t b = bs_begin; b < bs_end; ++b) { + int64_t start_idx = offsets[b]; + int64_t end_idx = ((b + 1) == num_batch && last_offset != -1) + ? last_offset + : offsets[b + 1]; + for (int64_t d = 0; d < emb_dim; d++) { + int64_t idx = indices[start_idx] * emb_dim; + float value = float(weight[idx + d]); + for (int64_t j = start_idx + 1; j < end_idx; ++j) { + idx = indices[j] * emb_dim; + value += float(weight[idx + d]); + } + value = value * scale; + result[d] = value; + } + result += num_emb * emb_dim; + } +} + +template +void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr, + index_t *offsets_ptr, int64_t num_batch, + int64_t emb_dim, index_t last_offset, double w_scale, + double o_scale) { + constexpr int64_t b_block = 512; + const int64_t n_b_blocks = (num_batch - 1) / b_block + 1; + w_scale /= o_scale; + const int64_t num_emb = 1; +#pragma omp parallel for collapse(2) + for (int64_t b = 0; b < n_b_blocks; ++b) { + for (int64_t n = 0; n < num_emb; ++n) { + const int64_t bs_begin = b * b_block; + const int64_t bs_end = std::min(num_batch, (b + 1) * b_block); + float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim]; + // avoid offsets not include last batch + _scaled_embedding_bag_krnl(bs_begin, bs_end, num_emb, emb_dim, + last_offset, indices_ptr, offsets_ptr, w_ptr, + w_scale, r, num_batch); + } + } +} + +at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight, + const at::Tensor &indices, + const at::Tensor &offsets, + const at::Tensor &w_scales, + double o_scale, const int64_t mode, + bool include_last_offset) { + // Only support include_last_offset == True and mode == + // at::native::EmbeddingBagMode::SUM + // TODO: Support more case + TORCH_CHECK(include_last_offset, + "_scaled_embedding_bag: only suppport include_last_offset"); + TORCH_CHECK(mode == at::native::EmbeddingBagMode::SUM, + "_scaled_embedding_bag: only suppport sum mode"); + int64_t batch_size = + include_last_offset ? offsets.size(0) - 1 : offsets.size(0); + int64_t emb_dim = qweight.size(1); + + auto index_type = indices.scalar_type(); + float w_scale = w_scales.data_ptr()[0]; + + TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(), + "_scaled_embedding_bag: only accept contiguous input"); + TORCH_CHECK( + offsets.scalar_type() == index_type, + "_scaled_embedding_bag: index and offset must be of the same type"); + TORCH_CHECK(qweight.is_contiguous(), + "_scaled_embedding_bag: only accept contiguous weight"); + TORCH_CHECK(qweight.dim() == 2, + "_scaled_embedding_bag: only accept weight with dim == 2"); + TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn, + "_scaled_embedding_bag: only support e4m3fn weight") + // handle last offsets + int64_t last_offset = indices.numel(); + + at::Tensor output = + at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat)); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] { + at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr(); + index_t *indices_ptr = indices.data_ptr(); + index_t *offsets_ptr = offsets.data_ptr(); + float *output_ptr = output.data_ptr(); + _scaled_embedding_bag( + output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim, + last_offset, w_scale, o_scale); + }); + return output; +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::_scaled_embedding_bag", &_scaled_embedding_bag_impl); +} + +} // namespace torchao diff --git a/torchao/csrc/cpu/build_and_run_benchmarks.sh b/torchao/csrc/cpu/build_and_run_benchmarks.sh new file mode 100644 index 0000000000..964fe9e5bf --- /dev/null +++ b/torchao/csrc/cpu/build_and_run_benchmarks.sh @@ -0,0 +1,38 @@ +set -eu + +if [[ $# -ne 1 ]]; then + echo "Usage: $0 "; + exit 1; +fi + +BENCHMARK_TYPE="${1}" + +export CMAKE_OUT=cmake-out + +export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') +echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" + +# Build +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -DTORCHAO_BUILD_EXECUTORCH_OPS=OFF \ + -DTORCHAO_BUILD_CPU_AARCH64=ON \ + -DTORCHAO_ENABLE_ARM_NEON_DOT=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DTORCHAO_BUILD_TESTS=OFF \ + -DTORCHAO_BUILD_BENCHMARKS=ON \ + -DOpenMP_ROOT=$(brew --prefix libomp) \ + -S . \ + -B ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} -j 16 --config Release + + +# Run +TARGET_PREFIX="${CMAKE_OUT}/torch_free_kernels/aarch64/benchmarks/torchao_benchmarks_torch_free_kernels_aarch64_" +case "${BENCHMARK_TYPE}" in + build_only) echo "Build only"; exit 0; ;; + quantization) ${TARGET_PREFIX}benchmark_quantization; ;; + bitpacking) ${TARGET_PREFIX}benchmark_bitpacking; ;; + linear) ${TARGET_PREFIX}benchmark_linear; ;; + *) echo "Unknown benchmark: $1. Please specify quantization, bitpacking, or linear."; exit 1; ;; +esac diff --git a/torchao/csrc/cpu/build_and_run_tests.sh b/torchao/csrc/cpu/build_and_run_tests.sh new file mode 100644 index 0000000000..6d92a81d98 --- /dev/null +++ b/torchao/csrc/cpu/build_and_run_tests.sh @@ -0,0 +1,87 @@ +#!/bin/bash -eu +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +set -eu + + +target=${1:-"native"} +export CMAKE_OUT=cmake-out + +EXTRA_ARGS="" +if [[ "${target}" == "android" ]]; then + if [[ -z ${ANDROID_NDK} ]]; then + echo "Need to set ANDROID_NDK env variable to build for Android"; + exit 1; + fi + android_abi=arm64-v8a + android_platform=28 # must be >=28 for aligned_alloc + IS_ARM64=1 + BUILD_ARM_I8MM=1 # Hardcoded for now + CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} + toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" + if [[ -z ${toolchain_file} ]]; then + echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" + exit 1; + fi + EXTRA_ARGS="\ + -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ + -DANDROID_ABI=${android_abi} \ + -DANDROID_PLATFORM=${android_platform} + " + echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" +fi + + + + +export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') +echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" + + +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -DTORCHAO_BUILD_EXECUTORCH_OPS=OFF \ + -DTORCHAO_BUILD_CPU_AARCH64=ON \ + -DTORCHAO_ENABLE_ARM_NEON_DOT=ON \ + -DTORCHAO_BUILD_KLEIDIAI=ON \ + -DCMAKE_BUILD_TYPE=Debug \ + -DTORCHAO_BUILD_TESTS=ON \ + -S . \ + -B ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} -j 16 --config Debug + + + +echo "Successfully built tests." + +if [[ "${target}" != "native" ]]; then + echo "Skip running tests when cross compiling."; + exit 0; +fi + +# Torch-free aarch64 +TEST_TARGET_PREFIX="${CMAKE_OUT}/torch_free_kernels/aarch64/tests/torchao_tests_torch_free_kernels_aarch64_" +${TEST_TARGET_PREFIX}test_quantization +${TEST_TARGET_PREFIX}test_reduction +${TEST_TARGET_PREFIX}test_reduction +${TEST_TARGET_PREFIX}test_bitpacking +${TEST_TARGET_PREFIX}test_linear +${TEST_TARGET_PREFIX}test_embedding +${TEST_TARGET_PREFIX}test_weight_packing +${TEST_TARGET_PREFIX}test_qmatmul +${TEST_TARGET_PREFIX}test_lut +${TEST_TARGET_PREFIX}test_bitpack_fallback_compatibility +${TEST_TARGET_PREFIX}test_embedding_lut + +# Torch-free fallback +TEST_TARGET_PREFIX="${CMAKE_OUT}/torch_free_kernels/fallback/tests/torchao_tests_torch_free_kernels_fallback_" +${TEST_TARGET_PREFIX}test_bitpacking + +# Shared kernels +TEST_TARGET_PREFIX="${CMAKE_OUT}/shared_kernels/tests/torchao_tests_shared_kernels_" +${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight +${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut diff --git a/torchao/experimental/build_torchao_ops.sh b/torchao/csrc/cpu/build_shared_kernels.sh similarity index 93% rename from torchao/experimental/build_torchao_ops.sh rename to torchao/csrc/cpu/build_shared_kernels.sh index 1bcc1a9658..bfa9a55eef 100644 --- a/torchao/experimental/build_torchao_ops.sh +++ b/torchao/csrc/cpu/build_shared_kernels.sh @@ -23,6 +23,8 @@ cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DTORCHAO_BUILD_EXECUTORCH_OPS="${TORCHAO_BUILD_EXECUTORCH_OPS}" \ -DTORCHAO_BUILD_CPU_AARCH64=ON \ -DTORCHAO_ENABLE_ARM_NEON_DOT=ON \ + -DTORCHAO_BUILD_TESTS=OFF \ + -DTORCHAO_BUILD_BENCHMARKS=OFF \ -S . \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} -j 16 --target install --config Release diff --git a/torchao/csrc/cpu/shared_kernels/README.md b/torchao/csrc/cpu/shared_kernels/README.md new file mode 100644 index 0000000000..37b4be6c7c --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/README.md @@ -0,0 +1,5 @@ +# Shared kernels + +This directory is for kernels that are shared between PyTorch/ATen and Executorch. +Shared kernels are written with abstractions in internal/library.h. +These are compiled to either an ATen or ExecuTorch kernel based on compile flags. diff --git a/torchao/experimental/Utils.cmake b/torchao/csrc/cpu/shared_kernels/Utils.cmake similarity index 97% rename from torchao/experimental/Utils.cmake rename to torchao/csrc/cpu/shared_kernels/Utils.cmake index 984c90006b..be70047844 100644 --- a/torchao/experimental/Utils.cmake +++ b/torchao/csrc/cpu/shared_kernels/Utils.cmake @@ -28,7 +28,7 @@ function(target_link_torchao_parallel_backend target_name torchao_parallel_backe message(STATUS "EXECUTORCH_INCLUDE_DIRS: ${EXECUTORCH_INCLUDE_DIRS}") message(STATUS "EXECUTORCH_LIBRARIES: ${EXECUTORCH_LIBRARIES}") target_include_directories(${target_name} PRIVATE "${EXECUTORCH_INCLUDE_DIRS}") - target_link_libraries(${target_name} PRIVATE "${EXECUTORCH_LIBRARIES}") + target_link_libraries(${target_name} PRIVATE executorch_core) target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_EXECUTORCH=1) elseif(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "OPENMP") diff --git a/torchao/csrc/cpu/shared_kernels/benchmarks/CMakeLists.txt b/torchao/csrc/cpu/shared_kernels/benchmarks/CMakeLists.txt new file mode 100644 index 0000000000..b5fd251a1f --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/benchmarks/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_benchmarks) +set(CMAKE_BUILD_TYPE Release) + +set(TARGET_PREFIX "torchao_benchmarks_shared_kernels_") + + +# TODO: fix benchmark. Got broken from refactor + +# add_executable(${TARGET_PREFIX}benchmark_linear_8bit_act_xbit_weight +# benchmark_linear_8bit_act_xbit_weight.cpp +# ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +# ) + +# target_link_torchao_parallel_backend(${TARGET_PREFIX}benchmark_linear_8bit_act_xbit_weight openmp) +# target_link_libraries( +# ${TARGET_PREFIX}benchmark_linear_8bit_act_xbit_weight +# PRIVATE +# benchmark::benchmark +# torchao_kernels_aarch64 +# ) diff --git a/torchao/experimental/ops/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp b/torchao/csrc/cpu/shared_kernels/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp similarity index 92% rename from torchao/experimental/ops/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp rename to torchao/csrc/cpu/shared_kernels/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp index 2efd425175..caf03acf21 100644 --- a/torchao/experimental/ops/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp +++ b/torchao/csrc/cpu/shared_kernels/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp @@ -5,11 +5,11 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include using namespace torchao::ops::linear_8bit_act_xbit_weight; diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit-impl.h similarity index 87% rename from torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h rename to torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit-impl.h index 8113a0566b..6c1181873b 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h +++ b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit-impl.h @@ -7,14 +7,14 @@ #pragma once #if defined(TORCHAO_BUILD_CPU_AARCH64) -#include +#include #endif // TORCHAO_BUILD_CPU_AARCH64 -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include template void check_embedding_inputs( @@ -27,11 +27,11 @@ void check_embedding_inputs( int& group_size) { TORCHAO_CHECK( packed_weight_qvals.dim() == 1, "packed_weight_qvals must be 1D"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weight_qvals.dtype() == torch::kInt8, "packed_weight_qvals must be byte"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( (embedding_dim * weight_nbit) % 8 == 0, "embedding_dim * weight_nbit must be a multiple of 8"); @@ -53,11 +53,11 @@ void check_embedding_inputs( /*max_value_chunk_size=*/128), "packed_weights are not compatible with the kernel"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( weight_scales.dtype() == torch::kFloat32, "weight_scales must be float32"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(weight_scales.dim() == 2, "weight_scales must be 2D"); TORCHAO_CHECK( weight_scales.size(0) == num_embeddings, @@ -71,10 +71,10 @@ void check_embedding_inputs( group_size = embedding_dim / num_groups; TORCHAO_CHECK(group_size % 32 == 0, "group_size must be a multiple of 32"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( weight_zeros.dtype() == torch::kInt8, "weight_zeros must be int8"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(weight_zeros.dim() == 2, "weight_zeros must be 2D"); TORCHAO_CHECK( weight_zeros.size(0) == weight_scales.size(0) && @@ -88,7 +88,7 @@ void check_embedding_inputs( "indices must be int32 or int64"); } -#if defined(USE_ATEN) || defined(USE_EXECUTORCH) +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) template Tensor embedding_out_cpu( const Tensor& packed_weight_qvals, @@ -149,9 +149,9 @@ Tensor embedding_out_cpu( return out; } -#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH) +#endif // defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor embedding_cpu( const Tensor& packed_weight_qvals, @@ -171,9 +171,9 @@ Tensor embedding_cpu( output_tensor); return output_tensor; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_embedding_cpu(const Tensor& weight_qvals) { TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D"); @@ -213,9 +213,9 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) { return out; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_embedding_meta(const Tensor& weight_qvals) { TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D"); @@ -229,9 +229,9 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) { torchao::ops::PackedWeightsHeader::size() + (num_embeddings * packed_embedding_dim), options); } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#if defined(USE_ATEN) || defined(USE_EXECUTORCH) +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) template Tensor shared_embedding_out_cpu( const Tensor& packed_weights, @@ -242,10 +242,10 @@ Tensor shared_embedding_out_cpu( Tensor& out) { // Check packed_weights are from linear op TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), "packed_weights is not big enough to read the header."); @@ -308,7 +308,7 @@ Tensor shared_embedding_out_cpu( return out; } -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor shared_embedding_cpu( const Tensor& packed_weights, @@ -321,6 +321,6 @@ Tensor shared_embedding_cpu( packed_weights, group_size, n, k, indices, output_tensor); return output_tensor; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH) +#endif // defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_aten.cpp similarity index 98% rename from torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp rename to torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_aten.cpp index 318e648977..7129cd61c3 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp +++ b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_aten.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \ diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_executorch.cpp similarity index 96% rename from torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp rename to torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_executorch.cpp index 2ffcba7e6b..0227f23327 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp +++ b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_executorch.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ Tensor _op_out_##weight_nbit( \ diff --git a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h b/torchao/csrc/cpu/shared_kernels/embedding_xbit/packed_weights_header.h similarity index 85% rename from torchao/experimental/ops/embedding_xbit/packed_weights_header.h rename to torchao/csrc/cpu/shared_kernels/embedding_xbit/packed_weights_header.h index 8e47c2d1c0..addcd4181e 100644 --- a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h +++ b/torchao/csrc/cpu/shared_kernels/embedding_xbit/packed_weights_header.h @@ -5,8 +5,8 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include -#include +#include +#include namespace torchao::ops::embedding_xbit { diff --git a/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp new file mode 100644 index 0000000000..d6ffbc79e1 --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include +#include +#include +#include +#include +#include + +namespace torchao::ops::groupwise_lowbit_weight_lut { + +void pack_weights_operator( + const UKernelConfig& uk, + // Outputs + void* packed_weights_ptr, + // Inputs + int n, + int k, + int scale_group_size, + int lut_group_size, + const uint8_t* weight_qval_indices, + const float* weight_scales, + const float* weight_luts, + const float* bias) { + if (uk.has_scales) { + TORCHAO_CHECK( + lut_group_size % scale_group_size == 0, + "scale_group_size must devide lut_group_size"); + TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k"); + } + TORCHAO_CHECK( + lut_group_size % (k * uk.nr) == 0, + "lut_group_size must be a multiple of k*nr"); + TORCHAO_CHECK(k % uk.kr == 0, "kr must divide k"); + + // 1. Define the block size for parallel work. + int n_step = uk.n_step; + int nc = std::min(n, n_step); + const int num_nc_panels = (n + nc - 1) / nc; + + torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { + const int n_idx = idx * nc; + const int nc_tile_size = std::min(nc, n - n_idx); + + auto packed_weights_offset = uk.packed_weights_offset( + n_idx, + k, + uk.weight_nbit, + scale_group_size, + uk.has_scales, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + + // Calculate offsets for all input pointers + int weight_qval_indices_offset = n_idx * k; + // Scales are packed in groups of nr + int scales_offset = weight_qval_indices_offset / scale_group_size; + int luts_offset = + (weight_qval_indices_offset / lut_group_size) * (1 << uk.weight_nbit); + + // 2. Call pack_weights with chunk arguments + uk.pack_weights( + static_cast(packed_weights_ptr) + packed_weights_offset, + weight_qval_indices + weight_qval_indices_offset, + uk.has_scales ? weight_scales + scales_offset : nullptr, + weight_luts + luts_offset, + nc_tile_size, + k, + scale_group_size, + lut_group_size, + uk.has_scales, + uk.has_bias, + uk.has_bias ? bias + n_idx : nullptr, + uk.nr, + uk.kr, + uk.sr); + }); +} + +GroupwiseTilingParams GroupwiseTilingParams::from_target_tiles_per_thread( + int m, + int m_step, + int n, + int n_step, + int target_tiles_per_thread) { + TORCHAO_CHECK(m >= 1, "m must be >= 1"); + TORCHAO_CHECK(m_step >= 1, "m_step must be >= 1"); + + TORCHAO_CHECK(n >= 1, "n must be >= 1"); + TORCHAO_CHECK(n_step >= 1, "n_step must be >= 1"); + TORCHAO_CHECK( + target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1"); + auto num_threads = torchao::get_num_threads(); + TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); + + int mc = m_step; + int num_mc_panels = (m + mc - 1) / mc; + + int numerator = n * num_mc_panels; + int denominator = num_threads * target_tiles_per_thread; + + // Set nc = ceil(numerator / denominator) + int nc = (numerator + denominator - 1) / denominator; + assert(nc >= 1); + + // Replace nc with next number n_step divides + nc = ((nc + n_step - 1) / n_step) * n_step; + + // Clamp mc, nc to be no larger than m, n + mc = std::min(m, mc); + nc = std::min(n, nc); + + assert((mc == m) || (mc % m_step == 0)); + assert((nc == n) || (nc % n_step == 0)); + + GroupwiseTilingParams tiling_params; + tiling_params.mc = mc; + tiling_params.nc = nc; + return tiling_params; +} + +void groupwise_lowbit_weight_lut_parallel_operator( + const UKernelConfig& uk, + const std::optional& tiling_params, + float* output, + int m, + int n, + int k, + int scale_group_size, + int lut_group_size, + const void* packed_weights, + const float* activations, + bool has_clamp, + float clamp_min, + float clamp_max) { + if (uk.has_scales) { + TORCHAO_CHECK( + lut_group_size % scale_group_size == 0, + "scale_group_size must divide lut_group_size"); + TORCHAO_CHECK(k % scale_group_size == 0, "scale_group_size must divide k"); + TORCHAO_CHECK( + scale_group_size % uk.kr == 0, "kr must divide scale_group_size"); + } + + TORCHAO_CHECK( + lut_group_size % (k * uk.nr) == 0, "(k * nr) must divide lut_group_size"); + int config_idx = uk.select_config_idx(m); + auto& kernel_config = uk.configs[config_idx]; + int n_step = uk.n_step; + int m_step = kernel_config.m_step; + + int mc, nc; + if (tiling_params.has_value()) { + mc = tiling_params->mc; + nc = tiling_params->nc; + } else { + // If no params are provided, calculate them to balance the workload. + auto params = GroupwiseTilingParams::from_target_tiles_per_thread( + m_step, m_step, n, n_step, /*target_tiles_per_thread=*/5); + mc = params.mc; + nc = params.nc; + } + TORCHAO_CHECK(mc >= 1, "mc must be >= 1"); + TORCHAO_CHECK(nc >= 1, "nc must be >= 1"); + TORCHAO_CHECK( + (mc == m) || (mc % m_step == 0), + "mc from tiling_params must be m or a multiple of m_step"); + TORCHAO_CHECK( + (nc == n) || (nc % n_step == 0), + "nc from tiling_params must be n or a multiple of n_step"); + + const int num_mc_tiles = (m + mc - 1) / mc; + const int num_nc_tiles = (n + nc - 1) / nc; + + const size_t packed_activations_size = kernel_config.packed_activations_size( + mc, k, kernel_config.mr, uk.kr, uk.sr); + auto packed_activations = torchao::make_aligned_byte_ptr( + uk.preferred_alignment, packed_activations_size); + + // Outer loop over M blocks + for (int mc_tile_idx = 0; mc_tile_idx < num_mc_tiles; ++mc_tile_idx) { + const int mc_tile_start = mc_tile_idx * mc; + const int mc_tile_size = std::min(mc, m - mc_tile_start); + const float* activation_row_ptr = activations + mc_tile_start * k; + + kernel_config.pack_activations( + (float*)packed_activations.get(), + mc_tile_size, + k, + activation_row_ptr, + uk.nr, + uk.kr, + uk.sr); + + // Parallelize the work over the larger NC-tiles + torchao::parallel_1d(0, num_nc_tiles, [&](int64_t n_tile_idx) { + const int nc_tile_start = n_tile_idx * nc; + const int nc_tile_size = std::min(nc, n - nc_tile_start); + float* output_tile_ptr = output + mc_tile_start * n + nc_tile_start; + + const size_t packed_weights_offset = uk.packed_weights_offset( + nc_tile_start, + k, + uk.weight_nbit, + scale_group_size, + uk.has_scales, + uk.has_bias, + uk.nr, + uk.kr, + uk.sr); + const void* packed_weights_for_tile = + static_cast(packed_weights) + packed_weights_offset; + + kernel_config.kernel( + output_tile_ptr, + /*output_m_stride=*/n, + /*m=*/mc_tile_size, + /*n=*/nc_tile_size, + k, + scale_group_size, + lut_group_size, + packed_weights_for_tile, + packed_activations.get(), + clamp_min, + clamp_max, + uk.has_bias, + has_clamp); + }); + } +} + +} // namespace torchao::ops::groupwise_lowbit_weight_lut diff --git a/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h new file mode 100644 index 0000000000..bb5624033b --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h @@ -0,0 +1,126 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include + +namespace torchao::ops::groupwise_lowbit_weight_lut { + +/** + * @brief Orchestrates the packing of quantized weights into a kernel-specific + * memory layout. + * + * @details This function acts as a high-level operator that parallelizes the + * weight packing process across the N dimension. It partitions the work into + * tiles, calculates the correct memory offsets for each tile's source and + * destination pointers, and then invokes the low-level `pack_weights` function + * provided by the kernel configuration (`uk`). + * + * @param uk The kernel configuration, providing layout details, function + * pointers, and dimension constraints (nr, kr). + * @param packed_weights_ptr [out] The destination buffer for the packed weight + * data. + * @param n The N dimension of the weight matrix (e.g., output channels). + * @param k The K dimension of the weight matrix (e.g., input channels). + * @param scale_group_size The group size for weight quantization scales. + * @param lut_group_size The group size for weight lookup tables (LUTs). + * @param weight_qval_indices [in] Pointer to the raw quantized weight indices. + * @param weight_scales [in] Pointer to the raw weight quantization scales. + * @param weight_luts [in] Pointer to the raw weight lookup tables. + * @param bias [in] Pointer to the raw bias values; can be nullptr if the kernel + * configuration indicates no bias is used. + */ +void pack_weights_operator( + const UKernelConfig& uk, + // Outputs + void* packed_weights_ptr, + // Inputs + int n, + int k, + int scale_group_size, + int lut_group_size, + const uint8_t* weight_qval_indices, + const float* weight_scales, + const float* weight_luts, + const float* bias); + +struct GroupwiseTilingParams { + int mc; + int nc; + + /** + * @brief Calculates groupwise tiling parameters based on a target number of + * tiles per thread. + * + * @details This function implements a heuristic to determine optimal tile + * sizes (`mc`, `nc`) for balancing a computational workload across multiple + * threads. It calculates the number of tiles needed to cover the M dimension + * and uses this, along with the target number of tiles per thread, to derive + * a suitable tile count in the N dimension. This count is then scaled by + * `n_step` to get the final `nc` value. The resulting tile sizes are clamped + * to not exceed the original problem dimensions. + * + * @param m The total size of the M dimension (e.g., rows). + * @param m_step The required step size for tiling in the M dimension. + * @param n The total size of the N dimension (e.g., columns). + * @param n_step The required step size for tiling in the N dimension. + * @param target_tiles_per_thread A tuning parameter that suggests how many + * tiles each thread should ideally process, influencing the calculated tile + * sizes. + * @return A `GroupwiseTilingParams` struct containing the computed `mc` and + * `nc`. + */ + static GroupwiseTilingParams from_target_tiles_per_thread( + int m, + int m_step, + int n, + int n_step, + int target_tiles_per_thread); +}; + +/** + * @brief Executes a parallel linear operation using a groupwise low-bit LUT + * kernel. + * + * @details This function acts as a high-level operator for performing a linear + * operation (GEMM-like) with quantized weights. + * + * @param uk The kernel configuration, providing layout details and function + * pointers. + * @param tiling_params [in] Optional. User-provided tiling parameters (mc, nc). + * If not provided, the operator will calculate them dynamically. + * @param output [out] The destination buffer for the output matrix. + * @param m The M dimension of the output matrix (e.g., rows). + * @param n The N dimension of the output matrix (e.g., columns). + * @param k The K dimension, shared between the weights and activations. + * @param scale_group_size The group size for weight quantization scales. + * @param lut_group_size The group size for weight lookup tables (LUTs). + * @param packed_weights [in] Pointer to the pre-packed weight data. + * @param activations [in] Pointer to the raw activation data. + * @param has_clamp A boolean flag indicating whether to apply clamping to the + * output. + * @param clamp_min The minimum value for output clamping. + * @param clamp_max The maximum value for output clamping. + */ +void groupwise_lowbit_weight_lut_parallel_operator( + const UKernelConfig& uk, + const std::optional& tiling_params, + // Outputs + float* output, + // Inputs + int m, + int n, + int k, + int scale_group_size, + int lut_group_size, + const void* packed_weights, + const float* activations, + bool has_clamp, + float clamp_min, + float clamp_max); +} // namespace torchao::ops::groupwise_lowbit_weight_lut diff --git a/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_config.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_config.h new file mode 100644 index 0000000000..1110e740e2 --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_config.h @@ -0,0 +1,234 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include + +namespace torchao::ops::groupwise_lowbit_weight_lut { + +constexpr int kMaxConfigs = 4; + +/** + * @brief Defines the configuration for a Universal Kernel (UKernel) for the + * groupwise low-bit LUT-based kernel. + */ +struct UKernelConfig { + // Calculates the required size for the packed activation. + using packed_activations_size_fn_type = + size_t (*)(int m, int k, int mr, int kr, int sr); + + // Calculates the required size for the packed weights buffer. + using packed_weights_size_fn_type = size_t (*)( + int n, + int k, + int weight_nbit, + int scale_group_size, + bool has_scales, + bool has_bias, + int nr, + int kr, + int sr); + + // Packs activations into a kernel-friendly layout. + using pack_activations_fn_type = void (*)( + float* packed_activations, + int m, + int k, + const float* activations, + int mr, + int kr, + int sr); + + // Packs weights, scales, and LUTs into the target buffer. + using pack_weights_fn_type = void (*)( + void* packed_weights_ptr, + const uint8_t* weight_qvals_indices, + const float* weight_scales, + const float* weight_luts, + int n, + int k, + int scale_group_size, + int lut_group_size, + bool has_scales, + bool has_bias, + const float* bias, + int nr, + int kr, + int sr); + + // Offset in packed_activation buffer for multithread. + using packed_activations_offset_fn_type = + size_t (*)(int m_idx, int k, int mr, int kr, int sr); + + // Offset in packed_weight buffer for multithread. + using packed_weights_offset_fn_type = size_t (*)( + int n_idx, + int k, + int weight_nbit, + int scale_group_size, + bool has_scales, + bool has_bias, + int nr, + int kr, + int sr); + + // The main computation kernel. + using kernel_fn_type = void (*)( + float* output, + int output_m_stride, + int m, + int n, + int k, + int scale_group_size, + int lut_group_size, + const void* packed_weights, + const void* packed_activations, + float clamp_min, + float clamp_max, + bool has_bias, + bool has_clamp); + + // Configuration for a single kernel. + struct config_type { + int m_step{0}; + int mr{0}; + packed_activations_size_fn_type packed_activations_size{nullptr}; + packed_activations_offset_fn_type packed_activations_offset{nullptr}; + pack_activations_fn_type pack_activations{nullptr}; + kernel_fn_type kernel{nullptr}; + }; + + // Preferred memory alignment for buffers. + size_t preferred_alignment{0}; + int n_step{0}; + int nr{0}; + int kr{0}; + int sr{0}; + int weight_nbit{0}; + bool has_scales{false}; + bool has_bias{false}; + + packed_weights_size_fn_type packed_weights_size{nullptr}; + packed_weights_offset_fn_type packed_weights_offset{nullptr}; + pack_weights_fn_type pack_weights{nullptr}; + + std::array configs; + + static UKernelConfig make( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_scales, + bool has_bias, + packed_weights_size_fn_type packed_weights_size, + packed_weights_offset_fn_type packed_weights_offset, + pack_weights_fn_type pack_weights, + std::array configs); + + // Validation function to ensure all pointers are properly initialized. + inline void validate() const { + // 1. Validate Top-Level UKernelConfig Parameters + TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1"); + TORCHAO_CHECK(nr >= 1, "nr must be >= 1"); + TORCHAO_CHECK(kr >= 1, "kr must be >= 1"); + TORCHAO_CHECK(sr >= 1, "sr must be >= 1"); + TORCHAO_CHECK(weight_nbit >= 1, "weight_nbit must be >= 1"); + TORCHAO_CHECK(weight_nbit <= 4, "weight_nbit must be <= 4"); + TORCHAO_CHECK( + packed_weights_size != nullptr, + "packed_weights_size_fn_type must be set"); + TORCHAO_CHECK( + packed_weights_offset != nullptr, + "packed_weights_offset_fn_type must be set"); + TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set"); + // 2. Validate the Array of Linear Configurations + // At least one configuration must be defined. + TORCHAO_CHECK( + !configs.empty(), + "At least one valid kernel configuration must be provided."); + + bool configs_set = true; // first linear config must be set + for (size_t i = 0; i < configs.size(); ++i) { + if (configs_set) { + const auto& config = configs[i]; + + TORCHAO_CHECK( + config.packed_activations_size != nullptr, + "config.packed_activations_size must be set"); + TORCHAO_CHECK( + config.pack_activations != nullptr, + "config.pack_activations must be set"); + TORCHAO_CHECK(config.kernel != nullptr, "config.kernel must be set"); + + if (i > 0) { + const auto& prev_config = configs[i - 1]; + TORCHAO_CHECK( + prev_config.m_step > 0, + "There cannot be a gap in configurations (m_step=0 followed by m_step>0)"); + TORCHAO_CHECK( + prev_config.m_step < config.m_step, + "m_step values in configs must be strictly increasing."); + } + if (i + 1 < configs.size()) { + configs_set = (configs[i + 1].m_step >= 1); + } + } + } + } + + // Selects the appropriate configuration based on m. + inline int select_config_idx(int m) const { + assert(m >= 1); + assert(configs[0].m_step >= 1); + + size_t i = 0; + while (i + 1 < configs.size() && configs[i + 1].m_step >= 1 && + configs[i + 1].m_step <= m) { + assert(configs[i].m_step < configs[i + 1].m_step); + i++; + } + + assert(i < configs.size()); + assert(configs[i].m_step >= 1); + assert(i == 0 || configs[i].m_step <= m); + return static_cast(i); + } +}; + +inline UKernelConfig UKernelConfig::make( + size_t preferred_alignment, + int n_step, + int nr, + int kr, + int sr, + int weight_nbit, + bool has_scales, + bool has_bias, + packed_weights_size_fn_type packed_weights_size, + packed_weights_offset_fn_type packed_weights_with_lut_offset, + pack_weights_fn_type pack_weights, + std::array configs) { + return UKernelConfig{ + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_scales, + has_bias, + packed_weights_size, + packed_weights_with_lut_offset, + pack_weights, + std::move(configs)}; +} +} // namespace torchao::ops::groupwise_lowbit_weight_lut diff --git a/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_selector.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_selector.h new file mode 100644 index 0000000000..f8bdc4cafb --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_selector.h @@ -0,0 +1,246 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include +#include +#include +#include + +#if defined(TORCHAO_BUILD_CPU_AARCH64) +#include +#endif // TORCHAO_BUILD_CPU_AARCH64 + +namespace torchao::ops::groupwise_lowbit_weight_lut { + +/** + * @brief A thread-unsafe registration table for kernel configurations. + * + * This table maps a combination of a weight format (header) and a CPU + * microarchitecture to a specific UKernelConfig. + */ +struct UKernelConfigRegistrationTable { + private: + using Key = std::pair; + struct KeyHasher { + std::size_t operator()(const Key& k) const { + return std::hash()(k.first) ^ + std::hash()(static_cast(k.second)); + } + }; + std::unordered_map registration_table_; + inline Key make_key( + torchao::ops::PackedWeightsHeader header, + cpuinfo_uarch uarch) const { + return std::make_pair(header, uarch); + } + + public: + // resgist a kernel config for a given format and uarch. + void register_ukernel_config( + PackedWeightsFormat format, + cpuinfo_uarch uarch, + UKernelConfig config) { + auto header = format.to_packed_weights_header(); + auto key = make_key(header, uarch); + if (registration_table_.find(key) != registration_table_.end()) { + throw std::runtime_error( + "UKernelConfig is already registered for this format"); + } + config.validate(); + registration_table_[key] = config; + } + // get the kernel config for a given format and uarch. + std::optional get_ukernel_config( + torchao::ops::PackedWeightsHeader header, + cpuinfo_uarch uarch) const { + auto key = make_key(header, uarch); + auto it = registration_table_.find(key); + if (it == registration_table_.end()) { + return std::nullopt; + } + return it->second; + } +}; + +void log_registration(PackedWeightsFormat format, std::string description) { + // Logging is only supported in ATen mode +#ifdef USE_ATEN + LOG(INFO) << "Registering ukernel config for groupwise_lowbit_weight_lut" + << std::endl + << "\tDescription: " << description << std::endl + << "\tformat.type=" << static_cast(format.type) << std::endl + << "\tformat.weight_nbit=" << format.weight_nbit << std::endl + << "\tformat.has_bias=" << format.has_bias << std::endl + << "\tformat.has_scales=" << format.has_scales << std::endl + << "\tformat.lut_group_size=" << format.lut_group_size << std::endl + << "\tformat.scale_group_size=" << format.scale_group_size + << "\tformat.nr=" << format.nr << std::endl + << "\tformat.kr=" << format.kr << std::endl + << "\tformat.sr=" << format.sr << std::endl + << std::endl; +#endif // USE_ATEN +} + +#if defined(TORCHAO_BUILD_CPU_AARCH64) +/** + * @brief Registers all available AArch64 kernels for a given format. + * + * @tparam weight_nbit The bit-width of the weights. + * @tparam has_scales Whether the packed buffer contains scale factors. + * @param table The registration table to add the kernel config to. + * @param format The format header describing the weights. + * @param uarch The target CPU microarchitecture. + */ +template +void register_ukernel_config( + UKernelConfigRegistrationTable& table, + PackedWeightsFormat format, + cpuinfo_uarch uarch) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + if (!cpuinfo_has_arm_v8()) { + // This CPU doesn't support the kernel, so do nothing. + return; + } + + check_format( + format, + torchao::ops::PackedWeightsType::groupwise_lowbit_weight_lut, + weight_nbit); + int preferred_alignment = 16; + + namespace kernel_api = + torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut; + + using kernel_fn_ptr_t = + decltype(&kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32< + weight_nbit, + true>); + kernel_fn_ptr_t kernel_dispatcher; + + if (format.has_scales) { + kernel_dispatcher = &kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32< + weight_nbit, + /*has_scales=*/true>; + } else { + kernel_dispatcher = &kernel_api::groupwise_lowbit_weight_lut_kernel_1x4x32< + weight_nbit, + /*has_scales=*/false>; + } + if (format.nr == 4 && format.kr == 32 && format.sr == 8) { + log_registration(format, "lut: groupwise_lowbit_weight_lut_kernel_1x4x32"); + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr int mr = 1; + constexpr int m_step = 1; + constexpr int n_step = 4; + + auto uk = UKernelConfig::make( + /*preferred_alignment=*/preferred_alignment, + /*n_step=*/n_step, + /*nr=*/format.nr, + /*kr=*/format.kr, + /*sr=*/format.sr, + /*weight_nbit=*/format.weight_nbit, + /*has_scales=*/format.has_scales, + /*has_bias=*/format.has_bias, + /*packed_weights_size_fn_type=*/ + &kernel_api::packed_weights_size, + /*packed_weights_offset_fn_type=*/ + &kernel_api::packed_weights_offset, + /*pack_weights_fn_type=*/ + &kernel_api:: + pack_weights, + /*configs=*/{}); + + uk.configs[0] = UKernelConfig::config_type + {m_step, + mr, + &kernel_api::packed_activations_size, + &kernel_api::packed_activations_offset, + &kernel_api::pack_activations, + kernel_dispatcher}; + + // Resgister the kernel config. + table.register_ukernel_config(format, uarch, std::move(uk)); + return; + } +} +#endif // TORCHAO_BUILD_CPU_AARCH64 + +/** + * @brief Selects the best UKernelConfig for the given format header. + * + * This function is the main entry point for the op. It manages a static + * registration table and, if a kernel is not already registered for the + * current CPU, it will perform the registration. + * + * @tparam weight_nbit The bit-width of the weights. + * @param header A header describing the packed weight format. + * @return The appropriate UKernelConfig for the current environment. + */ +template +UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) { +#if defined(TORCHAO_BUILD_CPU_AARCH64) + // Static table ensures we only register kernels once per session. + static UKernelConfigRegistrationTable table; + + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + + auto uarch = cpuinfo_uarch_unknown; + + auto ukernel = table.get_ukernel_config(header, uarch); + if (ukernel.has_value()) { + return ukernel.value(); + } + + // Create a new format object from the header. + auto format = PackedWeightsFormat::from_packed_weights_header(header); + + register_ukernel_config(table, format, uarch); + + ukernel = table.get_ukernel_config(header, uarch); + assert( + ukernel.has_value() && + "Kernel registration failed for the current CPU microarchitecture."); + return ukernel.value(); +#else + throw std::runtime_error( + "select_ukernel_config for groupwise_lowbit_weight_lut is only supported " + "when TORCHAO_BUILD_CPU_AARCH64 is defined."); +#endif +} + +template +PackedWeightsFormat select_packed_weights_format( + std::optional target, + int scale_group_size, + int lut_group_size, + bool has_scales, + bool has_bias) { + if (!target) { + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::groupwise_lowbit_weight_lut, + weight_nbit, + scale_group_size, + lut_group_size, + has_scales, + has_bias, + /*nr*/ 4, + /*kr*/ 32, + /*sr*/ 8); + } + throw std::runtime_error("No packed_weights_format was selected"); +} + +} // namespace torchao::ops::groupwise_lowbit_weight_lut diff --git a/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h new file mode 100644 index 0000000000..e3aca77844 --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h @@ -0,0 +1,240 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) +template +Tensor linear_out_cpu( + const Tensor& activations, + const Tensor& packed_weights, + const int64_t& scale_group_size, + const int64_t& lut_group_size, + const int64_t& n, + const int64_t& k, + Tensor& out) { + TORCHAO_CHECK(n >= 1, "n must be >= 1"); + TORCHAO_CHECK(k >= 1, "k must be >= 1"); + TORCHAO_CHECK(lut_group_size >= 1, "lut_group_size must be >= 1"); + +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN + TORCHAO_CHECK( + activations.dtype() == torch::kFloat32, "activations must be float32"); +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN + + TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D"); + int m = activations.size(0); + int k_ = activations.size(1); + TORCHAO_CHECK( + k == k_, "activation shape is incompatible with packed weights."); + +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN + TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32"); +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN + + // Explicit cast from int64_t to int is required for Executorch + TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n}); + + TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN + TORCHAO_CHECK( + packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN + TORCHAO_CHECK( + packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), + "packed_weights is not big enough to read the header."); + auto header = + torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); + + auto uk = torchao::ops::groupwise_lowbit_weight_lut::select_ukernel_config< + weight_nbit>(header); + + torchao::ops::groupwise_lowbit_weight_lut:: + groupwise_lowbit_weight_lut_parallel_operator( + uk, + std::nullopt, + out.mutable_data_ptr(), + m, + n, + k, + scale_group_size, + lut_group_size, + packed_weights.const_data_ptr() + + torchao::ops::PackedWeightsHeader::size(), + activations.const_data_ptr(), + /*has_clamp=*/false, + /*clamp_min=*/0.0, + /*clamp_max=*/0.0); + + return out; +} +#endif // defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) + +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN +template +Tensor linear_cpu( + const Tensor& activations, + const Tensor& packed_weights, + const int64_t& scale_group_size, + const int64_t& lut_group_size, + const int64_t& n, + const int64_t& k) { + Tensor output_tensor = torch::empty({}, torch::kFloat32); + linear_out_cpu( + activations, + packed_weights, + scale_group_size, + lut_group_size, + n, + k, + output_tensor); + return output_tensor; +} +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN + +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN +template +Tensor pack_weights_with_lut_cpu( + const Tensor& weight_qval_idxs, + const Tensor& luts, + int64_t scale_group_size, + int64_t lut_group_size, + const std::optional& weight_scales, + const std::optional& bias, + const std::optional& target) { + bool has_scales = weight_scales.has_value(); + bool has_bias = bias.has_value(); + + TORCHAO_CHECK( + weight_qval_idxs.dtype() == torch::kUInt8, + "weight_qval_idxs must be uint8"); + TORCHAO_CHECK(weight_qval_idxs.dim() == 2, "weight_qval_idxs must be 2D"); + int n = weight_qval_idxs.size(0); + int k = weight_qval_idxs.size(1); + TORCHAO_CHECK(lut_group_size >= 1, "lut_group_size must be >= 1"); + + TORCHAO_CHECK( + luts.dtype() == torch::kFloat32, + "luts must be float32"); // Changed to kFloat32 + TORCHAO_CHECK(lut_group_size % k == 0, "the number of luts must divide k"); + + TORCHAO_CHECK( + luts.size(1) == (1 << weight_nbit), + "luts must have 1 entry per quantization level"); + const float* scales_ptr = nullptr; + + if (has_scales) { + TORCHAO_CHECK(scale_group_size >= 1, "scale_group_size must be >= 1"); + TORCHAO_CHECK( + weight_scales->dtype() == torch::kFloat32, + "weight_scales must be float32"); + TORCHAO_CHECK(weight_scales->dim() == 1, "weight_scales must be 1D"); + scales_ptr = weight_scales.value().const_data_ptr(); + } + + const float* bias_ptr = nullptr; + if (has_bias) { + TORCHAO_CHECK( + bias.value().dtype() == torch::kFloat32, "bias must be float32"); + TORCHAO_CHECK(bias.value().dim() == 1, "bias must be 1D"); + TORCHAO_CHECK(bias.value().size(0) == n, "expected 1 bias per row"); + bias_ptr = bias.value().const_data_ptr(); + } + + TORCHAO_CHECK( + !target.has_value(), "target is not currently supported in pack_weights"); + + auto packed_weights_format = + torchao::ops::groupwise_lowbit_weight_lut::select_packed_weights_format< + weight_nbit>( + target, scale_group_size, lut_group_size, has_scales, has_bias); + + auto packed_weights_header = packed_weights_format.to_packed_weights_header(); + auto uk = torchao::ops::groupwise_lowbit_weight_lut::select_ukernel_config< + weight_nbit>(packed_weights_header); + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + weight_nbit, + scale_group_size, + has_scales, + has_bias, + uk.nr, + uk.kr, + uk.sr); + + Tensor packed_weights = torch::empty( + {static_cast(packed_weight_data_size)}, torch::kInt8); + packed_weights_header.write(packed_weights.mutable_data_ptr()); + + torchao::ops::groupwise_lowbit_weight_lut::pack_weights_operator( + uk, + packed_weights.mutable_data_ptr() + + torchao::ops::PackedWeightsHeader::size(), + n, + k, + scale_group_size, + lut_group_size, + weight_qval_idxs.const_data_ptr(), + scales_ptr, + luts.const_data_ptr(), + bias_ptr); + + return packed_weights; +} +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN + +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN +template +Tensor pack_weights_with_lut_meta( + const Tensor& weight_qval_idxs, + const Tensor& luts, + int64_t scale_group_size, + int64_t lut_group_size, + const std::optional& weight_scales, + const std::optional& bias, + const std::optional& target) { + bool has_bias = bias.has_value(); + bool has_scales = weight_scales.has_value(); + int n = weight_qval_idxs.size(0); + int k = weight_qval_idxs.size(1); + auto packed_weights_format = + torchao::ops::groupwise_lowbit_weight_lut::select_packed_weights_format< + weight_nbit>( + target, scale_group_size, lut_group_size, has_scales, has_bias); + auto packed_weights_header = packed_weights_format.to_packed_weights_header(); + auto uk = torchao::ops::groupwise_lowbit_weight_lut::select_ukernel_config< + weight_nbit>(packed_weights_header); + + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + uk.packed_weights_size( + n, + k, + weight_nbit, + scale_group_size, + has_scales, + has_bias, + uk.nr, + uk.kr, + uk.sr); + + auto options = + torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); + return torch::empty({static_cast(packed_weight_data_size)}, options); +} +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN + +} // namespace diff --git a/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp new file mode 100644 index 0000000000..c9b65f2152 --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#define DEFINE_PACK_OP(weight_nbit) \ + m.def( \ + "_pack_groupwise_" #weight_nbit \ + "bit_weight_with_lut(Tensor weight_qval_idxs, Tensor luts, int scale_group_size, int lut_group_size, Tensor? weight_scales, Tensor? bias, str? target) -> Tensor"); + +#define DEFINE_LINEAR_OP(weight_nbit) \ + m.def( \ + "_linear_groupwise_" #weight_nbit \ + "bit_weight_with_lut(Tensor activations, Tensor packed_weights, int scale_group_size, int lut_group_size, int n, int k) -> Tensor"); \ + m.def( \ + "_linear_groupwise_" #weight_nbit \ + "bit_weight_with_lut.out(Tensor activations, Tensor packed_weights, int scale_group_size, int lut_group_size, int n, int k, *, Tensor(a!) out) -> Tensor(a!)"); + +#define DEFINE_PACK_CPU_IMPL(weight_nbit) \ + m.impl( \ + "_pack_groupwise_" #weight_nbit "bit_weight_with_lut", \ + &pack_weights_with_lut_cpu); + +#define DEFINE_PACK_META_IMPL(weight_nbit) \ + m.impl( \ + "_pack_groupwise_" #weight_nbit "bit_weight_with_lut", \ + &pack_weights_with_lut_meta); + +#define DEFINE_LINEAR_CPU_IMPL(weight_nbit) \ + m.impl( \ + "_linear_groupwise_" #weight_nbit "bit_weight_with_lut", \ + &linear_cpu); \ + m.impl( \ + "_linear_groupwise_" #weight_nbit "bit_weight_with_lut.out", \ + &linear_out_cpu); + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + DEFINE_PACK_OP(1); + DEFINE_PACK_OP(2); + DEFINE_PACK_OP(3); + DEFINE_PACK_OP(4); + + DEFINE_LINEAR_OP(1); + DEFINE_LINEAR_OP(2); + DEFINE_LINEAR_OP(3); + DEFINE_LINEAR_OP(4); +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + DEFINE_PACK_CPU_IMPL(1); + DEFINE_PACK_CPU_IMPL(2); + DEFINE_PACK_CPU_IMPL(3); + DEFINE_PACK_CPU_IMPL(4); + + DEFINE_LINEAR_CPU_IMPL(1); + DEFINE_LINEAR_CPU_IMPL(2); + DEFINE_LINEAR_CPU_IMPL(3); + DEFINE_LINEAR_CPU_IMPL(4); +} + +TORCH_LIBRARY_IMPL(torchao, Meta, m) { + DEFINE_PACK_META_IMPL(1); + DEFINE_PACK_META_IMPL(2); + DEFINE_PACK_META_IMPL(3); + DEFINE_PACK_META_IMPL(4); +} diff --git a/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp new file mode 100644 index 0000000000..d3e06dd538 --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp @@ -0,0 +1,32 @@ +#include + +#define DEFINE_OP(weight_nbit) \ + Tensor _op_out_##weight_nbit( \ + RuntimeContext& ctx, \ + const Tensor& activations, \ + const Tensor& packed_weights, \ + const int64_t& scale_group_size, \ + const int64_t& lut_group_size, \ + const int64_t& n, \ + const int64_t& k, \ + Tensor& out) { \ + (void)ctx; \ + linear_out_cpu( \ + activations, \ + packed_weights, \ + scale_group_size, \ + lut_group_size, \ + n, \ + k, \ + out); \ + return out; \ + } \ + EXECUTORCH_LIBRARY( \ + torchao, \ + "_linear_groupwise_" #weight_nbit "bit_weight_with_lut.out", \ + _op_out_##weight_nbit) + +DEFINE_OP(1); +DEFINE_OP(2); +DEFINE_OP(3); +DEFINE_OP(4); diff --git a/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/packed_weights_format.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/packed_weights_format.h new file mode 100644 index 0000000000..d7c64fbebd --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/packed_weights_format.h @@ -0,0 +1,110 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace torchao::ops::groupwise_lowbit_weight_lut { + +/** + * @brief Defines the format parameters for the packed weights of the + * groupwise LUT kernel. + */ +struct PackedWeightsFormat { + torchao::ops::PackedWeightsType type; + int weight_nbit; + int scale_group_size; + int lut_group_size; + bool has_scales; + bool has_bias; + int nr; + int kr; + int sr; + + PackedWeightsFormat( + torchao::ops::PackedWeightsType type, + int weight_nbit, + int scale_group_size, + int lut_group_size, + bool has_scales, + bool has_bias, + int nr, + int kr, + int sr) + : type{type}, + weight_nbit{weight_nbit}, + scale_group_size{scale_group_size}, + lut_group_size{lut_group_size}, + has_scales{has_scales}, + has_bias{has_bias}, + nr{nr}, + kr{kr}, + sr{sr} {} + + /** + * @brief Converts a generic PackedWeightsHeader into this specific format. + * + * This assumes the generic header's `params` array is populated in the + * correct order. + */ + static PackedWeightsFormat from_packed_weights_header( + const torchao::ops::PackedWeightsHeader& header) { + return PackedWeightsFormat( + header.type, + header.params[0], // weight_nbit + header.params[1], // scale_group_size + header.params[2], // lut_group_size + static_cast(header.params[3]), // has_scales + static_cast(header.params[4]), // has_bias + header.params[5], // nr + header.params[6], // kr + header.params[7] // sr + ); + } + + /** + * @brief Converts this specific format into a generic PackedWeightsHeader. + */ + inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const { + return torchao::ops::PackedWeightsHeader( + type, + {weight_nbit, + scale_group_size, + lut_group_size, + has_scales, + has_bias, + nr, + kr, + sr}); + } +}; + +/** + * @brief Helper function to validate that the provided format matches the + * expectations of a specific kernel. + */ +inline void check_format( + const PackedWeightsFormat& format, + torchao::ops::PackedWeightsType expected_type, + int expected_weight_nbit) { + if (format.type != expected_type) { + throw std::runtime_error( + "Kernel expects packed_weights type=" + + std::to_string(static_cast(expected_type)) + + ", but got packed_weights with type=" + + std::to_string(static_cast(format.type))); + } + if (format.weight_nbit != expected_weight_nbit) { + throw std::runtime_error( + "Kernel expects weight_nbit=" + std::to_string(expected_weight_nbit) + + ", but got packed_weights with weight_nbit=" + + std::to_string(format.weight_nbit)); + } +} + +} // namespace torchao::ops::groupwise_lowbit_weight_lut diff --git a/torchao/experimental/ops/library.h b/torchao/csrc/cpu/shared_kernels/internal/library.h similarity index 67% rename from torchao/experimental/ops/library.h rename to torchao/csrc/cpu/shared_kernels/internal/library.h index c518b31aee..204d97f5a7 100644 --- a/torchao/experimental/ops/library.h +++ b/torchao/csrc/cpu/shared_kernels/internal/library.h @@ -4,8 +4,8 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#if defined(USE_ATEN) && !defined(USE_EXECUTORCH) -#pragma message("USE_ATEN") +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) && !defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) +#pragma message("TORCHAO_SHARED_KERNELS_BUILD_ATEN") #include #include #include @@ -15,8 +15,8 @@ using Tensor = at::Tensor; #define TORCHAO_CHECK(cond, msg) TORCH_CHECK(cond, msg) #define TORCHAO_RESIZE_TENSOR(tensor, ...) tensor.resize_({__VA_ARGS__}) -#elif defined(USE_EXECUTORCH) && !defined(USE_ATEN) -#pragma message("USE_EXECUTORCH") +#elif defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) && !defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) +#pragma message("TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH") #include #include #include @@ -28,8 +28,8 @@ using RuntimeContext = torch::executor::KernelRuntimeContext; #define TORCHAO_RESIZE_TENSOR(tensor, ...) \ ET_CHECK_MSG(torch::executor::resize_tensor(tensor, {__VA_ARGS__}) == torch::executor::Error::Ok, "resize failed") -#elif !defined(USE_EXECUTORCH) && !defined(USE_ATEN) -#pragma message("Neither USE_ATEN or USE_EXECUTORCH defined") +#elif !defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) && !defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) +#pragma message("Neither TORCHAO_SHARED_KERNELS_BUILD_ATEN or TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH defined") #include #define TORCHAO_CHECK(cond, message) \ @@ -38,5 +38,5 @@ using RuntimeContext = torch::executor::KernelRuntimeContext; } #else -#error "Cannot define both USE_ATEN or USE_EXECUTORCH" +#error "Cannot define both TORCHAO_SHARED_KERNELS_BUILD_ATEN or TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH" #endif diff --git a/torchao/experimental/ops/memory.h b/torchao/csrc/cpu/shared_kernels/internal/memory.h similarity index 100% rename from torchao/experimental/ops/memory.h rename to torchao/csrc/cpu/shared_kernels/internal/memory.h diff --git a/torchao/experimental/ops/packed_weights_header.h b/torchao/csrc/cpu/shared_kernels/internal/packed_weights_header.h similarity index 98% rename from torchao/experimental/ops/packed_weights_header.h rename to torchao/csrc/cpu/shared_kernels/internal/packed_weights_header.h index 90f77beae2..c3121b6056 100644 --- a/torchao/experimental/ops/packed_weights_header.h +++ b/torchao/csrc/cpu/shared_kernels/internal/packed_weights_header.h @@ -18,6 +18,7 @@ enum class PackedWeightsType : uint32_t { embedding_xbit_universal = 2, linear_8bit_act_xbit_weight_kleidi_ai = 3, linear_8bit_act_xbit_weight_lut = 4, + groupwise_lowbit_weight_lut = 5, }; class PackedWeightsHeader { diff --git a/torchao/experimental/ops/parallel-aten-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-aten-impl.h similarity index 87% rename from torchao/experimental/ops/parallel-aten-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-aten-impl.h index c2eb0b8498..9c825e48e5 100644 --- a/torchao/experimental/ops/parallel-aten-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-aten-impl.h @@ -19,10 +19,6 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { }); } -inline void torchao::set_num_threads(int num_threads) { - torch::set_num_threads(num_threads); -} - inline int torchao::get_num_threads() { return torch::get_num_threads(); } diff --git a/torchao/experimental/ops/parallel-executorch-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-executorch-impl.h similarity index 80% rename from torchao/experimental/ops/parallel-executorch-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-executorch-impl.h index 233f7250d4..01c8eb766f 100644 --- a/torchao/experimental/ops/parallel-executorch-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-executorch-impl.h @@ -18,11 +18,6 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { end - begin); } -inline void torchao::set_num_threads(int num_threads) { - torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool( - num_threads); -} - inline int torchao::get_num_threads() { return torch::executorch::threadpool::get_threadpool()->get_thread_count(); } diff --git a/torchao/experimental/ops/parallel-openmp-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-openmp-impl.h similarity index 87% rename from torchao/experimental/ops/parallel-openmp-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-openmp-impl.h index 236bb4e25f..e9b43653d2 100644 --- a/torchao/experimental/ops/parallel-openmp-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-openmp-impl.h @@ -18,9 +18,6 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { } } -inline void torchao::set_num_threads(int num_threads) { - omp_set_num_threads(num_threads); -} inline int torchao::get_num_threads() { // omp_get_num_threads returns the number of threads // in the current code section, which will be 1 in the routines diff --git a/torchao/experimental/ops/parallel-pthreadpool-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-pthreadpool-impl.h similarity index 83% rename from torchao/experimental/ops/parallel-pthreadpool-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-pthreadpool-impl.h index 9906cf4f3a..704349b59d 100644 --- a/torchao/experimental/ops/parallel-pthreadpool-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-pthreadpool-impl.h @@ -33,13 +33,6 @@ class Threadpool { } return pthreadpool_get_threads_count(pthreadpool_); } - void set_num_threads(size_t num_threads) { - if (num_threads == get_num_threads()) { - return; - } - pthreadpool_destroy(pthreadpool_); - pthreadpool_ = pthreadpool_create(num_threads); - } }; template @@ -62,10 +55,6 @@ inline int torchao::get_num_threads() { return torchao::parallel::internal::threadpool.get_num_threads(); } -inline void torchao::set_num_threads(int num_threads) { - torchao::parallel::internal::threadpool.set_num_threads(num_threads); -} - template void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { auto context = torchao::parallel::internal::Context(f, begin); diff --git a/torchao/experimental/ops/parallel-single_threaded-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-single_threaded-impl.h similarity index 88% rename from torchao/experimental/ops/parallel-single_threaded-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-single_threaded-impl.h index d9706829c2..74f067e39a 100644 --- a/torchao/experimental/ops/parallel-single_threaded-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-single_threaded-impl.h @@ -13,7 +13,6 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { } } -inline void torchao::set_num_threads(int num_threads) {} inline int torchao::get_num_threads() { return 1; } diff --git a/torchao/experimental/ops/parallel-test_dummy-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-test_dummy-impl.h similarity index 86% rename from torchao/experimental/ops/parallel-test_dummy-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-test_dummy-impl.h index de5a5f63ad..4a82cbd504 100644 --- a/torchao/experimental/ops/parallel-test_dummy-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-test_dummy-impl.h @@ -15,9 +15,13 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { } } -inline void torchao::set_num_threads(int num_threads) { - torchao::parallel::internal::num_threads_test_dummy_ = num_threads; -} inline int torchao::get_num_threads() { return torchao::parallel::internal::num_threads_test_dummy_; } + + +namespace torchao::parallel { +inline void set_num_threads_in_test_dummy(int num_threads) { + torchao::parallel::internal::num_threads_test_dummy_ = num_threads; +} +} diff --git a/torchao/experimental/ops/parallel.h b/torchao/csrc/cpu/shared_kernels/internal/parallel.h similarity index 80% rename from torchao/experimental/ops/parallel.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel.h index 5372c5a2dd..81f98b92c7 100644 --- a/torchao/experimental/ops/parallel.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel.h @@ -12,8 +12,6 @@ namespace torchao { template void parallel_1d(const int64_t begin, const int64_t end, const F& f); -void set_num_threads(int num_threads); - int get_num_threads(); } // namespace torchao @@ -28,37 +26,37 @@ int get_num_threads(); #pragma message( \ "AT_PARALLEL_OPENMP is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") #endif -#include +#include #else #ifdef TORCHAO_PARALLEL_EXECUTORCH #pragma message( \ "TORCHAO_PARALLEL_EXECUTORCH is set. Using ExecuTorch parallel backend.") -#include +#include #else #ifdef TORCHAO_PARALLEL_PTHREADPOOL #pragma message( \ "TORCHAO_PARALLEL_PTHREADPOOL is set. Using pthreadpool parallel backend.") -#include +#include #else #ifdef TORCHAO_PARALLEL_OPENMP #pragma message( \ "TORCHAO_PARALLEL_OPENMP is set. Using OPENMP parallel backend.") -#include +#include #else #if defined TORCHAO_PARALLEL_SINGLE_THREADED #pragma message( \ "TORCHAO_PARALLEL_SINGLE_THREADED is set. Using single-threaded parallel backend.") -#include +#include #else #if defined TORCHAO_PARALLEL_TEST_DUMMY #pragma message( \ "TORCHAO_PARALLEL_TEST_DUMMY is set. Using test dummy parallel backend.") -#include +#include #else #error \ diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_config.h similarity index 98% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_config.h index b699bdd3d3..c54b8af090 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_config.h @@ -5,8 +5,8 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include -#include +#include +#include #include #include diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_selector.h similarity index 95% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_selector.h index 958b9c08e5..88b27f4217 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_selector.h @@ -6,19 +6,19 @@ #pragma once #include -#include -#include +#include +#include #include #include #include #if defined(TORCHAO_BUILD_CPU_AARCH64) #if defined(TORCHAO_ENABLE_ARM_NEON_DOT) -#include +#include #endif // TORCHAO_ENABLE_ARM_NEON_DOT #if defined(TORCHAO_ENABLE_KLEIDI) -#include +#include #endif // TORCHAO_ENABLE_KLEIDI #endif // TORCHAO_BUILD_CPU_AARCH64 @@ -66,9 +66,9 @@ struct UKernelConfigRegistrationTable { } }; -void log_registration(PackedWeightsFormat format, std::string description) { +void inline log_registration(PackedWeightsFormat format, std::string description) { // Logging is only supported in ATen mode -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN LOG(INFO) << "Registering ukernel config for linear_8bit_act_xbit_weight" << std::endl << "\tDescription: " << description << std::endl @@ -80,7 +80,7 @@ void log_registration(PackedWeightsFormat format, std::string description) { << "\tformat.nr=" << format.nr << std::endl << "\tformat.kr=" << format.kr << std::endl << "\tformat.sr=" << format.sr << std::endl; -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN } template @@ -184,7 +184,7 @@ void register_ukernel_config_lut( namespace kernel = torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; - if (cpuinfo_has_arm_neon_dot()) { + if (!cpuinfo_has_arm_neon_dot()) { return; } if (format.has_weight_zeros) { diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp similarity index 95% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 96bfe17b5a..e95191d925 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -5,10 +5,10 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include @@ -203,7 +203,8 @@ void linear_operator( nc = tiling_params->nc; } else { auto params = LinearTilingParams::from_target_tiles_per_thread( - m, + // We process m sequentially, so m_step is the "m" for the purpose of computing tiling params + m_step, m_step, n, n_step, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h similarity index 91% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index 95e1640ad9..a148d3aa31 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -7,8 +7,8 @@ #pragma once #include #include -#include -#include +#include +#include #include #include diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h similarity index 88% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 8a72cbd00a..94df29d669 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -6,16 +6,16 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include +#include #include #include namespace { -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_cpu( const Tensor& weight_qvals, @@ -106,9 +106,9 @@ Tensor pack_weights_cpu( return packed_weights; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_meta( const Tensor& weight_qvals, @@ -146,9 +146,9 @@ Tensor pack_weights_meta( torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); return torch::empty({static_cast(packed_weight_data_size)}, options); } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#if defined(USE_ATEN) || defined(USE_EXECUTORCH) +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) template Tensor linear_out_cpu( const Tensor& activations, @@ -161,10 +161,10 @@ Tensor linear_out_cpu( TORCHAO_CHECK(k >= 1, "k must be >= 1"); TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( activations.dtype() == torch::kFloat32, "activations must be float32"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D"); int m = activations.size(0); @@ -172,18 +172,18 @@ Tensor linear_out_cpu( TORCHAO_CHECK( k == k_, "activation shape is incompatible with packed weights."); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN // Explicit cast from int64_t to int is required for Executorch TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n}); TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), "packed_weights is not big enough to read the header."); @@ -210,9 +210,9 @@ Tensor linear_out_cpu( return out; } -#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH) +#endif // defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor linear_cpu( const Tensor& activations, @@ -225,9 +225,9 @@ Tensor linear_cpu( activations, packed_weights, group_size, n, k, output_tensor); return output_tensor; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_with_lut_cpu( const Tensor& weight_qval_idxs, @@ -251,6 +251,7 @@ Tensor pack_weights_with_lut_cpu( "weight_scales must be float32"); TORCHAO_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D"); TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); + TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); TORCHAO_CHECK( weight_scales.size(0) == ((n * k) / group_size), "expected 1 scale per group"); @@ -285,8 +286,8 @@ Tensor pack_weights_with_lut_cpu( weight_nbit>(target, has_weight_zeros, has_bias); TORCHAO_CHECK(packed_weights_format.nr == 8, "nr must be 8"); TORCHAO_CHECK( - lut_channel_group_size % 8 == 0, - "the lut_channel_group_size must be a multiple of nr (8)"); + lut_channel_group_size == n || lut_channel_group_size % 8 == 0, + "the lut_channel_group_size must be n or a multiple of nr (8)"); auto packed_weights_header = packed_weights_format.to_packed_weights_header(); auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config< @@ -323,9 +324,9 @@ Tensor pack_weights_with_lut_cpu( return packed_weights; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_with_lut_meta( const Tensor& weight_qval_idxs, @@ -360,6 +361,6 @@ Tensor pack_weights_with_lut_meta( torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); return torch::empty({static_cast(packed_weight_data_size)}, options); } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN } // namespace diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp similarity index 97% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index 7e5799b5fd..466fd2567f 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ m.def( \ diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp similarity index 91% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp index 1275accbaa..78ccefecb7 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp @@ -1,4 +1,4 @@ -#include +#include #define DEFINE_OP(weight_nbit) \ Tensor _op_out_##weight_nbit( \ diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/packed_weights_format.h similarity index 96% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/packed_weights_format.h index e22082f9f1..e95593c13b 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/packed_weights_format.h @@ -6,7 +6,7 @@ #pragma once -#include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { diff --git a/torchao/csrc/cpu/shared_kernels/tests/CMakeLists.txt b/torchao/csrc/cpu/shared_kernels/tests/CMakeLists.txt new file mode 100644 index 0000000000..28bda6a1b8 --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/tests/CMakeLists.txt @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_tests) + +set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TEST_TARGET_PREFIX "torchao_tests_shared_kernels_") + +add_executable( + ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight + test_linear_8bit_act_xbit_weight.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight + PRIVATE + GTest::gtest_main +) +if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries( + ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight + PRIVATE + torchao_kernels_aarch64 + ) +endif() +if (TORCHAO_BUILD_KLEIDIAI) + target_link_libraries( + ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight + PRIVATE + kleidiai + ) +endif() +target_link_torchao_parallel_backend( ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight test_dummy) + +add_executable( + ${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut + test_groupwise_lowbit_weight_lut.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp +) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut + PRIVATE + GTest::gtest_main +) +if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries( + ${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut + PRIVATE + torchao_kernels_aarch64 + ) +endif() +target_link_torchao_parallel_backend(${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut test_dummy) + +include(GoogleTest) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight) diff --git a/torchao/experimental/ops/tests/generate_tests.py b/torchao/csrc/cpu/shared_kernels/tests/generate_tests.py similarity index 100% rename from torchao/experimental/ops/tests/generate_tests.py rename to torchao/csrc/cpu/shared_kernels/tests/generate_tests.py diff --git a/torchao/csrc/cpu/shared_kernels/tests/test_groupwise_lowbit_weight_lut.cpp b/torchao/csrc/cpu/shared_kernels/tests/test_groupwise_lowbit_weight_lut.cpp new file mode 100644 index 0000000000..10bf9bcd3c --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/tests/test_groupwise_lowbit_weight_lut.cpp @@ -0,0 +1,342 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#if defined(TORCHAO_BUILD_CPU_AARCH64) +#include +#endif // TORCHAO_BUILD_CPU_AARCH64 +#include +#include +#include +#include + +const float kTol = 1.0e-5; +using namespace torchao::ops::groupwise_lowbit_weight_lut; + +template +UKernelConfig get_ukernel_config(bool has_bias) { + namespace kernel = + torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut; + + int preferred_alignment = 16; + int n_step = 8; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr int mr = 1; + int m_step = 1; + + auto uk = UKernelConfig::make( + preferred_alignment, + n_step, + nr, + kr, + sr, + weight_nbit, + has_scales, + has_bias, + &kernel::packed_weights_size, + &kernel::packed_weights_offset, + &kernel::pack_weights, + /*configs*/ {}); + + uk.configs[0] = UKernelConfig::config_type{ + m_step, + mr, + &kernel::packed_activations_size, + &kernel::packed_activations_offset, + &kernel::pack_activations, + &kernel:: + groupwise_lowbit_weight_lut_kernel_1x4x32}; + return uk; +} + +template +void test_groupwise_lowbit_weight_lut( + int m, + int k, + int n, + int scale_group_size, + int lut_group_size, + bool has_bias, + bool has_clamp, + const UKernelConfig* ukernel_config_arg = nullptr) { + UKernelConfig ukernel_config; + if (ukernel_config_arg != nullptr) { + ukernel_config = *ukernel_config_arg; + } else { + ukernel_config = get_ukernel_config(has_bias); + } + + auto test_case = torchao::groupwise_lowbit_weight_lut_test_case:: + generate_with_decoupled_grouping( + m, + k, + n, + scale_group_size, + lut_group_size, + weight_nbit, + has_scales, + has_bias, + has_clamp); + + auto output = std::vector(m * n); + + for (auto num_threads : {1, 4, 500}) { + torchao::parallel::set_num_threads_in_test_dummy(num_threads); + EXPECT_EQ(torchao::get_num_threads(), num_threads); + auto packed_weight_data_size = ukernel_config.packed_weights_size( + n, + k, + weight_nbit, + scale_group_size, + has_scales, + has_bias, + ukernel_config.nr, + ukernel_config.kr, + ukernel_config.sr); + auto preferred_packed_weight_data_alignment = + ukernel_config.preferred_alignment; + auto packed_weights = torchao::make_aligned_byte_ptr( + preferred_packed_weight_data_alignment, packed_weight_data_size); + + pack_weights_operator( + ukernel_config, + // Outputs + packed_weights.get(), + // Inputs + n, + k, + scale_group_size, + lut_group_size, + test_case.weight_qval_indices.data(), + test_case.weight_scales.data(), + test_case.weight_luts.data(), + test_case.bias.data()); + + groupwise_lowbit_weight_lut_parallel_operator( + ukernel_config, + std::nullopt, + output.data(), + m, + n, + k, + scale_group_size, + lut_group_size, + packed_weights.get(), + test_case.activations.data(), + has_clamp, + test_case.clamp_min, + test_case.clamp_max); + + float tol = kTol; + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], tol); + } + } +} + +struct KernelTestParams { + int m; + int k; + int n; + int scale_group_size; + int lut_group_size; + bool has_bias; + bool has_clamp; +}; + +class ComprehensiveKernelTest + : public ::testing::TestWithParam {}; + +TEST_P(ComprehensiveKernelTest, kernel_test_has_scales_true) { + const KernelTestParams& params = GetParam(); + + constexpr bool has_scales = true; + + for (int weight_nbit : {1, 2, 3, 4}) { + switch (weight_nbit) { + case 1: + test_groupwise_lowbit_weight_lut<1, has_scales>( + params.m, + params.k, + params.n, + params.scale_group_size, + params.lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 2: + test_groupwise_lowbit_weight_lut<2, has_scales>( + params.m, + params.k, + params.n, + params.scale_group_size, + params.lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 3: + test_groupwise_lowbit_weight_lut<3, has_scales>( + params.m, + params.k, + params.n, + params.scale_group_size, + params.lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 4: + test_groupwise_lowbit_weight_lut<4, has_scales>( + params.m, + params.k, + params.n, + params.scale_group_size, + params.lut_group_size, + params.has_bias, + params.has_clamp); + break; + default: + FAIL() << "Unsupported weight_nbit value: " << weight_nbit; + } + } +} + +TEST_P(ComprehensiveKernelTest, kernel_test_has_scales_false) { + const KernelTestParams& params = GetParam(); + + constexpr bool has_scales = false; + + for (int weight_nbit : {1, 2, 3, 4}) { + switch (weight_nbit) { + case 1: + test_groupwise_lowbit_weight_lut<1, has_scales>( + params.m, + params.k, + params.n, + params.scale_group_size, + params.lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 2: + test_groupwise_lowbit_weight_lut<2, has_scales>( + params.m, + params.k, + params.n, + params.scale_group_size, + params.lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 3: + test_groupwise_lowbit_weight_lut<3, has_scales>( + params.m, + params.k, + params.n, + params.scale_group_size, + params.lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 4: + test_groupwise_lowbit_weight_lut<4, has_scales>( + params.m, + params.k, + params.n, + params.scale_group_size, + params.lut_group_size, + params.has_bias, + params.has_clamp); + break; + default: + FAIL() << "Unsupported weight_nbit value: " << weight_nbit; + } + } +} + +INSTANTIATE_TEST_SUITE_P( + KernelEdgeCases, + ComprehensiveKernelTest, + ::testing::Values( + // Flag-specific tests + KernelTestParams{ + 8, + 64, + 16, + 32, + 256, + /*has_bias=*/true, + /*has_clamp=*/true}, + KernelTestParams{ + 8, + 64, + 16, + 32, + 256, + /*has_bias=*/true, + /*has_clamp=*/false}, + KernelTestParams{ + 8, + 64, + 16, + 32, + 256, + /*has_bias=*/false, + /*has_clamp=*/true}, + KernelTestParams{ + 8, + 64, + 16, + 32, + 256, + /*has_bias=*/false, + /*has_clamp=*/false}, + + // Prime number dimensions for m and n + KernelTestParams{ + 7, + 64, + 13, + 32, + 256, + /*has_bias=*/true, + /*has_clamp=*/true}, + KernelTestParams{ + 13, + 128, + 17, + 64, + 512, + /*has_bias=*/false, + /*has_clamp=*/false}, + KernelTestParams{ + 1, + 32, + 5, + 32, + 128, + /*has_bias=*/true, + /*has_clamp=*/false}, + + // Varying Dimensions and Group Sizes + KernelTestParams{8, 64, 16, 32, 256, true, true}, + KernelTestParams{8, 64, 12, 32, 256, true, false}, + KernelTestParams{7, 128, 24, 64, 512, false, true}, + KernelTestParams{1, 32, 4, 32, 128, true, true}, + + // Unaligned M + KernelTestParams{7, 64, 16, 32, 256, true, false}, + KernelTestParams{5, 64, 16, 32, 256, false, true}, + KernelTestParams{1, 64, 16, 32, 256, true, true})); + +void PrintTo(const KernelTestParams& params, std::ostream* os) { + *os << "KernelTestParams(m=" << params.m << ", k=" << params.k + << ", n=" << params.n << ", scale_gs=" << params.scale_group_size + << ", lut_gs=" << params.lut_group_size + << ", has_bias=" << (params.has_bias ? "true" : "false") + << ", has_clamp=" << (params.has_clamp ? "true" : "false") << ")"; +} diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/csrc/cpu/shared_kernels/tests/test_linear_8bit_act_xbit_weight.cpp similarity index 99% rename from torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp rename to torchao/csrc/cpu/shared_kernels/tests/test_linear_8bit_act_xbit_weight.cpp index 16c38aa8d3..7631d34a03 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/csrc/cpu/shared_kernels/tests/test_linear_8bit_act_xbit_weight.cpp @@ -7,15 +7,15 @@ #include // TODO: move test_utils.h out of aarch64 #if defined(TORCHAO_BUILD_CPU_AARCH64) -#include +#include #endif // TORCHAO_BUILD_CPU_AARCH64 -#include -#include -#include -#include +#include +#include +#include +#include #if defined(TORCHAO_ENABLE_KLEIDI) -#include +#include using namespace torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; #endif // TORCHAO_ENABLE_KLEIDI @@ -111,7 +111,7 @@ void test_linear_8bit_act_xbit_weight( auto output = std::vector(m * n); for (auto num_threads : {1, 4, 500}) { - torchao::set_num_threads(num_threads); + torchao::parallel::set_num_threads_in_test_dummy(num_threads); EXPECT_EQ(torchao::get_num_threads(), num_threads); // Pack weights diff --git a/torchao/csrc/cpu/torch_free_kernels/README.md b/torchao/csrc/cpu/torch_free_kernels/README.md new file mode 100644 index 0000000000..e1787bd980 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/README.md @@ -0,0 +1,8 @@ +# Torch free kernels + +Kernels in this directory do not depend on Torch. Rather than use Tensor, they are written with raw pointers. These raw kernels are used by ATen/ExecuTorch kernels in torchao/csrc/cpu/shared_kernels. + +Code is organized into subdirectories by CPU architecture: +* aarch64 (Arm) +* fallback (architecture-independent / generic C++) +* interface (high-level interface for fallback and architecture-specific code) diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/aarch64/CMakeLists.txt new file mode 100644 index 0000000000..42f9cc82b7 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +if (TORCHAO_BUILD_CPU_AARCH64) + add_library( + torchao_kernels_aarch64 + ${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduction/compute_sum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantization/quantize.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/valpacking/interleave.cpp + ) +endif() + +if (TORCHAO_BUILD_TESTS) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tests) +endif() + +if (TORCHAO_BUILD_BENCHMARKS) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/benchmarks) +endif() diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/CMakeLists.txt new file mode 100644 index 0000000000..d9d0480dfb --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_benchmarks) +set(CMAKE_BUILD_TYPE Release) + +set(TARGET_PREFIX "torchao_benchmarks_torch_free_kernels_aarch64_") + +add_library( + ${TARGET_PREFIX}dep + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/interleave.cpp +) + +add_executable(${TARGET_PREFIX}benchmark_quantization benchmark_quantization.cpp) +target_link_libraries( + ${TARGET_PREFIX}benchmark_quantization + PRIVATE + benchmark::benchmark + ${TARGET_PREFIX}dep +) + +add_executable(${TARGET_PREFIX}benchmark_bitpacking benchmark_bitpacking.cpp) +target_link_libraries( + ${TARGET_PREFIX}benchmark_bitpacking + PRIVATE + benchmark::benchmark + ${TARGET_PREFIX}dep +) + +# TODO: fix this, it's not working right now because of code refactors +# add_executable(${TARGET_PREFIX}benchmark_linear benchmark_linear.cpp) +# target_link_libraries( +# ${TARGET_PREFIX}benchmark_linear +# PRIVATE +# benchmark::benchmark +# ${TARGET_PREFIX}dep +# ) diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_bitpacking.cpp similarity index 96% rename from torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_bitpacking.cpp index a6bb8b478f..d31233b09b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -9,15 +9,15 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include namespace { diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_linear.cpp similarity index 95% rename from torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_linear.cpp index 4e9759ab2e..26abe6918a 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_linear.cpp @@ -5,9 +5,9 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include -#include +#include +#include +#include #include template @@ -92,7 +92,7 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( int group_size = state.range(3); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( @@ -164,7 +164,7 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( int group_size = state.range(3); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_quantization.cpp similarity index 84% rename from torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_quantization.cpp index 7c81b963dc..d877b905d0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_quantization.cpp @@ -7,9 +7,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include +#include +#include +#include static void benchmark_quantize(benchmark::State& state) { int nbit = state.range(0); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/bitpack.h similarity index 62% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/bitpack.h index f3b5c1be77..01e8b85e1d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/bitpack.h @@ -9,14 +9,14 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include #include namespace torchao { @@ -328,6 +328,60 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( assert(false); } } +template +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uintx_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3) { + static_assert(nbit < 9); + static_assert(nbit >= 1); + + // No shifting is needed because the data is already unsigned. + + switch (nbit) { + case 1: + // The internal functions are already designed to take uint8x16_t + torchao::bitpacking::internal::vec_pack_64_uint1_values( + packed, unpacked0, unpacked1, unpacked2, unpacked3); + break; + case 2: + torchao::bitpacking::internal::vec_pack_64_uint2_values( + packed, unpacked0, unpacked1, unpacked2, unpacked3); + break; + case 3: + torchao::bitpacking::internal::vec_pack_64_uint3_values( + packed, unpacked0, unpacked1, unpacked2, unpacked3); + break; + case 4: + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed, unpacked0, unpacked1); + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed + 16, unpacked2, unpacked3); + break; + case 5: + torchao::bitpacking::internal::vec_pack_64_uint5_values( + packed, unpacked0, unpacked1, unpacked2, unpacked3); + break; + case 6: + torchao::bitpacking::internal::vec_pack_64_uint6_values( + packed, unpacked0, unpacked1, unpacked2, unpacked3); + break; + case 7: + torchao::bitpacking::internal::vec_pack_64_uint7_values( + packed, unpacked0, unpacked1, unpacked2, unpacked3); + break; + case 8: + vst1q_u8(packed, unpacked0); + vst1q_u8(packed + 16, unpacked1); + vst1q_u8(packed + 32, unpacked2); + vst1q_u8(packed + 48, unpacked3); + break; + default: + assert(false); + } +} template TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( @@ -396,6 +450,107 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( } } +template +TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uintx_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1) { + // Ensure nbit is within the valid range [1, 8] + static_assert(nbit < 9); + static_assert(nbit >= 1); + + switch (nbit) { + case 1: { + // For 1-bit, we store the 32 values into a temporary buffer + // and then pack them in 8-value chunks. + uint8_t buffer[32]; + vst1q_u8(buffer, unpacked0); + vst1q_u8(buffer + 16, unpacked1); + + torchao::bitpacking::internal::pack_8_uint1_values(packed, buffer); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 1, buffer + 8); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 2, buffer + 16); + torchao::bitpacking::internal::pack_8_uint1_values( + packed + 3, buffer + 24); + break; + } + case 2: + // Use the existing vectorized implementation for 2-bit packing. + torchao::bitpacking::internal::vec_pack_32_uint2_values( + packed, + vget_low_u8(unpacked0), + vget_high_u8(unpacked0), + vget_low_u8(unpacked1), + vget_high_u8(unpacked1)); + break; + case 3: { + // For 3-bit, we store to a buffer and pack in 8-value chunks. + uint8_t buffer[32]; + vst1q_u8(buffer, unpacked0); + vst1q_u8(buffer + 16, unpacked1); + + torchao::bitpacking::internal::pack_8_uint3_values(packed, buffer); + torchao::bitpacking::internal::pack_8_uint3_values( + packed + 3, buffer + 8); + torchao::bitpacking::internal::pack_8_uint3_values( + packed + 6, buffer + 16); + torchao::bitpacking::internal::pack_8_uint3_values( + packed + 9, buffer + 24); + break; + } + case 4: + // Use the existing vectorized implementation for 4-bit packing. + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed, unpacked0, unpacked1); + break; + case 5: { + // For 5-bit, we store to a buffer and pack in 8-value chunks. + uint8_t buffer[32]; + vst1q_u8(buffer, unpacked0); + vst1q_u8(buffer + 16, unpacked1); + + torchao::bitpacking::internal::pack_8_uint5_values(packed, buffer); + torchao::bitpacking::internal::pack_8_uint5_values( + packed + 5, buffer + 8); + torchao::bitpacking::internal::pack_8_uint5_values( + packed + 10, buffer + 16); + torchao::bitpacking::internal::pack_8_uint5_values( + packed + 15, buffer + 24); + break; + } + case 6: + // Use the existing vectorized implementation for 6-bit packing. + torchao::bitpacking::internal::vec_pack_32_uint6_values( + packed, unpacked0, unpacked1); + break; + case 7: { + // For 7-bit, we store to a buffer and pack in 8-value chunks. + uint8_t buffer[32]; + vst1q_u8(buffer, unpacked0); + vst1q_u8(buffer + 16, unpacked1); + + torchao::bitpacking::internal::pack_8_uint7_values(packed, buffer); + torchao::bitpacking::internal::pack_8_uint7_values( + packed + 7, buffer + 8); + torchao::bitpacking::internal::pack_8_uint7_values( + packed + 14, buffer + 16); + torchao::bitpacking::internal::pack_8_uint7_values( + packed + 21, buffer + 24); + break; + } + case 8: + // For 8-bit, it's a direct memory store of the two vectors. + vst1q_u8(packed, unpacked0); + vst1q_u8(packed + 16, unpacked1); + break; + default: + // This should be unreachable due to the static_asserts + assert(false); + } +} + template TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uintx_values( uint8_t* packed, @@ -726,6 +881,258 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values_with_lut( unpacked7 = vqtbl1q_s8(lut, idx7); } +template +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uintx_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + static_assert(nbit < 9); + static_assert(nbit >= 1); + + switch (nbit) { + case 1: + torchao::bitpacking::internal::vec_unpack_64_uint1_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed); + break; + case 2: + torchao::bitpacking::internal::vec_unpack_64_uint2_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed); + break; + case 3: + torchao::bitpacking::internal::vec_unpack_64_uint3_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed); + break; + case 4: + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + unpacked0, unpacked1, packed); + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + unpacked2, unpacked3, packed + 16); + break; + case 5: + torchao::bitpacking::internal::vec_unpack_64_uint5_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed); + break; + case 6: + torchao::bitpacking::internal::vec_unpack_64_uint6_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed); + break; + case 7: + torchao::bitpacking::internal::vec_unpack_64_uint7_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed); + break; + case 8: + unpacked0 = vld1q_u8(packed); + unpacked1 = vld1q_u8(packed + 16); + unpacked2 = vld1q_u8(packed + 32); + unpacked3 = vld1q_u8(packed + 48); + break; + default: + assert(false); + } +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lut_indices( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + static_assert(nbit <= 8); + static_assert(nbit >= 1); + + if constexpr (nbit == 8) { + unpacked0 = vld1q_u8(packed + 0); + unpacked1 = vld1q_u8(packed + 16); + unpacked2 = vld1q_u8(packed + 32); + unpacked3 = vld1q_u8(packed + 48); + return; + } + + vec_unpack_64_uintx_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed); + + const uint8_t mask = (1 << nbit) - 1; + uint8x16_t mask_vec = vdupq_n_u8(mask); + + unpacked0 = vandq_u8(unpacked0, mask_vec); + unpacked1 = vandq_u8(unpacked1, mask_vec); + unpacked2 = vandq_u8(unpacked2, mask_vec); + unpacked3 = vandq_u8(unpacked3, mask_vec); +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uintx_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + const uint8_t* packed) { + static_assert(nbit < 9); + static_assert(nbit >= 1); + + uint8x16_t shifted0 = vdupq_n_u8(0); + uint8x16_t shifted1 = vdupq_n_u8(0); + + switch (nbit) { + case 1: + uint8_t buffer1[32]; + torchao::bitpacking::internal::unpack_8_uint1_values(buffer1, packed); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 8, packed + 1); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 16, packed + 2); + torchao::bitpacking::internal::unpack_8_uint1_values( + buffer1 + 24, packed + 3); + shifted0 = vld1q_u8(buffer1); + shifted1 = vld1q_u8(buffer1 + 16); + break; + case 2: + uint8x8_t shifted0_low; + uint8x8_t shifted0_high; + uint8x8_t shifted1_low; + uint8x8_t shifted1_high; + torchao::bitpacking::internal::vec_unpack_32_uint2_values( + shifted0_low, shifted0_high, shifted1_low, shifted1_high, packed); + shifted0 = vcombine_u8(shifted0_low, shifted0_high); + shifted1 = vcombine_u8(shifted1_low, shifted1_high); + break; + case 3: + uint8_t buffer3[32]; + torchao::bitpacking::internal::unpack_8_uint3_values(buffer3, packed); + torchao::bitpacking::internal::unpack_8_uint3_values( + buffer3 + 8, packed + 3); + torchao::bitpacking::internal::unpack_8_uint3_values( + buffer3 + 16, packed + 6); + torchao::bitpacking::internal::unpack_8_uint3_values( + buffer3 + 24, packed + 9); + shifted0 = vld1q_u8(buffer3); + shifted1 = vld1q_u8(buffer3 + 16); + break; + case 4: + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + shifted0, shifted1, packed); + break; + case 5: + uint8_t buffer5[32]; + torchao::bitpacking::internal::unpack_8_uint5_values(buffer5, packed); + torchao::bitpacking::internal::unpack_8_uint5_values( + buffer5 + 8, packed + 5); + torchao::bitpacking::internal::unpack_8_uint5_values( + buffer5 + 16, packed + 10); + torchao::bitpacking::internal::unpack_8_uint5_values( + buffer5 + 24, packed + 15); + shifted0 = vld1q_u8(buffer5); + shifted1 = vld1q_u8(buffer5 + 16); + break; + case 6: + torchao::bitpacking::internal::vec_unpack_32_uint6_values( + shifted0, shifted1, packed); + break; + case 7: + uint8_t buffer7[32]; + torchao::bitpacking::internal::unpack_8_uint7_values(buffer7, packed); + torchao::bitpacking::internal::unpack_8_uint7_values( + buffer7 + 8, packed + 7); + torchao::bitpacking::internal::unpack_8_uint7_values( + buffer7 + 16, packed + 14); + torchao::bitpacking::internal::unpack_8_uint7_values( + buffer7 + 24, packed + 21); + shifted0 = vld1q_u8(buffer7); + shifted1 = vld1q_u8(buffer7 + 16); + break; + case 8: + shifted0 = vld1q_u8(packed); + shifted1 = vld1q_u8(packed + 16); + break; + default: + assert(false); + } + unpacked0 = shifted0; + unpacked1 = shifted1; +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lut_indices( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + const uint8_t* packed) { + static_assert(nbit <= 8); + static_assert(nbit >= 1); + + // For 8-bit, the data is already unpacked. Just load directly. + if constexpr (nbit == 8) { + unpacked0 = vld1q_u8(packed + 0); + unpacked1 = vld1q_u8(packed + 16); + return; + } + + // 1. Call the internal helper to get the raw unpacked values. + vec_unpack_32_uintx_values(unpacked0, unpacked1, packed); + + // 2. Apply the bitmask to get the final, correct indices for a LUT. + const uint8_t mask = (1 << nbit) - 1; + uint8x16_t mask_vec = vdupq_n_u8(mask); + + unpacked0 = vandq_u8(unpacked0, mask_vec); + unpacked1 = vandq_u8(unpacked1, mask_vec); +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lut_indices( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + uint8x16_t& unpacked4, + uint8x16_t& unpacked5, + uint8x16_t& unpacked6, + uint8x16_t& unpacked7, + const uint8_t* packed) { + // Unpacks 128 tightly packed n-bit values into 8-bit LUT indices using ARM + // NEON. For n-bit < 8, this function first spreads the bits into bytes and + // then applies a mask to zero out the unused upper bits, ensuring each index + // is valid. For the n-bit == 8 case, it's a direct memory load, as no + // unpacking is needed. + + static_assert(nbit <= 8); + static_assert(nbit >= 1); + + // For 8-bit, the data is already unpacked. Just load directly. + if constexpr (nbit == 8) { + unpacked0 = vld1q_u8(packed + 0); + unpacked1 = vld1q_u8(packed + 16); + unpacked2 = vld1q_u8(packed + 32); + unpacked3 = vld1q_u8(packed + 48); + unpacked4 = vld1q_u8(packed + 64); + unpacked5 = vld1q_u8(packed + 80); + unpacked6 = vld1q_u8(packed + 96); + unpacked7 = vld1q_u8(packed + 112); + return; + } + + vec_unpack_128_uintx_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed); + const uint8_t mask = (1 << nbit) - 1; + uint8x16_t mask_vec = vdupq_n_u8(mask); + + unpacked0 = vandq_u8(unpacked0, mask_vec); + unpacked1 = vandq_u8(unpacked1, mask_vec); + unpacked2 = vandq_u8(unpacked2, mask_vec); + unpacked3 = vandq_u8(unpacked3, mask_vec); + unpacked4 = vandq_u8(unpacked4, mask_vec); + unpacked5 = vandq_u8(unpacked5, mask_vec); + unpacked6 = vandq_u8(unpacked6, mask_vec); + unpacked7 = vandq_u8(unpacked7, mask_vec); +} } // namespace bitpacking } // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint1.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint1.h index de999a53d6..d24425745e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint1.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint1. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint2.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint2.h index 630bc22798..b4874154e1 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint2.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint4. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint3.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint3.h index a808ee3a27..6063c12008 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint3.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint3. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint4.h similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint4.h index fba626ea57..2a36f3c429 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint4.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint4. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint5.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint5.h index 456706b76a..4771bab584 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint5.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint5. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint6.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint6.h index d15094ddfb..3ae83fab09 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint6.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint5. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint7.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint7.h index 1fc2a8d5cb..f1130c89bd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint7.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint7. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding.h similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding.h index c750b6d534..0f6d8a2339 100644 --- a/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding.h @@ -9,9 +9,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include +#include +#include +#include #include #include diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding_lut.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding_lut.h new file mode 100644 index 0000000000..573fc8020d --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding_lut.h @@ -0,0 +1,382 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::embedding { + +/** + * @brief Calculates the size in bytes for a single row of packed embeddings. + * + * This function computes the memory stride for one row, accounting for three + * components: + * 1. Bit-packed weight indices. + * 2. Optional, group-quantized scales. + * 3. Padded look-up tables (LUTs). + * + * @param weight_nbit The number of bits for each weight index (e.g., 2, 4). + * @param embedding_dim The dimension of the embedding vector (i.e., number of + * weights per row). + * @param scale_group_size The number of weights that share a single + * quantization scale. + * @param lut_group_size The number of weights that share a single look-up + * table. + * @param has_scales A flag indicating whether quantization scales are stored. + * @return The total size in bytes (stride) for one packed row. + */ +inline size_t packed_embedding_size_per_row( + int weight_nbit, + int embedding_dim, + int scale_group_size, + int lut_group_size, + bool has_scales) { + // We need to account for the padding of the LUTs. The LUTs are padded to 16 + // floats (64 bytes) for alignment. + constexpr int kLutPaddedSize = 16; + // Number of LUTs per row, it could be 1 or more LUTs per row. + const int lut_per_row = (embedding_dim + lut_group_size - 1) / lut_group_size; + // LUT size in bytes + const int lut_bytes = lut_per_row * kLutPaddedSize * sizeof(float); + + // Scales are packed if has_scales is true. + const int scales_per_row = + (embedding_dim + scale_group_size - 1) / scale_group_size; + const int scale_bytes = has_scales ? (scales_per_row * sizeof(float)) : 0; + + // The indices are bit-packed. + const int index_bytes = (embedding_dim * weight_nbit + 7) / 8; + + const size_t packed_row_stride = lut_bytes + scale_bytes + index_bytes; + return packed_row_stride; +} + +/** + * @brief Calculates the total size in bytes for an entire table of packed + * embeddings. + * + * This is a convenience function that multiplies the size of a single packed + * row by the total number of embeddings (rows) to find the total memory + * required. + * + * @param weight_nbit The number of bits for each weight index. + * @param num_embeddings The total number of rows (embeddings) in the weight + * table. + * @param embedding_dim The dimension of the embedding vector. + * @param scale_group_size The number of weights sharing a single scale. + * @param lut_group_size The number of weights sharing a single LUT. + * @param has_scales A flag indicating if scales are present. + * @return The total size in bytes required for the entire packed weight table. + */ +inline size_t packed_embedding_size( + int weight_nbit, + int num_embeddings, + int embedding_dim, + int scale_group_size, + int lut_group_size, + bool has_scales) { + // Pass the correct arguments to the helper function. + return num_embeddings * + packed_embedding_size_per_row( + weight_nbit, + embedding_dim, + scale_group_size, + lut_group_size, + has_scales); +} + +template +inline void pack_embedding_row_at_index_lut( + // Destination + void* packed_table, + int index, + // Source Tables + const uint8_t* source_indices_table, + const float* source_scales_table, + const float* source_luts_table, + // Dimensions + int num_embeddings, + int embedding_dim, + int scale_group_size, + int lut_group_size, + bool has_scales) { + assert(index >= 0 && index < num_embeddings); + assert(embedding_dim > 0 && embedding_dim % 32 == 0); + + // 1. Calculate the stride of one packed row (for the destination table) + constexpr int kLutPaddedSize = 16; + const int lut_size = 1 << weight_nbit; + const int lut_per_row = (embedding_dim + lut_group_size - 1) / lut_group_size; + const int scales_per_row = + (embedding_dim + scale_group_size - 1) / scale_group_size; + + const size_t packed_row_stride = packed_embedding_size_per_row( + weight_nbit, embedding_dim, scale_group_size, lut_group_size, has_scales); + + constexpr int bytes_per_packed_128_values = (128 * weight_nbit) / 8; + constexpr int bytes_per_packed_64_values = (64 * weight_nbit) / 8; + constexpr int bytes_per_packed_32_values = (32 * weight_nbit) / 8; + // 2. Calculate the starting pointer for the destination row + uint8_t* output_ptr = reinterpret_cast(packed_table) + + (static_cast(index) * packed_row_stride); + + // --- 3. Calculate the starting pointers for the SOURCE data row --- + // This is the key change to support 1D indexing. + const size_t linear_idx_start_of_row = + static_cast(index) * embedding_dim; + + // Find the global group index for the start of our row. + const size_t start_lut_group_idx = linear_idx_start_of_row / lut_group_size; + const size_t start_scale_group_idx = + linear_idx_start_of_row / scale_group_size; + + const uint8_t* source_indices_for_row = + source_indices_table + linear_idx_start_of_row; + const float* source_scales_for_row = + source_scales_table + start_scale_group_idx; + const float* source_luts_for_row = + source_luts_table + start_lut_group_idx * lut_size; + + // 4. Pack LUTs + std::vector lut_buffer(kLutPaddedSize, 0.0f); + for (int i = 0; i < lut_per_row; i++) { + std::memcpy( + lut_buffer.data(), + source_luts_for_row + i * lut_size, + lut_size * sizeof(float)); + std::memcpy(output_ptr, lut_buffer.data(), kLutPaddedSize * sizeof(float)); + output_ptr += kLutPaddedSize * sizeof(float); + } + + // 5. Pack Scales + if (has_scales) { + std::memcpy( + output_ptr, source_scales_for_row, scales_per_row * sizeof(float)); + output_ptr += scales_per_row * sizeof(float); + } + + // 6. Pack Weight Indices (Quantized Values) + int i = 0; + // Process in chunks of 128 + for (; i + 128 <= embedding_dim; i += 128) { + uint8x16_t qvals0 = vld1q_u8(source_indices_for_row + i); + uint8x16_t qvals1 = vld1q_u8(source_indices_for_row + i + 16); + uint8x16_t qvals2 = vld1q_u8(source_indices_for_row + i + 32); + uint8x16_t qvals3 = vld1q_u8(source_indices_for_row + i + 48); + uint8x16_t qvals4 = vld1q_u8(source_indices_for_row + i + 64); + uint8x16_t qvals5 = vld1q_u8(source_indices_for_row + i + 80); + uint8x16_t qvals6 = vld1q_u8(source_indices_for_row + i + 96); + uint8x16_t qvals7 = vld1q_u8(source_indices_for_row + i + 112); + + torchao::bitpacking::vec_pack_128_uintx_values( + output_ptr, + qvals0, + qvals1, + qvals2, + qvals3, + qvals4, + qvals5, + qvals6, + qvals7); + output_ptr += bytes_per_packed_128_values; + } + + // Process in chunks of 64 + if (i + 64 <= embedding_dim) { + uint8x16_t qvals0 = vld1q_u8(source_indices_for_row + i); + uint8x16_t qvals1 = vld1q_u8(source_indices_for_row + i + 16); + uint8x16_t qvals2 = vld1q_u8(source_indices_for_row + i + 32); + uint8x16_t qvals3 = vld1q_u8(source_indices_for_row + i + 48); + + torchao::bitpacking::vec_pack_64_uintx_values( + output_ptr, qvals0, qvals1, qvals2, qvals3); + output_ptr += bytes_per_packed_64_values; + i += 64; + } + + // Process in chunks of 32 + if (i + 32 <= embedding_dim) { + uint8x16_t qvals0 = vld1q_u8(source_indices_for_row + i); + uint8x16_t qvals1 = vld1q_u8(source_indices_for_row + i + 16); + torchao::bitpacking::vec_pack_32_uintx_values( + output_ptr, qvals0, qvals1); + output_ptr += bytes_per_packed_32_values; + i += 32; + } + + assert(i == embedding_dim); // Final check: Ensure all elements were processed +} + +/** + * @brief Reads a single embedding vector from the packed format and dequantizes + * it. + * + * @tparam weight_nbit The number of bits used for the quantized weights (e.g., + * 2, 4). + * @param out Pointer to the output buffer for the dequantized float vector. + * Must have space for `embedding_dim` floats. + * @param packed_data Pointer to the beginning of the entire packed embedding + * table. + * @param index The row index of the embedding vector to retrieve. + * @param num_embeddings The total number of embeddings in the table (for + * boundary checks). + * @param embedding_dim The dimension of a single embedding vector. + * @param scale_group_size The number of values sharing a single scale. + * @param lut_group_size The number of values sharing a single LUT. + * @param has_scales A flag indicating if scales were packed. + */ +template +inline void dequantize_embedding_row_at_idx_lut( + // Output + float* out, + // Inputs + const void* packed_data, + int index, + int num_embeddings, + int embedding_dim, + int scale_group_size, + int lut_group_size, + bool has_scales) { + assert(index >= 0 && index < num_embeddings); + assert(embedding_dim > 0 && embedding_dim % 32 == 0); + + // 1. Calculate the total size (stride) of one packed embedding row + + // LUTs are padded to 16 floats (64 bytes) for alignment. + constexpr int kLutPaddedSize = 16; + const int lut_per_row = (embedding_dim + lut_group_size - 1) / lut_group_size; + const int lut_bytes = lut_per_row * kLutPaddedSize * sizeof(float); + + // Scales are packed if has_scales is true. + const int scales_per_row = + (embedding_dim + scale_group_size - 1) / scale_group_size; + const int scale_bytes = has_scales ? (scales_per_row * sizeof(float)) : 0; + + // The indices are bit-packed. + const int index_bytes = (embedding_dim * weight_nbit) / 8; + + const size_t total_row_stride = lut_bytes + scale_bytes + index_bytes; + + // 2. Calculate the memory offset to the start of the desired row + const uint8_t* row_start_ptr = reinterpret_cast(packed_data) + + (static_cast(index) * total_row_stride); + + // 3. Get pointers to the LUTs, scales, and packed indices for this row + const float* luts_ptr = reinterpret_cast(row_start_ptr); + const float* scales_ptr = has_scales + ? reinterpret_cast(row_start_ptr + lut_bytes) + : nullptr; + const uint8_t* packed_indices_ptr = row_start_ptr + lut_bytes + scale_bytes; + + // 4. Unpack the n-bit indices into a temporary 8-bit buffer + std::vector unpacked_indices(embedding_dim); + const uint8_t* read_ptr = packed_indices_ptr; + uint8_t* write_ptr = unpacked_indices.data(); + int i = 0; + + constexpr int bytes_per_packed_128_values = (128 * weight_nbit) / 8; + constexpr int bytes_per_packed_64_values = (64 * weight_nbit) / 8; + constexpr int bytes_per_packed_32_values = (32 * weight_nbit) / 8; + + // Process in chunks of 128 + for (; i + 128 <= embedding_dim; i += 128) { + // 1. Declare NEON registers for the output + uint8x16_t u0, u1, u2, u3, u4, u5, u6, u7; + // 2. Unpack directly into the registers + torchao::bitpacking::vec_unpack_128_lut_indices( + u0, u1, u2, u3, u4, u5, u6, u7, read_ptr); + // 3. Store the results from registers to memory + vst1q_u8(write_ptr + 0, u0); + vst1q_u8(write_ptr + 16, u1); + vst1q_u8(write_ptr + 32, u2); + vst1q_u8(write_ptr + 48, u3); + vst1q_u8(write_ptr + 64, u4); + vst1q_u8(write_ptr + 80, u5); + vst1q_u8(write_ptr + 96, u6); + vst1q_u8(write_ptr + 112, u7); + + write_ptr += 128; + read_ptr += bytes_per_packed_128_values; + } + + // Process in chunks of 64 + if (i + 64 <= embedding_dim) { + uint8x16_t u0, u1, u2, u3; + torchao::bitpacking::vec_unpack_64_lut_indices( + u0, u1, u2, u3, read_ptr); + vst1q_u8(write_ptr + 0, u0); + vst1q_u8(write_ptr + 16, u1); + vst1q_u8(write_ptr + 32, u2); + vst1q_u8(write_ptr + 48, u3); + + write_ptr += 64; + read_ptr += bytes_per_packed_64_values; + i += 64; + } + + // Process in chunks of 32 + if (i + 32 <= embedding_dim) { + uint8x16_t u0, u1; + torchao::bitpacking::vec_unpack_32_lut_indices( + u0, u1, read_ptr); + vst1q_u8(write_ptr + 0, u0); + vst1q_u8(write_ptr + 16, u1); + + write_ptr += 32; + read_ptr += bytes_per_packed_32_values; + i += 32; + } + + assert(i == embedding_dim); + // Dequantize using vectorized LUT lookup + for (int j = 0; j < embedding_dim; j += 16) { + // Identify and load the LUT for this 16-element chunk. + // Since lut_group_size % 16 == 0, all 16 elements use the same LUT. + const int lut_group_idx = j / lut_group_size; + const float* current_lut_ptr = luts_ptr + lut_group_idx * kLutPaddedSize; + uint8x16x4_t lut_neon; + torchao::lut::load_fp32_lut(lut_neon, current_lut_ptr); + + // Load the 16 indices to be looked up. + uint8x16_t indices_neon = vld1q_u8(unpacked_indices.data() + j); + + // Perform the vectorized lookup. The results are in out0..3. + float32x4_t out0, out1, out2, out3; + torchao::lut::lookup_from_fp32_lut( + out0, out1, out2, out3, lut_neon, indices_neon); + float scale_val = 1.0f; + // Apply scales vectorially. + if (has_scales) { + // Since scale_group_size % 16 == 0, all 16 elements use the same scale. + const int scale_group_idx = j / scale_group_size; + scale_val = scales_ptr[scale_group_idx]; + // Load the single scale value into all 4 lanes of a vector register. + float32x4_t scale_vec = vdupq_n_f32(scale_val); + + // Multiply the looked-up values by the scale. + out0 = vmulq_f32(out0, scale_vec); + out1 = vmulq_f32(out1, scale_vec); + out2 = vmulq_f32(out2, scale_vec); + out3 = vmulq_f32(out3, scale_vec); + } + + // Store the final 16 float results back to the output buffer. + vst1q_f32(out + j + 0, out0); + vst1q_f32(out + j + 4, out1); + vst1q_f32(out + j + 8, out2); + vst1q_f32(out + j + 12, out3); + } +} +} // namespace torchao::kernels::cpu::aarch64::embedding + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index aa338fc165..777d73cebc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -28,7 +28,7 @@ #include #endif // TORCHAO_ENABLE_ARM_I8MM -#include +#include namespace torchao::kernels::cpu::aarch64::kleidi { diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/pack.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/pack.h diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h similarity index 90% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h index ce0ac804c9..849d99cb8a 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -10,12 +10,12 @@ #include #include -#include -#include +#include +#include -#include -#include -#include +#include +#include +#include namespace torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h index 1d48f6f2b0..535bf7a084 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include namespace torchao::kernels::cpu::aarch64::linear:: diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h index e2bb78d385..40be2c5231 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h index 7a53c7302c..78246e211d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h similarity index 95% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h index 5967c5b14e..d7558dd4ce 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h @@ -8,8 +8,8 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include -#include +#include +#include #include namespace torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight::activation_packing { diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h similarity index 90% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h index aece38b435..133c4a7f25 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h @@ -2,9 +2,10 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include -#include -#include +#include +#include +#include +#include #include #include @@ -125,61 +126,6 @@ TORCHAO_ALWAYS_INLINE inline void unpack_buffer( assert(false); } -// Packs nr * kr values for GEMM with packing params (nr, kr, sr) -// It takes (kr / sr) values from each of nr columns and writes to packed_values -// This is repeated sr times -template -void pack_values( - // Output - T* packed_values, - // Inputs - const T* values, - int nr, - int kr, - int sr) { - assert(kr % sr == 0); - int kr_per_sr = kr / sr; - int dst_idx = 0; - for (int sr_idx = 0; sr_idx < sr; sr_idx++) { - for (int n_idx = 0; n_idx < nr; n_idx++) { - // Take kr_per_sr values from column n_idx - std::memcpy( - packed_values + dst_idx, - values + n_idx * kr + sr_idx * kr_per_sr, - sizeof(T) * kr_per_sr); - dst_idx += kr_per_sr; - } - } -} - -// Undoes pack_values -template -void unpack_values( - // Output - T* values, - // Inputs - const T* packed_values, - int nr, - int kr, - int sr) { - // packed_values and values should have size nr * kr - // This function takes (kr / sr) from each column of nr columns and writes to - // output This is repeated sr times - assert(kr % sr == 0); - int kr_per_sr = kr / sr; - int dst_idx = 0; - for (int sr_idx = 0; sr_idx < sr; sr_idx++) { - for (int n_idx = 0; n_idx < nr; n_idx++) { - // Take kr_per_sr values from column n_idx - std::memcpy( - values + n_idx * kr + sr_idx * kr_per_sr, - packed_values + dst_idx, - sizeof(T) * kr_per_sr); - dst_idx += kr_per_sr; - } - } -} - // Size in bytes of 1 packed weights column size_t inline packed_weights_size_per_n( int k, @@ -344,7 +290,7 @@ TORCHAO_ALWAYS_INLINE inline void pack_weights_impl( } // Pack buffer - internal::pack_values(packed_values, buffer.data(), nr, kr, sr); + torchao::packing::pack_values(packed_values, buffer.data(), nr, kr, sr); if constexpr (has_lut) { internal::pack_buffer_for_lut( packed_weights_byte_ptr, packed_values); @@ -498,7 +444,7 @@ void unpack_weights_at_n_idx( internal::unpack_buffer( packed_values, packed_weights_byte_ptr); packed_weights_byte_ptr += packed_buffer_bytes; - internal::unpack_values(buffer.data(), packed_values, nr, kr, sr); + torchao::packing::unpack_values(buffer.data(), packed_values, nr, kr, sr); // Write weight_qvals for (int j = 0; j < nr; j++) { diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h new file mode 100644 index 0000000000..b0fea65afb --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h @@ -0,0 +1,263 @@ +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include + +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut { + +/** + * @brief Calculates the total size in bytes required for the packed weight. + * + * @param m The number of rows in the source activation matrix. + * @param k The number of columns in the source activation matrix. + * @param mr The row-tiling factor of the micro-kernel. + * @param kr The column-tiling factor of the micro-kernel. + * @param sr The split ratio of the micro-kernel. + */ +inline size_t packed_activations_size(int m, int k, int mr, int kr, int sr) { + (void)mr; // unused + (void)kr; // unused + (void)sr; // unused + return activation_packing::packed_activations_size(m, k); +} + +/** + * @brief Packs a row-major activation matrix into a kernel-optimized blocked +layout. + * + * @tparam mr_ The row-tiling factor of the micro-kernel (Currently only have +1). + * @tparam kr_ The column-tiling factor of the micro-kernel (e.g., 32). + * @tparam sr_ Split ratio determine how the k dimension of a weight tile is +chunked and interleaved during the packing process. + * @param output Pointer to the destination buffer. + * @param m The number of rows in the source activation matrix. + * @param k The number of columns in the source activation matrix. + * @param input Pointer to the source activation matrix (float32, row-major). + */ +template +inline void pack_activations( + float* output, + int m, + int k, + const float* input, + int mr, + int kr, + int sr) { + (void)mr; // unused + (void)kr; // unused + (void)sr; // unused { + activation_packing::pack_activations(output, m, k, input); +} + +/** + * @brief Calculates the total size in bytes required for the packed weight + * buffer for the groupwise LUT kernel format. + * + * @param n The number of columns in the weight matrix. + * @param k The number of rows in the weight matrix. + * @param weight_nbit The number of bits per weight (e.g., 2, 3, 4). + * @param scale_group_size The number of weights along the K dim that share a + * scale factor. + * @param has_scales If true, the packed buffer will contain scale factors. + * @param has_bias If true, the packed buffer will contain bias terms. + * @param nr The column-tiling factor for the kernel (e.g., 16). + * @param kr The column-tiling factor for the kernel (e.g., 16). + * @param sr The split ratio of the micro-kernel. + * @return The total required size of the packed buffer in bytes. + */ +inline size_t packed_weights_size( + int n, + int k, + int weight_nbit, + int scale_group_size, + bool has_scales, + bool has_bias, + int nr, + int kr, + int sr) { + (void)sr; // unused + return weight_packing::packed_weights_size( + n, k, weight_nbit, scale_group_size, has_scales, has_bias, nr, kr); +} + +/** + * @brief Packs weights, LUTs, scales and bias into a kernel-optimized format. + * @tparam weight_nbit_ The true bit-width of the weights. + * @tparam nr_ The column-tiling factor for the kernel (e.g., 4). + * @tparam kr_ The column-tiling factor of the micro-kernel (e.g., 32). + * @tparam sr_ Split ratio determine how the k dimension of a weight tile is +chunked and interleaved during the packing process. + * @param packed_weights_ptr Pointer to the destination buffer. + * @param weight_qvals_indices Pointer to the quantized weight matrix (uint8, +row-major). + * @param weight_scales Pointer to the scale factors (float32, row-major). + * @param weight_luts Pointer to the LUTs (float32, row-major). + * @param n The number of columns in the weight matrix. + * @param k The number of rows in the weight matrix. + * @param scale_group_size The number of weights that share a scale factor. + * @param lut_group_size The number of weights that share a LUT. + * @param has_scales If true, the packed buffer will contain scale factors. + * @param has_bias If true, the packed buffer will contain bias terms. + * @param bias Pointer to the bias vector (float32, row-major). + */ +template +void pack_weights( + /*output*/ + void* packed_weights_ptr, + /*inputs*/ + const uint8_t* weight_qvals_indices, + const float* weight_scales, + const float* weight_luts, + int n, + int k, + int scale_group_size, + int lut_group_size, + bool has_scales, + bool has_bias, + const float* bias, + int nr, + int kr, + int sr) { + (void)nr; // unused + (void)kr; // unused + (void)sr; // unused + weight_packing::pack_weights( + packed_weights_ptr, + weight_qvals_indices, + weight_scales, + weight_luts, + n, + k, + scale_group_size, + lut_group_size, + has_scales, + has_bias, + bias); +} + +/** + * @brief Computes a group-wise low-bit GEMM using an optimized NEON kernel. + * + * This function selects the best available micro-kernel based on the provided + * tile sizes (MR and NR) and dispatches the computation. + * @tparam weight_nbit_ The true bit-width of the weights (e.g., 2, 3, 4). + * @tparam has_scales_ If true, applies the scales. + * @param output Pointer to the output matrix C. + * @param output_m_stride The stride (in elements) between rows of the output + * matrix. + * @param m Number of rows in A and C. + * @param n Number of columns in B and C. + * @param k Number of columns in A and rows in B. + * @param scale_group_size The grouping factor for scales. + * @param lut_group_size The grouping factor for LUTs. + * @param packed_weights Pointer to the pre-packed weight buffer. + * @param packed_activations Pointer to the pre-packed activation buffer. + * @param biases Pointer to the bias vector. + * @param clamp_min Minimum value for the fused clamp (ReLU) operation. + * @param clamp_max Maximum value for the fused clamp (ReLU6) operation. + * @param has_bias If true, applies the bias. + * @param has_clamp If true, applies the clamping. + */ +template +inline void groupwise_lowbit_weight_lut_kernel_1x4x32( + float* output, + int output_m_stride, + int m, + int n, + int k, + int scale_group_size, + int lut_group_size, + const void* packed_weights, + const void* packed_activations, + float clamp_min, + float clamp_max, + bool has_bias, + bool has_clamp) { + kernel::groupwise_lowbit_weight_lut_kernel_1x4x32( + output, + output_m_stride, + m, + n, + k, + scale_group_size, + lut_group_size, + packed_weights, + packed_activations, + clamp_min, + clamp_max, + has_bias, + has_clamp); +} + +/** + * @brief Calculates the byte offset for a specific row in the packed activation + * buffer. + * + * @param m_idx The row index for which to calculate the offset. + * @param k The K dimension (width) of the activation matrix. + * @return The byte offset from the start of the buffer. + */ +inline size_t +packed_activations_offset(int m_idx, int k, int mr, int kr, int sr) { + (void)mr; // unused + (void)kr; // unused + (void)sr; // unused + // For a simple padded row-major format, the offset is just m_idx * k. + return sizeof(float) * m_idx * k; +} + +/** + * @brief Calculates the byte offset for a given column index in the packed + * weights buffer. The buffer is assumed to be laid out as a series of + * contiguous blocks, where each block contains `nr` packed columns. + * + * @param n_idx The starting column index of the tile. Must be a multiple of + * `nr`. + * @param k The inner dimension of the matrix. + * @param weight_nbit The number of bits for the quantized weights. + * @param has_scales Whether weight scales are present. + * @param has_bias Whether a bias vector is packed. + * @param nr The micro-kernel tiling parameter for the N dimension. + * @param kr The micro-kernel tiling parameter for the K dimension. + * @return The byte offset into the packed weights buffer. + */ +inline size_t packed_weights_offset( + int n_idx, + int k, + int weight_nbit, + int scale_group_size, + bool has_scales, + bool has_bias, + int nr, + int kr, + int sr) { + (void)sr; // unused + assert(n_idx % nr == 0); + + const size_t packed_tile_size_for_nr_cols = packed_weights_size( + /*n=*/nr, // The size we are calculating is for a single tile of width + // `nr`. + k, + weight_nbit, + scale_group_size, + has_scales, + has_bias, + nr, + kr, + sr); + + return (n_idx / nr) * packed_tile_size_for_nr_cols; +} +} // namespace + // torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h new file mode 100644 index 0000000000..b50c886d11 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h @@ -0,0 +1,239 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +#if defined(aarch64) || defined(__ARM_NEON) +#include +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut:: + kernel { + +namespace lut_utils = torchao::lut; +namespace weight_packing = torchao::kernels::cpu::aarch64::linear:: + groupwise_lowbit_weight_lut::weight_packing; + +namespace internal { + +/* + * @brief Computes a single tile of the output matrix. + * @tparam weight_nbit_ The bit-precision of the quantized weight indices. + * @tparam has_scales A compile-time flag to enable the application of scales. + * + * @param accum A NEON vector of 4 floats used as an in-out accumulator. + * @param activation_tile_ptr Pointer to the 32-float activation tile. + * @param packed_indices_ptr Pointer to the bit-packed weight indices. + * @param lut_neon The dequantization LUT, pre-formatted for NEON lookups. + * @param scale_vec A NEON vector with the four dequantization scales. + */ +template +TORCHAO_ALWAYS_INLINE static inline void compute_tile_1x4x32( + float32x4_t& accum, + const float* __restrict__ activation_tile_ptr, + const uint8_t* __restrict__ packed_indices_ptr, + const uint8x16x4_t& lut_neon, + const float32x4_t scale_vec) { + // 1. Unpack indices + uint8x16_t idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7; + bitpacking::vec_unpack_128_uintx_values( + idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7, packed_indices_ptr); + + const std::array unpacked_indices = { + idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7}; + + for (int sr_idx = 0; sr_idx < 8; ++sr_idx) { + // Load the 4 activations corresponding to this chunk + const float* activation_chunk_ptr = activation_tile_ptr + sr_idx * 4; + float32x4_t a = vld1q_f32(activation_chunk_ptr); + + // Lookup the 4x4 weight sub-tile (as columns) + float32x4_t w_col0, w_col1, w_col2, w_col3; + lut_utils::lookup_from_fp32_lut( + w_col0, w_col1, w_col2, w_col3, lut_neon, unpacked_indices[sr_idx]); + + float32x4x2_t tmp0 = vtrnq_f32(w_col0, w_col1); + float32x4x2_t tmp1 = vtrnq_f32(w_col2, w_col3); + float32x4_t w_row0 = + vcombine_f32(vget_low_f32(tmp0.val[0]), vget_low_f32(tmp1.val[0])); + float32x4_t w_row1 = + vcombine_f32(vget_low_f32(tmp0.val[1]), vget_low_f32(tmp1.val[1])); + float32x4_t w_row2 = + vcombine_f32(vget_high_f32(tmp0.val[0]), vget_high_f32(tmp1.val[0])); + float32x4_t w_row3 = + vcombine_f32(vget_high_f32(tmp0.val[1]), vget_high_f32(tmp1.val[1])); + + // Conditionally apply scales at compile time + if constexpr (has_scales) { + w_row0 = vmulq_f32(w_row0, scale_vec); + w_row1 = vmulq_f32(w_row1, scale_vec); + w_row2 = vmulq_f32(w_row2, scale_vec); + w_row3 = vmulq_f32(w_row3, scale_vec); + } + + // Use vfmaq_n_f32 to multiply each row vector by the corresponding scalar + // activation. + accum = vfmaq_n_f32( + accum, w_row0, vgetq_lane_f32(a, 0)); // accum += w_row0 * a[0] + accum = vfmaq_n_f32( + accum, w_row1, vgetq_lane_f32(a, 1)); // accum += w_row1 * a[1] + accum = vfmaq_n_f32( + accum, w_row2, vgetq_lane_f32(a, 2)); // accum += w_row2 * a[2] + accum = vfmaq_n_f32( + accum, w_row3, vgetq_lane_f32(a, 3)); // accum += w_row3 * a[3] + } +} + +/** + * @brief Stores the accumulated values to the output matrix. + * @tparam mr_ The row-tiling factor of the micro-kernel. + * @tparam nr_ The column-tiling factor of the micro-kernel. + * + * @param output The output matrix. + * @param ldc The leading dimension of the output matrix. + * @param n_cols The number of columns in the output matrix. + * @param n_tile_start The starting column index of the current tile. + * @param accum The accumulated values. + * @param bias_ptr The pointer to the bias vector. + * @param has_clamp Whether to apply clamping. + * @param clamp_min_vec The minimum value for clamping. + * @param clamp_max_vec The maximum value for clamping. + */ +template +TORCHAO_ALWAYS_INLINE static inline void post_process_and_store( + float* __restrict__ output, + int ldc, + int n_cols, + int n_tile_start, + const float32x4_t accum[mr_][nr_ / 4], + const float* __restrict__ bias_ptr, + bool has_clamp, + const float32x4_t& clamp_min_vec, + const float32x4_t& clamp_max_vec) { + constexpr int NR_VEC = nr_ / 4; + for (int m = 0; m < mr_; ++m) { + float* out_row = output + m * ldc; + for (int nb = 0; nb < NR_VEC; ++nb) { + float32x4_t res = accum[m][nb]; + if (bias_ptr != nullptr) { + float32x4_t bias_vec = vld1q_f32(bias_ptr + nb * 4); + res = vaddq_f32(res, bias_vec); + } + if (has_clamp) { + res = vmaxq_f32(res, clamp_min_vec); + res = vminq_f32(res, clamp_max_vec); + } + + const int current_n_offset = n_tile_start + nb * 4; + const int remaining_cols = n_cols - current_n_offset; + if (remaining_cols < 4) { + float temp_res[4]; + vst1q_f32(temp_res, res); + for (int i = 0; i < remaining_cols; ++i) { + *(out_row + current_n_offset + i) = temp_res[i]; + } + } else { + vst1q_f32(out_row + current_n_offset, res); + } + } + } +} + +} // namespace internal + +/* + * @brief The main kernel for groupwise low-bit weight LUT. + */ +template +void groupwise_lowbit_weight_lut_kernel_1x4x32( + float* output, + int output_m_stride, + int m, + int n, + int k, + int scale_group_size, + int lut_group_size, + const void* packed_weights, + const void* packed_activations, + float clamp_min, + float clamp_max, + bool has_bias, + bool has_clamp) { + constexpr int mr_ = 1; + constexpr int nr_ = 4; + constexpr int kr_ = 32; + + const auto* typed_activations_ptr = + static_cast(packed_activations); + const float32x4_t clamp_min_vec = vdupq_n_f32(clamp_min); + const float32x4_t clamp_max_vec = vdupq_n_f32(clamp_max); + constexpr int bytes_per_weight_tile = ((nr_ * kr_ * weight_nbit_) + 7) / 8; + + for (int m_tile_start = 0; m_tile_start < m; m_tile_start += mr_) { + const float* activation_row_ptr = typed_activations_ptr + m_tile_start * k; + const uint8_t* packed_ptr = static_cast(packed_weights); + + for (int n_tile_start = 0; n_tile_start < n; n_tile_start += nr_) { + float32x4_t accumulators[mr_][nr_ / 4] = {{vdupq_n_f32(0.0f)}}; + + uint8x16x4_t lut_neon; + // Load the 16-float LUT for this tile. + lut_utils::load_fp32_lut( + lut_neon, reinterpret_cast(packed_ptr)); + // Advance the pointer past the LUT. + packed_ptr += 16 * sizeof(float); + float32x4_t scale_vec = vdupq_n_f32(1.0f); + for (int k_tile_start = 0; k_tile_start < k; k_tile_start += kr_) { + if constexpr (has_scales) { + const float* scale_for_tile = nullptr; + + if (k_tile_start % scale_group_size == 0) { + scale_for_tile = reinterpret_cast(packed_ptr); + scale_vec = vld1q_f32(scale_for_tile); + packed_ptr += nr_ * sizeof(float); + } + } + + // The current packed_ptr points to the weight indices. + const uint8_t* indices_ptr = packed_ptr; + + internal::compute_tile_1x4x32( + accumulators[0][0], + activation_row_ptr + k_tile_start, + indices_ptr, + lut_neon, + scale_vec); + + // Advance pointer past the weights that were just used. + packed_ptr += bytes_per_weight_tile; + } + + const float* bias_for_tile = nullptr; + if (has_bias) { + bias_for_tile = reinterpret_cast(packed_ptr); + packed_ptr += nr_ * sizeof(float); + } + + float* output_row_ptr = output + m_tile_start * output_m_stride; + internal::post_process_and_store( + output_row_ptr, + output_m_stride, + n, + n_tile_start, + accumulators, + bias_for_tile, + has_clamp, + clamp_min_vec, + clamp_max_vec); + } + } +} +} // namespace + // torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::kernel +#endif // defined(aarch64) || defined(__ARM_NEON) diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_activations.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_activations.h new file mode 100644 index 0000000000..bf16e04bda --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_activations.h @@ -0,0 +1,31 @@ +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut:: + activation_packing { + +inline size_t packed_activations_size(int m, int k) { + return m * k * sizeof(float); +} + +template +void pack_activations( + // Output + float* packed_activations, + // Inputs + int m, + int k, + const float* activations) { + static_assert(mr_ == 1); + std::memcpy(packed_activations, activations, sizeof(float) * m * k); +} +} // namespace + // torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::activation_packing + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_weights.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_weights.h new file mode 100644 index 0000000000..021693caec --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_weights.h @@ -0,0 +1,228 @@ +#pragma once + +#if defined(aarch64) || defined(__ARM_NEON) +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut:: + weight_packing { +namespace lut_utils = torchao::lut; +namespace packing_utils = torchao::packing; + +/** + * @brief Calculates the exact buffer size in bytes for packed weights. + * + * This function computes the total memory required for a weight buffer based on + * a specific packing layout. The calculation accounts for tiled weights, a + * Look-Up Table (LUT), and optional interleaved scales and biases. It assumes + * the 'n' dimension is padded to be a multiple of the tile height 'nr'. + * + * @param n The number of output channels (columns) in the weight matrix. + * @param k The number of input channels (rows) in the weight matrix. + * @param weight_nbit The bit precision for each weight (e.g., 4, 8). + * @param scale_group_size The number of weights that share a single scale + * factor. + * @param has_scales Set to true to include space for scaling factors. + * @param has_bias Set to true to include space for a bias vector. + * @param nr The tile height used for packing along the 'n' dimension. + * @param kr The tile width used for packing along the 'k' dimension. + * @return The total required size in bytes for the complete packed buffer. + */ +inline size_t packed_weights_size( + int n, + int k, + int weight_nbit, + int scale_group_size, + bool has_scales, + bool has_bias, + int nr, + int kr) { + size_t size_per_n_strip = 0; + + // 1. Size of the LUT, written once per strip. + size_per_n_strip += 16 * sizeof(float); + + // 2. Size of the interleaved scales. + if (has_scales) { + assert( + k % scale_group_size == 0 && + "k must be a multiple of scale_group_size"); + size_t num_scale_blocks = k / scale_group_size; + size_per_n_strip += num_scale_blocks * nr * sizeof(float); + } + + // 3. Size of the packed weight tiles. + assert(k % kr == 0 && "k must be a multiple of kr"); + size_t num_k_tiles = k / kr; + size_t bytes_per_weight_tile = ((nr * kr * weight_nbit) + 7) / 8; + size_per_n_strip += num_k_tiles * bytes_per_weight_tile; + + // 4. Size of the bias, written once per strip. + if (has_bias) { + size_per_n_strip += nr * sizeof(float); + } + + // Calculate the total number of n-strips, padding n to a multiple of nr. + int num_n_strips = (n + nr - 1) / nr; + + return size_per_n_strip * num_n_strips; +} + +/** + * @brief Packs weights, LUTs, scales and bias into a kernel-optimized format. + * @details The function organizes the output buffer into "n-strips," where +each strip corresponds to a tile of `nr_` columns from the weight matrix. + * The memory layout for each strip is as follows: + * 1. **Look-Up Table (LUT):** A 16-element float LUT is written once at + * the beginning of the strip. + * 2. **Interleaved Scales:** If `has_scales` is true, dequantization + * scales are interleaved. For each group of `scale_group_size` + * elements along the k-dimension, `nr_` scale values (one for each + * column in the strip) are written. + * 3. **Packed Weight Tiles:** The core weight data is tiled into + * (`nr_` x `kr_`) blocks. These blocks are then bit-packed and + * interleaved according to the `sr_` ratio before being written. + * 4. **Bias:** If `has_bias` is true, `nr_` bias values are appended + * at the end of the strip. + * + * @tparam weight_nbit_ The true bit-width of the weights. + * @tparam nr_ The column-tiling factor for the kernel (e.g., 4). + * @tparam kr_ The column-tiling factor of the micro-kernel (e.g., 32). + * @tparam sr_ Split ratio determine how the k dimension of a weight tile is +chunked and interleaved during the packing process. + * @param packed_weights_ptr Pointer to the destination buffer. + * @param weight_qval_indices Pointer to the quantized weight matrix (uint8, +row-major). + * @param weight_scales Pointer to the scale factors (float32, row-major). + * @param weight_luts Pointer to the LUTs (float32, row-major). + * @param n The number of columns in the weight matrix. + * @param k The number of rows in the weight matrix. + * @param scale_group_size The number of weights that share a scale factor. + * @param lut_group_size The number of weights that share a LUT. + * @param has_scales If true, the packed buffer will contain scale factors. + * @param has_bias If true, the packed buffer will contain bias terms. + * @param bias Pointer to the bias vector (float32, row-major). + */ +template +TORCHAO_ALWAYS_INLINE inline void pack_weights( + // Output + void* packed_weights_ptr, + // Inputs + const uint8_t* weight_qval_indices, + const float* weight_scales, + const float* weight_luts, + int n, + int k, + int scale_group_size, + int lut_group_size, + bool has_scales, + bool has_bias, + const float* bias) { + static_assert(nr_ == 4); + static_assert(kr_ == 32); + static_assert(sr_ == 8); + static_assert(kr_ % sr_ == 0, "kr must be divisible by sr"); + assert(k % kr_ == 0 && "K must be a multiple of tile dimension kr"); + assert(scale_group_size > 0 && "Scale group size must be positive"); + assert(lut_group_size > 0 && "LUT group size must be positive"); + + // Grouping hierarchy constraint + assert( + lut_group_size % scale_group_size == 0 && + "LUT group size must be a multiple of scale group size"); + + // Group compatibility constraints with tile dimensions + assert( + lut_group_size % (k * nr_) == 0 && + "LUT group size must be compatible with tile dimensions"); + assert(scale_group_size % kr_ == 0 && "Scale group size % kr must be 0"); + + auto* out_ptr = reinterpret_cast(packed_weights_ptr); + constexpr int kLutBufferSize = 16; + std::vector lut_buffer(kLutBufferSize); + + std::vector padded_tile(nr_ * kr_); + + std::vector tmp_buffer(128); + constexpr int bytes_per_128_packed_values = + ((nr_ * kr_ * weight_nbit_) + 7) / 8; + + const int lut_size = 1 << weight_nbit_; + const int scales_per_col = k / scale_group_size; + + for (int n_idx = 0; n_idx < n; n_idx += nr_) { + int current_lut_idx = (n_idx * k) / lut_group_size; + + std::memset(lut_buffer.data(), 0, 16 * sizeof(float)); + std::memcpy(out_ptr, lut_buffer.data(), 16 * sizeof(float)); + + std::memcpy( + lut_buffer.data(), + weight_luts + current_lut_idx * lut_size, + lut_size * sizeof(float)); + std::memcpy(out_ptr, lut_buffer.data(), 16 * sizeof(float)); + out_ptr += 16 * sizeof(float); + + for (int k_idx = 0; k_idx < k; k_idx += kr_) { + int w_idx = n_idx * k + k_idx; + // Write scales if k_idx is a multiple of scale_group_size + if (has_scales && (k_idx % scale_group_size == 0)) { + int scale_idx = w_idx / scale_group_size; + // Write scales for next nr columns + for (int j = 0; j < nr_; j++) { + float scale = 0.0; + if (n_idx + j < n) { + scale = weight_scales[scale_idx + j * scales_per_col]; + } + std::memcpy(out_ptr, &scale, sizeof(float)); + out_ptr += sizeof(float); + } + } + // Write 128 packed tile (kr x nr) + std::memset(padded_tile.data(), 0, 128); + for (int j = 0; j < nr_; j++) { + if (n_idx + j < n) { + std::memcpy( + padded_tile.data() + j * kr_, + weight_qval_indices + w_idx + j * k, + kr_); + } + } + packing_utils::pack_values( + tmp_buffer.data(), padded_tile.data(), nr_, kr_, sr_); + const uint8_t* buffer = tmp_buffer.data(); + torchao::bitpacking::vec_pack_128_uintx_values( + reinterpret_cast(out_ptr), + vld1q_u8(buffer), + vld1q_u8(buffer + 16), + vld1q_u8(buffer + 32), + vld1q_u8(buffer + 48), + vld1q_u8(buffer + 64), + vld1q_u8(buffer + 80), + vld1q_u8(buffer + 96), + vld1q_u8(buffer + 112)); + out_ptr += bytes_per_128_packed_values; + } // k_idx + + if (has_bias) { + for (int i = 0; i < nr_; i++) { + float current_bias = 0.0; + if (n_idx + i < n) { + current_bias = bias[n_idx + i]; + } + std::memcpy(out_ptr, ¤t_bias, sizeof(float)); + out_ptr += sizeof(float); + } + } + } +} +} // namespace + // torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::weight_packing +#endif // defined(aarch64) || defined(__ARM_NEON) diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/lut/lut.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/lut/lut.h new file mode 100644 index 0000000000..c8b76d979f --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/lut/lut.h @@ -0,0 +1,84 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +namespace torchao::lut { + +TORCHAO_ALWAYS_INLINE inline void load_fp32_lut(uint8x16x4_t& lut, const float* table) { + lut = { + vld1q_u8((const uint8_t*)&table[0]), + vld1q_u8((const uint8_t*)&table[4]), + vld1q_u8((const uint8_t*)&table[8]), + vld1q_u8((const uint8_t*)&table[12]) + }; +} + +// This function looks up float values from a 16-value LUT +// (stored as 16 consecutive floats loaded into uint8x16x4_t) +// The indices of the 16 values being looked up are contained in idx +// These values are output to out0, out1, out2, and out3 +TORCHAO_ALWAYS_INLINE inline void lookup_from_fp32_lut( + float32x4_t& out0, + float32x4_t& out1, + float32x4_t& out2, + float32x4_t& out3, + const uint8x16x4_t& lut, + const uint8x16_t idx +) { + // Performs a vectorized lookup of FP32 values from a 16-element float table. + // The input `idx` is a uint8x16_t vector containing 16 indices (0–15), + // each selecting a float from the LUT. Since each float is 4 bytes, we compute + // the byte offsets for each selected float: + // - `idx0` = idx * 4 (byte 0 of each float) + // - `idx1` = idx0 + 1 (byte 1) + // - `idx2` = idx0 + 2 (byte 2) + // - `idx3` = idx0 + 3 (byte 3) + // + // These are grouped into a 4-way NEON table `idx_tbl = {idx0, idx1, idx2, idx3}`. + // + // To reconstruct full FP32 values (4 bytes each) from the byte lookup, we use + // `vqtbl4q_u8(idx_tbl, ...)` with a special interleaving `offsets` vector: + // - `offsets = { 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51 }` + // + // This offset pattern selects the 4 bytes for float0 (0, 16, 32, 48), float1 (1, 17, 33, 49), etc. + // + // We repeat this with offset vectors incremented by 4 and 8 and 12 to produce + // `out1_idx`, `out2_idx`, and `out3_idx`, each forming the byte indices for + // the next group of 4 floats. + // + // Finally, we use `vqtbl4q_u8(lut, outN_idx)` to gather bytes from the original LUT, + // and `vreinterpretq_f32_u8(...)` to convert the byte-wise result into + // actual `float32x4_t` values: `out0`, `out1`, `out2`, and `out3` + + uint8x16_t idx0 = vshlq_n_u8(idx, 2); + uint8x16_t idx1 = vaddq_u8(idx0, vdupq_n_u8(1)); + uint8x16_t idx2 = vaddq_u8(idx0, vdupq_n_u8(2)); + uint8x16_t idx3 = vaddq_u8(idx0, vdupq_n_u8(3)); + + // 4-way interleave idx0, idx1, idx2, idx3 to create out0_idx, out1_idx, out2_idx, out3_idx + uint8x16x4_t idx_tbl = {idx0, idx1, idx2, idx3}; + uint8x16_t offsets = { 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51 }; + uint8x16_t out0_idx = vqtbl4q_u8(idx_tbl, offsets); + uint8x16_t out1_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(4))); + uint8x16_t out2_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(8))); + uint8x16_t out3_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(12))); + + out0 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out0_idx)); + out1 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out1_idx)); + out2 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out2_idx)); + out3 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out3_idx)); +} + +} // namespace torchao::lut + + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h index 5ed3b686fd..925bbbb4bd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h index c976be39f5..2c34cebc3c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h index 19bde9dad9..80417f37e4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h index 4fc393fcaf..28f173e9bc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h index a3dd44a10b..ffcd0a1f1d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul.h similarity index 91% rename from torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul.h index 86b14a52aa..371dc55666 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul.h @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. // TODO: this file will be deleted and replaced by -// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h +// torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h // It exists now to prevent breaking existing code in the interim. #pragma once @@ -309,10 +309,10 @@ void kernel( } // namespace fp32_a_input_channelwise_8bit_b_f32 } // namespace torchao::kernels::cpu::aarch64::quantized_matmul -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul_utils.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul_utils.h index 0a3c8463a8..db577c39a8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul_utils.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include #include #include diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/packing/utils.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/packing/utils.h new file mode 100644 index 0000000000..32ee7000b9 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/packing/utils.h @@ -0,0 +1,67 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +namespace torchao::packing { + +// Packs nr * kr values for GEMM with packing params (nr, kr, sr) +// It takes (kr / sr) values from each of nr columns and writes to packed_values +// This is repeated sr times +template +void pack_values( + // Output + T* packed_values, + // Inputs + const T* values, + int nr, + int kr, + int sr) { + assert(kr % sr == 0); + int kr_per_sr = kr / sr; + int dst_idx = 0; + for (int sr_idx = 0; sr_idx < sr; sr_idx++) { + for (int n_idx = 0; n_idx < nr; n_idx++) { + // Take kr_per_sr values from column n_idx + std::memcpy( + packed_values + dst_idx, + values + n_idx * kr + sr_idx * kr_per_sr, + sizeof(T) * kr_per_sr); + dst_idx += kr_per_sr; + } + } +} + +// Undoes pack_values +template +void unpack_values( + // Output + T* values, + // Inputs + const T* packed_values, + int nr, + int kr, + int sr) { + // packed_values and values should have size nr * kr + // This function takes (kr / sr) from each column of nr columns and writes to + // output This is repeated sr times + assert(kr % sr == 0); + int kr_per_sr = kr / sr; + int dst_idx = 0; + for (int sr_idx = 0; sr_idx < sr; sr_idx++) { + for (int n_idx = 0; n_idx < nr; n_idx++) { + // Take kr_per_sr values from column n_idx + std::memcpy( + values + n_idx * kr + sr_idx * kr_per_sr, + packed_values + dst_idx, + sizeof(T) * kr_per_sr); + dst_idx += kr_per_sr; + } + } +} + +} // namespace torchao::packing diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp index 3460d67fba..42301dc2fa 100644 --- a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp @@ -6,7 +6,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.h diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp index 3a41307cb3..1b9d2aa97b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp @@ -6,7 +6,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum( diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp similarity index 93% rename from torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp index 89707eb0ac..ea4efcf1cc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp @@ -6,7 +6,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include void torchao::kernels::cpu::aarch64::reduction::find_min_and_max( diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/reduction.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/reduction.h diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/CMakeLists.txt new file mode 100644 index 0000000000..8d214b2e61 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/CMakeLists.txt @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_tests) + + # Delay test discovery till runtime. Useful for cross-compiling. +set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) + +set(TEST_TARGET_PREFIX "torchao_tests_torch_free_kernels_aarch64_") + +add_library( + ${TEST_TARGET_PREFIX}dep + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp +) + +enable_testing() + +add_executable(${TEST_TARGET_PREFIX}test_quantization test_quantization.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_quantization + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_reduction test_reduction.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_reduction + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_bitpacking test_bitpacking.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_bitpacking + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_linear test_linear.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_linear + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep + torchao_kernels_aarch64 +) + +add_executable(${TEST_TARGET_PREFIX}test_embedding_lut test_embedding_lut.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_embedding_lut + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_embedding test_embedding.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_embedding + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_weight_packing test_weight_packing.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_weight_packing + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_qmatmul test_qmatmul.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_qmatmul + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_lut test_lut.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_lut + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_bitpack_fallback_compatibility test_bitpack_fallback_compatibility.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_bitpack_fallback_compatibility + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +include(GoogleTest) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_quantization) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_reduction) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_bitpacking) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_linear) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_embedding) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_embedding_lut) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_weight_packing) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_qmatmul) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_lut) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_bitpack_fallback_compatibility) diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpack_fallback_compatibility.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpack_fallback_compatibility.cpp new file mode 100644 index 0000000000..ccae74cbcd --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpack_fallback_compatibility.cpp @@ -0,0 +1,686 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include + +#include +#include +#include + +// --- Compatibility Tests for uint1 --- + +TEST(test_bitpacking_64_uint1_values, CppToNeon) { + int unpacked_bytes = 64; + int nbit = 1; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_64_uint1_values( + packed.data(), input.data()); + + uint8x16_t u0, u1, u2, u3; + torchao::bitpacking::internal::vec_unpack_64_uint1_values( + u0, u1, u2, u3, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + vst1q_u8(unpacked.data() + 32, u2); + vst1q_u8(unpacked.data() + 48, u3); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint1_values, NeonToCpp) { + int unpacked_bytes = 64; + int nbit = 1; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0, i1, i2, i3; + torchao::bitpacking::internal::vec_load_64_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint1_values( + packed.data(), i0, i1, i2, i3); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_64_uint1_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_128_uint1_values, CppToNeon) { + int unpacked_bytes = 128; + int nbit = 1; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_128_uint1_values( + packed.data(), input.data()); + + uint8x16_t u0, u1, u2, u3, u4, u5, u6, u7; + torchao::bitpacking::internal::vec_unpack_128_uint1_values( + u0, u1, u2, u3, u4, u5, u6, u7, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + vst1q_u8(unpacked.data() + 32, u2); + vst1q_u8(unpacked.data() + 48, u3); + vst1q_u8(unpacked.data() + 64, u4); + vst1q_u8(unpacked.data() + 80, u5); + vst1q_u8(unpacked.data() + 96, u6); + vst1q_u8(unpacked.data() + 112, u7); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_128_uint1_values, NeonToCpp) { + int unpacked_bytes = 128; + int nbit = 1; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0, i1, i2, i3, i4, i5, i6, i7; + torchao::bitpacking::internal::vec_load_64_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values( + i4, i5, i6, i7, input.data() + 64); + torchao::bitpacking::internal::vec_pack_128_uint1_values( + packed.data(), i0, i1, i2, i3, i4, i5, i6, i7); + + torchao::kernels::cpu::fallback::bitpacking::internal:: + unpack_128_uint1_values(unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +// --- Compatibility Tests for uint2 --- + +TEST(test_bitpacking_32_uint2_values, CppToNeon) { + int unpacked_bytes = 32; + int nbit = 2; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_32_uint2_values( + packed.data(), input.data()); + + uint8x8_t u0, u1, u2, u3; + torchao::bitpacking::internal::vec_unpack_32_uint2_values( + u0, u1, u2, u3, packed.data()); + vst1_u8(unpacked.data(), u0); + vst1_u8(unpacked.data() + 8, u1); + vst1_u8(unpacked.data() + 16, u2); + vst1_u8(unpacked.data() + 24, u3); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_32_uint2_values, NeonToCpp) { + int unpacked_bytes = 32; + int nbit = 2; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x8_t i0, i1, i2, i3; + torchao::bitpacking::internal::vec_load_32_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_pack_32_uint2_values( + packed.data(), i0, i1, i2, i3); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_32_uint2_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint2_values, CppToNeon) { + int unpacked_bytes = 64; + int nbit = 2; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_64_uint2_values( + packed.data(), input.data()); + + uint8x16_t u0, u1, u2, u3; + torchao::bitpacking::internal::vec_unpack_64_uint2_values( + u0, u1, u2, u3, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + vst1q_u8(unpacked.data() + 32, u2); + vst1q_u8(unpacked.data() + 48, u3); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint2_values, NeonToCpp) { + int unpacked_bytes = 64; + int nbit = 2; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0, i1, i2, i3; + torchao::bitpacking::internal::vec_load_64_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint2_values( + packed.data(), i0, i1, i2, i3); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_64_uint2_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +// --- Compatibility Tests for uint3 --- + +TEST(test_bitpacking_64_uint3_values, CppToNeon) { + int unpacked_bytes = 64; + int nbit = 3; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_64_uint3_values( + packed.data(), input.data()); + + uint8x16_t u0, u1, u2, u3; + torchao::bitpacking::internal::vec_unpack_64_uint3_values( + u0, u1, u2, u3, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + vst1q_u8(unpacked.data() + 32, u2); + vst1q_u8(unpacked.data() + 48, u3); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint3_values, NeonToCpp) { + int unpacked_bytes = 64; + int nbit = 3; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0, i1, i2, i3; + torchao::bitpacking::internal::vec_load_64_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint3_values( + packed.data(), i0, i1, i2, i3); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_64_uint3_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_128_uint3_values, CppToNeon) { + int unpacked_bytes = 128; + int nbit = 3; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_128_uint3_values( + packed.data(), input.data()); + + uint8x16_t u0, u1, u2, u3, u4, u5, u6, u7; + torchao::bitpacking::internal::vec_unpack_128_uint3_values( + u0, u1, u2, u3, u4, u5, u6, u7, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + vst1q_u8(unpacked.data() + 32, u2); + vst1q_u8(unpacked.data() + 48, u3); + vst1q_u8(unpacked.data() + 64, u4); + vst1q_u8(unpacked.data() + 80, u5); + vst1q_u8(unpacked.data() + 96, u6); + vst1q_u8(unpacked.data() + 112, u7); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_128_uint3_values, NeonToCpp) { + int unpacked_bytes = 128; + int nbit = 3; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0, i1, i2, i3, i4, i5, i6, i7; + torchao::bitpacking::internal::vec_load_64_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values( + i4, i5, i6, i7, input.data() + 64); + torchao::bitpacking::internal::vec_pack_128_uint3_values( + packed.data(), i0, i1, i2, i3, i4, i5, i6, i7); + + torchao::kernels::cpu::fallback::bitpacking::internal:: + unpack_128_uint3_values(unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +// --- Compatibility Tests for uint4 --- + +TEST(test_bitpacking_16_uint4_values, CppToNeon) { + int unpacked_bytes = 16; + int nbit = 4; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_16_uint4_values( + packed.data(), input.data()); + + uint8x16_t unpacked0; + torchao::bitpacking::internal::vec_unpack_16_uint4_values( + unpacked0, packed.data()); + vst1q_u8(unpacked.data(), unpacked0); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_16_uint4_values, NeonToCpp) { + int unpacked_bytes = 16; + int nbit = 4; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t input0 = vld1q_u8(input.data()); + torchao::bitpacking::internal::vec_pack_16_uint4_values( + packed.data(), input0); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_16_uint4_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_32_uint4_values, CppToNeon) { + int unpacked_bytes = 32; + int nbit = 4; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_32_uint4_values( + packed.data(), input.data()); + + uint8x16_t unpacked0, unpacked1; + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + unpacked0, unpacked1, packed.data()); + vst1q_u8(unpacked.data(), unpacked0); + vst1q_u8(unpacked.data() + 16, unpacked1); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_32_uint4_values, NeonToCpp) { + int unpacked_bytes = 32; + int nbit = 4; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t input0 = vld1q_u8(input.data()); + uint8x16_t input1 = vld1q_u8(input.data() + 16); + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed.data(), input0, input1); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_32_uint4_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +// --- Compatibility Tests for uint5 --- + +TEST(test_bitpacking_64_uint5_values, CppToNeon) { + int unpacked_bytes = 64; + int nbit = 5; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_64_uint5_values( + packed.data(), input.data()); + + uint8x16_t unpacked0, unpacked1, unpacked2, unpacked3; + torchao::bitpacking::internal::vec_unpack_64_uint5_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + vst1q_u8(unpacked.data(), unpacked0); + vst1q_u8(unpacked.data() + 16, unpacked1); + vst1q_u8(unpacked.data() + 32, unpacked2); + vst1q_u8(unpacked.data() + 48, unpacked3); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint5_values, NeonToCpp) { + int unpacked_bytes = 64; + int nbit = 5; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t input0, input1, input2, input3; + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint5_values( + packed.data(), input0, input1, input2, input3); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_64_uint5_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_128_uint5_values, CppToNeon) { + int unpacked_bytes = 128; + int nbit = 5; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_128_uint5_values( + packed.data(), input.data()); + + uint8x16_t u0, u1, u2, u3, u4, u5, u6, u7; + torchao::bitpacking::internal::vec_unpack_128_uint5_values( + u0, u1, u2, u3, u4, u5, u6, u7, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + vst1q_u8(unpacked.data() + 32, u2); + vst1q_u8(unpacked.data() + 48, u3); + vst1q_u8(unpacked.data() + 64, u4); + vst1q_u8(unpacked.data() + 80, u5); + vst1q_u8(unpacked.data() + 96, u6); + vst1q_u8(unpacked.data() + 112, u7); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_128_uint5_values, NeonToCpp) { + int unpacked_bytes = 128; + int nbit = 5; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0, i1, i2, i3, i4, i5, i6, i7; + torchao::bitpacking::internal::vec_load_64_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values( + i4, i5, i6, i7, input.data() + 64); + torchao::bitpacking::internal::vec_pack_128_uint5_values( + packed.data(), i0, i1, i2, i3, i4, i5, i6, i7); + + torchao::kernels::cpu::fallback::bitpacking::internal:: + unpack_128_uint5_values(unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +// --- Compatibility Tests for uint6 --- + +TEST(test_bitpacking_32_uint6_values, CppToNeon) { + int unpacked_bytes = 32; + int nbit = 6; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_32_uint6_values( + packed.data(), input.data()); + + uint8x16_t u0, u1; + torchao::bitpacking::internal::vec_unpack_32_uint6_values( + u0, u1, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_32_uint6_values, NeonToCpp) { + int unpacked_bytes = 32; + int nbit = 6; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0 = vld1q_u8(input.data()); + uint8x16_t i1 = vld1q_u8(input.data() + 16); + torchao::bitpacking::internal::vec_pack_32_uint6_values( + packed.data(), i0, i1); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_32_uint6_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint6_values, CppToNeon) { + int unpacked_bytes = 64; + int nbit = 6; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_64_uint6_values( + packed.data(), input.data()); + + uint8x16_t u0, u1, u2, u3; + torchao::bitpacking::internal::vec_unpack_64_uint6_values( + u0, u1, u2, u3, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + vst1q_u8(unpacked.data() + 32, u2); + vst1q_u8(unpacked.data() + 48, u3); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint6_values, NeonToCpp) { + int unpacked_bytes = 64; + int nbit = 6; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0, i1, i2, i3; + torchao::bitpacking::internal::vec_load_64_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint6_values( + packed.data(), i0, i1, i2, i3); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_64_uint6_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +// --- Compatibility Tests for uint7 --- + +TEST(test_bitpacking_64_uint7_values, CppToNeon) { + int unpacked_bytes = 64; + int nbit = 7; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_64_uint7_values( + packed.data(), input.data()); + + uint8x16_t unpacked0, unpacked1, unpacked2, unpacked3; + torchao::bitpacking::internal::vec_unpack_64_uint7_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + vst1q_u8(unpacked.data(), unpacked0); + vst1q_u8(unpacked.data() + 16, unpacked1); + vst1q_u8(unpacked.data() + 32, unpacked2); + vst1q_u8(unpacked.data() + 48, unpacked3); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint7_values, NeonToCpp) { + int unpacked_bytes = 64; + int nbit = 7; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t input0, input1, input2, input3; + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint7_values( + packed.data(), input0, input1, input2, input3); + + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_64_uint7_values( + unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_128_uint7_values, CppToNeon) { + int unpacked_bytes = 128; + int nbit = 7; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_128_uint7_values( + packed.data(), input.data()); + + uint8x16_t u0, u1, u2, u3, u4, u5, u6, u7; + torchao::bitpacking::internal::vec_unpack_128_uint7_values( + u0, u1, u2, u3, u4, u5, u6, u7, packed.data()); + vst1q_u8(unpacked.data(), u0); + vst1q_u8(unpacked.data() + 16, u1); + vst1q_u8(unpacked.data() + 32, u2); + vst1q_u8(unpacked.data() + 48, u3); + vst1q_u8(unpacked.data() + 64, u4); + vst1q_u8(unpacked.data() + 80, u5); + vst1q_u8(unpacked.data() + 96, u6); + vst1q_u8(unpacked.data() + 112, u7); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_128_uint7_values, NeonToCpp) { + int unpacked_bytes = 128; + int nbit = 7; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + uint8x16_t i0, i1, i2, i3, i4, i5, i6, i7; + torchao::bitpacking::internal::vec_load_64_uint8_values( + i0, i1, i2, i3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values( + i4, i5, i6, i7, input.data() + 64); + torchao::bitpacking::internal::vec_pack_128_uint7_values( + packed.data(), i0, i1, i2, i3, i4, i5, i6, i7); + + torchao::kernels::cpu::fallback::bitpacking::internal:: + unpack_128_uint7_values(unpacked.data(), packed.data()); + + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpacking.cpp similarity index 83% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpacking.cpp index 7e7ccaea26..d052ae1d47 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpacking.cpp @@ -8,15 +8,15 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include TEST(test_bitpacking_8_uint1_values, PackUnpackAreSame) { @@ -209,21 +209,7 @@ TEST(test_bitpacking_64_uint2_values, PackUnpackAreSame) { } } -TEST(test_bitpacking_8_uint3_values, PackUnpackAreSame) { - int unpacked_bytes = 8; - int packed_bytes = 3; - auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 3); - std::vector packed(packed_bytes, 0); - std::vector unpacked(unpacked_bytes, 0); - torchao::bitpacking::internal::pack_8_uint3_values( - packed.data(), input.data()); - torchao::bitpacking::internal::unpack_8_uint3_values( - unpacked.data(), packed.data()); - for (int i = 0; i < unpacked_bytes; ++i) { - EXPECT_EQ(input[i], unpacked[i]); - } -} TEST(test_bitpacking_64_uint3_values, PackUnpackAreSame) { int unpacked_bytes = 64; @@ -921,4 +907,134 @@ TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(2); TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(3); TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(4); + +template +void test_vec_uintx_packing_unpacking_32() { + constexpr int unpacked_values = 32; + constexpr int packed_bytes = unpacked_values * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_values, nbit); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0 = vld1q_u8(input.data()); + uint8x16_t input1 = vld1q_u8(input.data() + 16); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + + torchao::bitpacking::vec_pack_32_uintx_values(packed.data(), input0, input1); + torchao::bitpacking::vec_unpack_32_uintx_values(unpacked0, unpacked1, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + } +} + +template +void test_vec_uintx_packing_unpacking_64() { + constexpr int unpacked_values = 64; + constexpr int packed_bytes = unpacked_values * nbit / 8; + + auto input = torchao::get_random_lowbit_vector(unpacked_values, nbit); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0 = vld1q_u8(input.data()); + uint8x16_t input1 = vld1q_u8(input.data() + 16); + uint8x16_t input2 = vld1q_u8(input.data() + 32); + uint8x16_t input3 = vld1q_u8(input.data() + 48); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + torchao::bitpacking::vec_pack_64_uintx_values(packed.data(), input0, input1, input2, input3); + torchao::bitpacking::vec_unpack_64_uintx_values(unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + +template +void test_vec_uintx_packing_unpacking_128() { + constexpr int unpacked_values = 128; + constexpr int packed_bytes = unpacked_values * nbit / 8; + + auto input = torchao::get_random_lowbit_vector(unpacked_values, nbit); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0 = vld1q_u8(input.data()); + uint8x16_t input1 = vld1q_u8(input.data() + 16); + uint8x16_t input2 = vld1q_u8(input.data() + 32); + uint8x16_t input3 = vld1q_u8(input.data() + 48); + uint8x16_t input4 = vld1q_u8(input.data() + 64); + uint8x16_t input5 = vld1q_u8(input.data() + 80); + uint8x16_t input6 = vld1q_u8(input.data() + 96); + uint8x16_t input7 = vld1q_u8(input.data() + 112); + + uint8x16_t unpacked0, unpacked1, unpacked2, unpacked3; + uint8x16_t unpacked4, unpacked5, unpacked6, unpacked7; + + torchao::bitpacking::vec_pack_128_uintx_values( + packed.data(), input0, input1, input2, input3, input4, input5, input6, input7); + torchao::bitpacking::vec_unpack_128_uintx_values( + unpacked0, unpacked1, unpacked2, unpacked3, unpacked4, unpacked5, unpacked6, unpacked7, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + EXPECT_EQ(input4[i], unpacked4[i]); + EXPECT_EQ(input5[i], unpacked5[i]); + EXPECT_EQ(input6[i], unpacked6[i]); + EXPECT_EQ(input7[i], unpacked7[i]); + } +} + +#define TEST_UINTX_PACKING_UNPACKING_32(nbit) \ + TEST(test_vec_uintx_packing_unpacking_32_##nbit, RoundtripIsCorrect) { \ + test_vec_uintx_packing_unpacking_32(); \ + } + +#define TEST_UINTX_PACKING_UNPACKING_64(nbit) \ + TEST(test_vec_uintx_packing_unpacking_64_##nbit, RoundtripIsCorrect) { \ + test_vec_uintx_packing_unpacking_64(); \ + } + +#define TEST_UINTX_PACKING_UNPACKING_128(nbit) \ + TEST(test_vec_uintx_packing_unpacking_128_##nbit, RoundtripIsCorrect) { \ + test_vec_uintx_packing_unpacking_128(); \ + } + +TEST_UINTX_PACKING_UNPACKING_32(1); +TEST_UINTX_PACKING_UNPACKING_32(2); +TEST_UINTX_PACKING_UNPACKING_32(3); +TEST_UINTX_PACKING_UNPACKING_32(4); +TEST_UINTX_PACKING_UNPACKING_32(5); +TEST_UINTX_PACKING_UNPACKING_32(6); +TEST_UINTX_PACKING_UNPACKING_32(7); +TEST_UINTX_PACKING_UNPACKING_32(8); + +TEST_UINTX_PACKING_UNPACKING_64(1); +TEST_UINTX_PACKING_UNPACKING_64(2); +TEST_UINTX_PACKING_UNPACKING_64(3); +TEST_UINTX_PACKING_UNPACKING_64(4); +TEST_UINTX_PACKING_UNPACKING_64(5); +TEST_UINTX_PACKING_UNPACKING_64(6); +TEST_UINTX_PACKING_UNPACKING_64(7); +TEST_UINTX_PACKING_UNPACKING_64(8); + +TEST_UINTX_PACKING_UNPACKING_128(1); +TEST_UINTX_PACKING_UNPACKING_128(2); +TEST_UINTX_PACKING_UNPACKING_128(3); +TEST_UINTX_PACKING_UNPACKING_128(4); +TEST_UINTX_PACKING_UNPACKING_128(5); +TEST_UINTX_PACKING_UNPACKING_128(6); +TEST_UINTX_PACKING_UNPACKING_128(7); +TEST_UINTX_PACKING_UNPACKING_128(8); #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding.cpp similarity index 96% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding.cpp index 8fe7e69574..e5cdfb0a1b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding.cpp @@ -7,9 +7,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include +#include +#include +#include #include float kTol = 0.0001; diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding_lut.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding_lut.cpp new file mode 100644 index 0000000000..5802a179d0 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding_lut.cpp @@ -0,0 +1,135 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include + +float kTol = 0.0001; + +template +void test_embedding( + int num_embeddings, + int embedding_dim, + int scale_group_size, + int lut_group_size, + bool has_scales) { + auto test_case = torchao::lut_embedding_test_case::generate( + num_embeddings, + embedding_dim, + scale_group_size, + lut_group_size, + has_scales); + + const size_t packed_embedding_size = + torchao::kernels::cpu::aarch64::embedding::packed_embedding_size( + weight_nbit, + num_embeddings, + embedding_dim, + scale_group_size, + lut_group_size, + has_scales); + + auto packed = std::vector(packed_embedding_size, 0); + auto output = std::vector(num_embeddings * embedding_dim, 0.0); + + for (int i = 0; i < num_embeddings; i++) { + torchao::kernels::cpu::aarch64::embedding::pack_embedding_row_at_index_lut< + weight_nbit>( + packed.data(), + i, + test_case.weight_qval_idxs.data(), + test_case.weight_scales.data(), + test_case.weight_luts.data(), + num_embeddings, + embedding_dim, + scale_group_size, + lut_group_size, + has_scales); + } + + for (int i = 0; i < num_embeddings; i++) { + torchao::kernels::cpu::aarch64::embedding:: + dequantize_embedding_row_at_idx_lut( + output.data() + i * embedding_dim, + packed.data(), + i, + num_embeddings, + embedding_dim, + scale_group_size, + lut_group_size, + has_scales); + } + + for (int i = 0; i < num_embeddings * embedding_dim; i++) { + EXPECT_NEAR(output[i], test_case.expected_outputs[i], kTol); + } +} + +struct LutEmbeddingBaseParams { + int num_embeddings; + int embedding_dim; + int scale_group_size; + int lut_group_size; + bool has_scales; +}; + +class LutEmbeddingParamTest + : public ::testing::TestWithParam> { + protected: + // run_test now correctly accepts the base parameters + template + void run_test(const LutEmbeddingBaseParams& params) { + test_embedding( + params.num_embeddings, + params.embedding_dim, + params.scale_group_size, + params.lut_group_size, + params.has_scales); + }; +}; + +TEST_P(LutEmbeddingParamTest, PackDequantizeEndToEnd) { + const auto& base_params = std::get<0>(GetParam()); + const int weight_nbit = std::get<1>(GetParam()); + + switch (weight_nbit) { + case 4: + run_test<4>(base_params); + break; + case 3: + run_test<3>(base_params); + break; + case 2: + run_test<2>(base_params); + break; + case 1: + run_test<1>(base_params); + break; + default: + FAIL() << "Unsupported weight_nbit: " << weight_nbit; + } +} + +INSTANTIATE_TEST_SUITE_P( + LutEmbeddingParamSweep, + LutEmbeddingParamTest, + ::testing::Combine( + ::testing::Values( + LutEmbeddingBaseParams{8, 128, 64, 32, true}, + LutEmbeddingBaseParams{8, 128, 32, 32, true}, + LutEmbeddingBaseParams{4, 256, 128, 64, false}, + LutEmbeddingBaseParams{1, 64, 64, 64, true}, + LutEmbeddingBaseParams{16, 512, 64, 32, true}, + LutEmbeddingBaseParams{3, 96, 32, 32, true}, + LutEmbeddingBaseParams{8, 128, 64, 128, true}, + LutEmbeddingBaseParams{8, 128, 64, 256, true}, + LutEmbeddingBaseParams{8, 128, 64, 512, true}), + ::testing::Values(1, 2, 3, 4))); +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_linear.cpp similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_linear.cpp index 6d6101e3cf..bf99823052 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_linear.cpp @@ -10,9 +10,9 @@ #include #include -#include -#include -#include +#include +#include +#include float kTol = 0.0001; diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_lut.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_lut.cpp new file mode 100644 index 0000000000..6d9214eeba --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_lut.cpp @@ -0,0 +1,686 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include +#include +#include +#include + +namespace lut_utils = torchao::lut; +namespace kernel_api = + torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut; + +TEST(test_fp32_lut, LutLookup) { + auto lut = torchao::get_random_vector(16, -1.0, 1.0); + auto idx = torchao::get_random_lowbit_vector(16, 4); + + uint8x16_t idx_vec = vld1q_u8(idx.data()); + uint8x16x4_t lut_vec; + torchao::lut::load_fp32_lut(lut_vec, lut.data()); + + float32x4_t out0, out1, out2, out3; + torchao::lut::lookup_from_fp32_lut(out0, out1, out2, out3, lut_vec, idx_vec); + + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(out0[i], lut[idx[i]]); + EXPECT_EQ(out1[i], lut[idx[i + 4]]); + EXPECT_EQ(out2[i], lut[idx[i + 8]]); + EXPECT_EQ(out3[i], lut[idx[i + 12]]); + } +} + +template < + int weight_nbit_, + bool has_scales_, + int mr_, + int nr_, + int kr_, + int sr_> +void test_groupwise_lowbit_lut_kernel( + int m, + int k, + int n, + int flat_scale_group_size, + int flat_lut_group_size, + bool has_bias, + bool has_clamp) { + namespace kernel_api = + torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut; + // 1. Generate test case + auto test_case = torchao::groupwise_lowbit_weight_lut_test_case:: + generate_with_decoupled_grouping( + m, + k, + n, + /*scale_group_size=*/flat_scale_group_size, + /*lut_group_size=*/flat_lut_group_size, + /*weight_nbit=*/weight_nbit_, + /*has_scales=*/has_scales_, + has_bias, + has_clamp); + // 2. Pack Activations + const auto& source_activations = test_case.activations; + std::vector packed_activations_buffer( + kernel_api::packed_activations_size(m, k, mr_, kr_, sr_)); + kernel_api::pack_activations( + packed_activations_buffer.data(), + m, + k, + source_activations.data(), + mr_, + kr_, + sr_); + // 3. Pack Weights + std::vector packed_weights(kernel_api::packed_weights_size( + n, + k, + weight_nbit_, + flat_scale_group_size, + has_scales_, + has_bias, + nr_, + kr_, + sr_)); + kernel_api::pack_weights( + packed_weights.data(), + test_case.weight_qval_indices.data(), + test_case.weight_scales.data(), + test_case.weight_luts.data(), + n, + k, + flat_scale_group_size, + flat_lut_group_size, + has_scales_, + has_bias, + test_case.bias.data(), + nr_, + kr_, + sr_); + + // 4. Run the kernel + std::vector output(m * n); + kernel_api:: + groupwise_lowbit_weight_lut_kernel_1x4x32( + output.data(), + n, + m, + n, + k, + flat_scale_group_size, + flat_lut_group_size, + packed_weights.data(), + packed_activations_buffer.data(), + test_case.clamp_min, + test_case.clamp_max, + has_bias, + has_clamp); + + // 5. Compare results + constexpr float kTol = 1e-4; + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol) + << "Mismatch at index " << i; + } +} + +TEST(test_groupwise_lowbit_lut_kernel, 4bit_aligned) { + constexpr int weight_nbit_ = 4; + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/false); + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/false); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); +} + +TEST(test_groupwise_lowbit_lut_kernel, 4bit_mismatch) { + constexpr int weight_nbit_ = 4; + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/false); + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/false); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); +} + +TEST(test_groupwise_lowbit_lut_kernel, 3bit_mismatch) { + constexpr int weight_nbit_ = 3; + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/false); + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/false); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/512, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); +} + +TEST(test_groupwise_lowbit_lut_kernel, 2bit_mismatch) { + constexpr int weight_nbit_ = 2; + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/false, + /*has_clamp=*/false); + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/true, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/false, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/true, + /*has_clamp=*/false); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/true, + /*has_clamp=*/true); +} + +TEST(test_groupwise_lowbit_lut_kernel, 1bit_mismatch) { + constexpr int weight_nbit_ = 1; + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/false, + /*has_clamp=*/false); + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/true, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/false, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/true, + /*has_clamp=*/false); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, + /*flat_lut_group_size=*/512, + /*has_bias=*/true, + /*has_clamp=*/true); +} + +TEST(test_groupwise_lowbit_lut_kernel, 3bit_aligned) { + constexpr int weight_nbit_ = 3; + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/false); + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/false); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); +} + +TEST(test_groupwise_lowbit_lut_kernel, 2bit_aligned) { + constexpr int weight_nbit_ = 2; + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/false); + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/false); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); +} + +TEST(test_groupwise_lowbit_lut_kernel, 1bit_aligned) { + constexpr int weight_nbit_ = 1; + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/false); + + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/false, + /*has_clamp=*/true); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/false); + test_groupwise_lowbit_lut_kernel( + /*m=*/8, + /*k=*/64, + /*n=*/16, + /*flat_scale_group_size=*/32, // Must be multiple of k*NR = 256 + /*flat_lut_group_size=*/256, // Must be multiple of k*NR = 256 + /*has_bias=*/true, + /*has_clamp=*/true); +} + +struct KernelTestParams { + int m; + int k; + int n; + int flat_scale_group_size; + int flat_lut_group_size; + bool has_bias; + bool has_clamp; +}; + +class ComprehensiveKernelTest + : public ::testing::TestWithParam {}; + +TEST_P(ComprehensiveKernelTest, kernel_test) { + const KernelTestParams& params = GetParam(); + + constexpr int mr = 1; + constexpr int nr = 4; + constexpr int kr = 32; + constexpr int sr = 8; + constexpr bool has_scales = true; + + for (int weight_nbit : {1, 2, 3, 4}) { + switch (weight_nbit) { + case 1: + test_groupwise_lowbit_lut_kernel<1, has_scales, mr, nr, kr, sr>( + params.m, + params.k, + params.n, + params.flat_scale_group_size, + params.flat_lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 2: + test_groupwise_lowbit_lut_kernel<2, has_scales, mr, nr, kr, sr>( + params.m, + params.k, + params.n, + params.flat_scale_group_size, + params.flat_lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 3: + test_groupwise_lowbit_lut_kernel<3, has_scales, mr, nr, kr, sr>( + params.m, + params.k, + params.n, + params.flat_scale_group_size, + params.flat_lut_group_size, + params.has_bias, + params.has_clamp); + break; + case 4: + test_groupwise_lowbit_lut_kernel<4, has_scales, mr, nr, kr, sr>( + params.m, + params.k, + params.n, + params.flat_scale_group_size, + params.flat_lut_group_size, + params.has_bias, + params.has_clamp); + break; + default: + FAIL() << "Unsupported weight_nbit value: " << weight_nbit; + } + } +} + +INSTANTIATE_TEST_SUITE_P( + KernelEdgeCases, + ComprehensiveKernelTest, + ::testing::Values( + // --- Varying Dimensions --- + // Test cases where n is a multiple of 4 (since lut_group_size = 256) + KernelTestParams{8, 64, 16, 32, 256, true, true}, + KernelTestParams{8, 64, 12, 32, 256, true, true}, + KernelTestParams{8, 64, 8, 32, 256, true, true}, + KernelTestParams{8, 64, 4, 32, 256, true, true}, + + // Test cases where n is a multiple of 8 (since lut_group_size = 512) + KernelTestParams{8, 64, 24, 32, 512, true, true}, + KernelTestParams{8, 64, 16, 32, 512, true, true}, + KernelTestParams{8, 64, 8, 32, 512, true, true}, + + // Test cases where n is a multiple of 16 (since lut_group_size = 1024) + KernelTestParams{8, 64, 32, 32, 1024, true, true}, + KernelTestParams{8, 64, 16, 32, 1024, true, true}, + + // Test unaligned M + KernelTestParams{7, 64, 16, 32, 256, true, true}, + KernelTestParams{6, 64, 16, 32, 256, true, true}, + KernelTestParams{5, 64, 16, 32, 256, true, true}, + KernelTestParams{4, 64, 16, 32, 256, true, true}, + KernelTestParams{3, 64, 16, 32, 256, true, true}, + KernelTestParams{2, 64, 16, 32, 256, true, true}, + KernelTestParams{1, 64, 16, 32, 256, true, true}, + + // --- Varying Group Sizes --- + // Test where one LUT group covers multiple scale groups + KernelTestParams{8, 64, 16, 32, 512, true, true}, + // Test with different group sizes that are not equal + KernelTestParams{8, 64, 16, 32, 1024, true, true}, + KernelTestParams{8, 64, 16, 32, 1024, true, true}, + KernelTestParams{8, 64, 16, 32, 1024, true, true}, + // A single scale group is exactly one row of tiles. + KernelTestParams{8, 64, 16, 32, 256, true, true}, + // All flags off (the simplest path) + KernelTestParams{8, 64, 16, 32, 256, false, false}, + + // All flags on + KernelTestParams{8, 64, 16, 32, 256, true, true}, + + // Other combinations + KernelTestParams{8, 64, 16, 32, 256, true, true}, + KernelTestParams{8, 64, 16, 32, 256, true, false}, + // A single group covers the entire matrix. + + // --- Varying Boolean Flags --- + // Test with only scales enabled + KernelTestParams{8, 64, 16, 32, 256, false, false}, + // Test with only bias enabled + KernelTestParams{8, 64, 16, 32, 256, true, false}, + // Test with only clamp enabled + KernelTestParams{8, 64, 16, 32, 256, false, true}, + // Test with scales and clamp + KernelTestParams{8, 64, 16, 32, 256, false, true}, + + // --- Edges cases --- + KernelTestParams{8, 64, 16, 32, 1024, true, true}, + // A single tile matrix. + KernelTestParams{1, 32, 4, 32, 128, true, true}, + // Group sizes are exactly equal to the padded matrix size. + KernelTestParams{8, 64, 16, 32, 1024, true, true})); + +void PrintTo(const KernelTestParams& params, std::ostream* os) { + *os << "KernelTestParams(m=" << params.m << ", k=" << params.k + << ", n=" << params.n << ", scale_gs=" << params.flat_scale_group_size + << ", lut_gs=" << params.flat_lut_group_size + << ", bias=" << std::boolalpha << params.has_bias + << ", clamp=" << std::boolalpha << params.has_clamp << ")"; +} + +#endif // defined(aarch64) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_qmatmul.cpp similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_qmatmul.cpp index 18c9986393..5d46937ccf 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_qmatmul.cpp @@ -10,9 +10,9 @@ #include #include -#include -#include -#include +#include +#include +#include float kTol = 0.0001; diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_quantization.cpp similarity index 92% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_quantization.cpp index bb19528de7..ebe3fbdfa8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_quantization.cpp @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include // Demonstrate some basic assertions. diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_reduction.cpp similarity index 93% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_reduction.cpp index 0720f2dcf8..44dbafafa5 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_reduction.cpp @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils.h similarity index 58% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils.h index aeb9042210..e5742d3f56 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils.h @@ -8,61 +8,15 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include -#include +#include +#include +#include #include #include #include #include namespace torchao { -inline std::vector -get_random_vector(int size, float min = -1.0, float max = 1.0) { - assert(min < max); - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto dist = std::bind(std::uniform_real_distribution(min, max), rng); - std::vector res(size); - std::generate(res.begin(), res.end(), std::ref(dist)); - return res; -} - -inline std::vector get_random_lowbit_vector(int size, int nbit) { - assert(nbit >= 1); - assert(nbit <= 8); - - int min = 0; - int max = (1 << nbit) - 1; - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); - - std::vector res(size); - std::generate(res.begin(), res.end(), std::ref(dist)); - return res; -} - -inline std::vector get_random_signed_lowbit_vector(int size, int nbit) { - assert(nbit >= 1); - assert(nbit <= 8); - - int min = 0; - int max = (1 << nbit) - 1; - int offset = (1 << (nbit - 1)); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); - - std::vector res(size); - std::vector tmp(size); - std::generate(tmp.begin(), tmp.end(), std::ref(dist)); - for (int i = 0; i < size; i++) { - res[i] = tmp[i] - offset; - } - return res; -} // TODO move these to a common utils inline uint16_t get_bf16_from_float(float f) { @@ -575,6 +529,134 @@ struct lowbit_embedding_test_case { } }; +template +struct lut_embedding_test_case { + // --- Struct Members --- + int num_embeddings; + int embedding_dim; + int scale_group_size; + int lut_group_size; + bool has_scales; + + // Source Data for LUT-based quantization + std::vector weight_qval_idxs; // Unsigned indices into the LUT + std::vector weight_scales; // Grouped scales + std::vector weight_luts; // The lookup tables themselves + + // Ground Truth + std::vector expected_outputs; // Dequantized float values + + // --- Constructor --- + lut_embedding_test_case( + int num_embeddings_, + int embedding_dim_, + int scale_group_size_, + int lut_group_size_, + bool has_scales_, + std::vector weight_qval_idxs_, + std::vector weight_scales_, + std::vector weight_luts_, + std::vector expected_outputs_) + : num_embeddings(num_embeddings_), + embedding_dim(embedding_dim_), + scale_group_size(scale_group_size_), + lut_group_size(lut_group_size_), + has_scales(has_scales_), + weight_qval_idxs(weight_qval_idxs_), + weight_scales(weight_scales_), + weight_luts(weight_luts_), + expected_outputs(expected_outputs_) { + assert((num_embeddings * embedding_dim) % lut_group_size == 0); + assert(embedding_dim % scale_group_size == 0); + assert(this->weight_qval_idxs.size() == num_embeddings * embedding_dim); + if (has_scales) { + assert(this->weight_scales.size() == num_embeddings * (embedding_dim / scale_group_size)); + } + assert(this->expected_outputs.size() == num_embeddings * embedding_dim); + } + + private: + static lut_embedding_test_case _generate( + int num_embeddings, + int embedding_dim, + int scale_group_size, + int lut_group_size, + bool has_scales) { + const int lut_size = 1 << weight_nbit_; + const int total_weights = num_embeddings * embedding_dim; + const int total_lut_groups = + (total_weights + lut_group_size - 1) / lut_group_size; + const int total_scale_groups = has_scales + ? ((total_weights + scale_group_size - 1) / scale_group_size) + : 0; + + // 1. Generate the test case parameters + // Generate random source data + std::mt19937 gen(std::random_device{}()); + auto weight_luts = + get_random_vector(total_lut_groups * lut_size, -1.0f, 1.0f); + + // Generate random quantized indices for each weight. + auto weight_qval_idxs = + get_random_lowbit_vector(total_weights, weight_nbit_); + + // Generate random scales for each weight. + std::vector weight_scales; + if (has_scales) { + weight_scales = get_random_vector(total_scale_groups, 0.5f, 1.5f); + } + + // 2. Calculate the expected outputs by applying the LUT dequantization + auto expected_outputs = std::vector(total_weights); + for (int i = 0; i < num_embeddings; ++i) { + for (int j = 0; j < embedding_dim; ++j) { + const size_t linear_idx = i * embedding_dim + j; + const size_t lut_idx = linear_idx / lut_group_size; + + const size_t lut_offset = lut_idx * lut_size; + const float* current_lut = weight_luts.data() + lut_offset; + + // Scale logic is unchanged. + float scale = 1.0f; + if (has_scales) { + const size_t scale_group_idx = linear_idx / scale_group_size; + scale = weight_scales[scale_group_idx]; + } + + uint8_t q_idx = weight_qval_idxs[linear_idx]; + expected_outputs[linear_idx] = current_lut[q_idx] * scale; + } + } + + // 3. Return the complete test case + return lut_embedding_test_case( + num_embeddings, + embedding_dim, + scale_group_size, + lut_group_size, + has_scales, + weight_qval_idxs, + weight_scales, + weight_luts, + expected_outputs); + } + + public: + static lut_embedding_test_case generate( + int num_embeddings, + int embedding_dim, + int scale_group_size, + int lut_group_size, + bool has_scales) { + return _generate( + num_embeddings, + embedding_dim, + scale_group_size, + lut_group_size, + has_scales); + } +}; + struct groupwise_lowbit_weight_lut_test_case { //-------------------------------------------------------------------------- // Parameters @@ -589,177 +671,371 @@ struct groupwise_lowbit_weight_lut_test_case { //-------------------------------------------------------------------------- // Data Tensors //-------------------------------------------------------------------------- - std::vector expected_output; - std::vector activations; - std::vector bias; - std::vector weight_qval_indices; // Indices into a LUT for each weight - std::vector weight_luts; // The pool of unique LUTs - std::vector weight_scales; // The pool of unique scales + std::vector expected_output; + std::vector activations; + std::vector bias; + std::vector + weight_qval_indices; // Indices into a LUT for each weight + std::vector weight_luts; // The pool of unique LUTs + std::vector weight_scales; // The pool of unique scales //-------------------------------------------------------------------------- // Constructor //-------------------------------------------------------------------------- groupwise_lowbit_weight_lut_test_case( - int m_, int k_, int n_, int scale_group_size_, int lut_group_size_, int weight_nbit_, bool has_scales_, bool has_bias_, bool has_clamp_, - float clamp_min_, float clamp_max_, - std::vector expected_output_, std::vector activations_, - std::vector bias_, std::vector weight_qval_indices_, - std::vector weight_luts_, std::vector weight_scales_) - : m(m_), k(k_), n(n_), - scale_group_size(scale_group_size_), lut_group_size(lut_group_size_), weight_nbit(weight_nbit_), + int m_, + int k_, + int n_, + int scale_group_size_, + int lut_group_size_, + int weight_nbit_, + bool has_scales_, + bool has_bias_, + bool has_clamp_, + float clamp_min_, + float clamp_max_, + std::vector expected_output_, + std::vector activations_, + std::vector bias_, + std::vector weight_qval_indices_, + std::vector weight_luts_, + std::vector weight_scales_) + : m(m_), + k(k_), + n(n_), + scale_group_size(scale_group_size_), + lut_group_size(lut_group_size_), + weight_nbit(weight_nbit_), has_scales(has_scales_), - has_bias(has_bias_), has_clamp(has_clamp_), clamp_min(clamp_min_), clamp_max(clamp_max_), + has_bias(has_bias_), + has_clamp(has_clamp_), + clamp_min(clamp_min_), + clamp_max(clamp_max_), expected_output(expected_output_), activations(activations_), bias(bias_), weight_qval_indices(weight_qval_indices_), weight_luts(weight_luts_), - weight_scales(weight_scales_) - {} + weight_scales(weight_scales_) {} //-------------------------------------------------------------------------- // Generator Functions (Factories) //-------------------------------------------------------------------------- -private: + private: /** * @brief The private "master" generator that provides maximum flexibility. * - * This function is the core engine. It takes the exact number of scales and LUTs - * to generate and constructs the test case. All other public generators are - * wrappers around this one. + * This function is the core engine. It takes the exact number of scales and + * LUTs to generate and constructs the test case. All other public generators + * are wrappers around this one. */ static groupwise_lowbit_weight_lut_test_case _generate_master( - int m, int k, int n, - int scale_group_size, // Directly controls scale change frequency - int lut_group_size, // Directly controls LUT change frequency - int weight_nbit, bool has_scales, - bool has_bias, bool has_clamp) { - + int m, + int k, + int n, + int scale_group_size, // Directly controls scale change frequency + int lut_group_size, // Directly controls LUT change frequency + int weight_nbit, + bool has_scales, + bool has_bias, + bool has_clamp) { // --- 0. Validation and Setup --- const int total_weights = n * k; // Frequencies are controlled by their group sizes. assert(total_weights % scale_group_size == 0); - assert(total_weights % lut_group_size == 0); - // The number of unique scales/LUTs is derived directly from their group size. + // The number of unique scales/LUTs is derived directly from their group + // size. const int num_scales = total_weights / scale_group_size; - const int num_luts = total_weights / lut_group_size; + const int num_luts = (total_weights + lut_group_size - 1) / lut_group_size; const int lut_size = 1 << weight_nbit; std::mt19937 gen(std::random_device{}()); // --- 1. Generate Primary Inputs --- auto activations = get_random_vector(m * k, -1.0f, 1.0f); std::vector bias_vec(n, 0.0f); - if (has_bias) bias_vec = get_random_vector(n, -0.5f, 0.5f); - float clamp_min = -std::numeric_limits::infinity(), clamp_max = std::numeric_limits::infinity(); + if (has_bias) + bias_vec = get_random_vector(n, -0.5f, 0.5f); + float clamp_min = -std::numeric_limits::infinity(), + clamp_max = std::numeric_limits::infinity(); if (has_clamp) { auto r = get_random_vector(2, -5.0f, 5.0f); - clamp_min = std::min(r[0], r[1]); clamp_max = std::max(r[0], r[1]); + clamp_min = std::min(r[0], r[1]); + clamp_max = std::max(r[0], r[1]); } // --- 2. Generate Quantization Data --- // 2a. Generate the pools of unique scales and LUTs. std::vector weight_scales; if (has_scales) { - // Normal case: generate random scales. - weight_scales = get_random_vector(num_scales, 0.001f, 0.1f); + // Normal case: generate random scales. + weight_scales = get_random_vector(num_scales, 0.001f, 0.1f); } else { - // LUT-only case: create a vector where every scale is 1.0f. - weight_scales.assign(num_scales, 1.0f); + // LUT-only case: create a vector where every scale is 1.0f. + weight_scales.assign(num_scales, 1.0f); } - auto weight_luts = get_random_vector(num_luts * lut_size, -0.2f, 0.2f); // Independent random LUTs + auto weight_luts = get_random_vector( + num_luts * lut_size, -0.2f, 0.2f); // Independent random LUTs // 2b. Generate random quantized indices for each weight. auto weight_qval_indices = std::vector(total_weights); std::uniform_int_distribution qval_dis(0, lut_size - 1); - for (int i = 0; i < total_weights; ++i) weight_qval_indices[i] = static_cast(qval_dis(gen)); - - // --- 3. Compute Expected Output using the IMPLICIT mappings --- - std::vector expected_output(m * n); - for (int m_idx = 0; m_idx < m; ++m_idx) { - for (int n_idx = 0; n_idx < n; ++n_idx) { - float res = 0.0f; - for (int k_idx = 0; k_idx < k; ++k_idx) { - float activation_val = activations[m_idx * k + k_idx]; - int weight_idx = n_idx * k + k_idx; - uint8_t qval_idx = weight_qval_indices[weight_idx]; - - int32_t scale_idx = weight_idx / scale_group_size; - int32_t lut_idx = weight_idx / lut_group_size; - - // Dequantize: scale * LUT_value - float scale = weight_scales[scale_idx]; - float lut_val = weight_luts[lut_idx * lut_size + qval_idx]; - res += activation_val * (scale * lut_val); + for (int i = 0; i < total_weights; ++i) + weight_qval_indices[i] = static_cast(qval_dis(gen)); + + // --- 3. Compute Expected Output using the IMPLICIT mappings --- + std::vector expected_output(m * n); + for (int m_idx = 0; m_idx < m; ++m_idx) { + for (int n_idx = 0; n_idx < n; ++n_idx) { + float res = 0.0f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + float activation_val = activations[m_idx * k + k_idx]; + int weight_idx = n_idx * k + k_idx; + uint8_t qval_idx = weight_qval_indices[weight_idx]; + + int32_t scale_idx = weight_idx / scale_group_size; + int32_t lut_idx = weight_idx / lut_group_size; + + // Dequantize: scale * LUT_value + float scale = weight_scales[scale_idx]; + float lut_val = weight_luts[lut_idx * lut_size + qval_idx]; + res += activation_val * (scale * lut_val); + } + res += bias_vec[n_idx]; + if (has_clamp) { + res = std::clamp(res, clamp_min, clamp_max); + } + expected_output[m_idx * n + n_idx] = res; } - res += bias_vec[n_idx]; - if (has_clamp) { res = std::clamp(res, clamp_min, clamp_max); } - expected_output[m_idx * n + n_idx] = res; } - } - - // --- 4. Construct and Return --- - return groupwise_lowbit_weight_lut_test_case( - m, k, n, scale_group_size, lut_group_size, weight_nbit, has_scales, - has_bias, has_clamp, clamp_min, clamp_max, - expected_output, - activations, - bias_vec, - weight_qval_indices, - weight_luts, - weight_scales); + // --- 4. Construct and Return --- + return groupwise_lowbit_weight_lut_test_case( + m, + k, + n, + scale_group_size, + lut_group_size, + weight_nbit, + has_scales, + has_bias, + has_clamp, + clamp_min, + clamp_max, + expected_output, + activations, + bias_vec, + weight_qval_indices, + weight_luts, + weight_scales); } -public: + public: /** - * @brief OVERLOAD 1: Simple generator where scales and LUTs share the same grouping. + * @brief OVERLOAD 1: Simple generator where scales and LUTs share the same + * grouping. * - * This is for the simplest case where a block of weights gets one scale and one LUT, - * and this pattern repeats. + * This is for the simplest case where a block of weights gets one scale and + * one LUT, and this pattern repeats. */ static groupwise_lowbit_weight_lut_test_case generate_per_group( - int m, int k, int n, - int group_size, // The size of the block for both scales and LUTs - int weight_nbit, bool has_scales, - bool has_bias, bool has_clamp) { - - std::cout << "[Generator Info] Using 'Per-Group' model.\n" - << " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl; - + int m, + int k, + int n, + int group_size, // The size of the block for both scales and LUTs + int weight_nbit, + bool has_scales, + bool has_bias, + bool has_clamp) { // Just call the decoupled generator with the same group size for both. return _generate_master( - m, k, n, - group_size, /* scale_group_size */ - group_size, /* lut_group_size */ - weight_nbit, - has_scales, - has_bias, has_clamp - ); + m, + k, + n, + group_size, /* scale_group_size */ + group_size, /* lut_group_size */ + weight_nbit, + has_scales, + has_bias, + has_clamp); } /** - * @brief OVERLOAD 2: Advanced generator with separate grouping for scales and LUTs. + * @brief OVERLOAD 2: Advanced generator with separate grouping for scales and + * LUTs. */ static groupwise_lowbit_weight_lut_test_case generate_with_decoupled_grouping( - int m, int k, int n, - int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales, - bool has_bias, bool has_clamp) { + int m, + int k, + int n, + int scale_group_size, + int lut_group_size, + int weight_nbit, + bool has_scales, + bool has_bias, + bool has_clamp) { + return _generate_master( + m, + k, + n, + scale_group_size, + lut_group_size, + weight_nbit, + has_scales, + has_bias, + has_clamp); + } +}; - std::cout << "[Generator Info] Using 'Decoupled Grouping' model.\n" - << " - Scales will switch every " << scale_group_size << " weights.\n" - << " - LUTs will switch every " << lut_group_size << " weights." << std::endl; +#if defined(__ARM_FEATURE_BF16) +std::vector to_bfloat16_vector(const std::vector& vec) { + std::vector bf16_vec(vec.size()); + for (size_t i = 0; i < vec.size(); ++i) { + // This conversion simulates the precision loss + bf16_vec[i] = vcvt_f32_bf16(vdup_n_f32(vec[i])); + } + return bf16_vec; +} - return _generate_master( - m, k, n, - scale_group_size, lut_group_size, - weight_nbit, has_scales, - has_bias, has_clamp +struct groupwise_lowbit_weight_lut_test_case_bf16 { + //-------------------------------------------------------------------------- + // Parameters + //-------------------------------------------------------------------------- + int m, k, n; + int scale_group_size; + int lut_group_size; + int weight_nbit; + bool has_scales, has_bias, has_clamp; + float clamp_min, clamp_max; + + //-------------------------------------------------------------------------- + // Data Tensors + //-------------------------------------------------------------------------- + std::vector expected_output; + std::vector activations; + std::vector bias; + std::vector + weight_qval_indices; // Indices into a LUT for each weight + std::vector weight_luts; + std::vector weight_scales; + + // ... existing constructor and generate functions ... + + // New generator for the BFMMLA kernel + static groupwise_lowbit_weight_lut_test_case generate( + int m, + int k, + int n, + int scale_group_size, + int lut_group_size, + int weight_nbit, + bool has_scales, + bool has_bias, + bool has_clamp) { + // 1. Generate float data first + // --- 0. Validation and Setup --- + const int total_weights = n * k; + // Frequencies are controlled by their group sizes. + assert(total_weights % scale_group_size == 0); + + // The number of unique scales/LUTs is derived directly from their group + // size. + const int num_scales = total_weights / scale_group_size; + const int num_luts = (total_weights + lut_group_size - 1) / lut_group_size; + const int lut_size = 1 << weight_nbit; + std::mt19937 gen(std::random_device{}()); + + // --- 1. Generate Primary Inputs --- + auto activations = get_random_vector(m * k, -1.0f, 1.0f); + std::vector bias_vec(n, 0.0f); + if (has_bias) + bias_vec = get_random_vector(n, -0.5f, 0.5f); + float clamp_min = -std::numeric_limits::infinity(), + clamp_max = std::numeric_limits::infinity(); + if (has_clamp) { + auto r = get_random_vector(2, -5.0f, 5.0f); + clamp_min = std::min(r[0], r[1]); + clamp_max = std::max(r[0], r[1]); + } + + // --- 2. Generate Quantization Data --- + // 2a. Generate the pools of unique scales and LUTs. + std::vector weight_scales; + if (has_scales) { + // Normal case: generate random scales. + weight_scales = get_random_vector(num_scales, 0.001f, 0.1f); + } else { + // LUT-only case: create a vector where every scale is 1.0f. + weight_scales.assign(num_scales, 1.0f); + } + + auto weight_luts = get_random_vector( + num_luts * lut_size, -0.2f, 0.2f); // Independent random LUTs + + // 2b. Generate random quantized indices for each weight. + auto weight_qval_indices = std::vector(total_weights); + std::uniform_int_distribution qval_dis(0, lut_size - 1); + for (int i = 0; i < total_weights; ++i) + weight_qval_indices[i] = static_cast(qval_dis(gen)); + + std::vector weight_scales_bf16 = + to_bfloat16_vector(weight_scales); + + std::vector weight_luts_bf16 = to_bfloat16_vector(weight_luts); + + // --- 3. Compute Expected Output using SIMULATED bfloat16 precision --- + std::vector expected_output(m * n); + for (int m_idx = 0; m_idx < m; ++m_idx) { + for (int n_idx = 0; n_idx < n; ++n_idx) { + float res = 0.0f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + float activation_val = activations[m_idx * k + k_idx]; + int weight_idx = n_idx * k + k_idx; + uint8_t qval_idx = weight_qval_indices[weight_idx]; + + int32_t scale_idx = weight_idx / scale_group_size; + int32_t lut_idx = weight_idx / lut_group_size; + + // Dequantize: scale * LUT_value + // CRITICAL CHANGE: Simulate bfloat16 precision before multiplying + bfloat16_t scale_bf16 = weight_scales_bf16[scale_idx]; + bfloat16_t lut_val_bf16 = + weight_luts_bf16[lut_idx * lut_size + qval_idx]; + float dequantized_weight = float(scale_bf16) * float(lut_val_bf16); + + res += activation_val * dequantized_weight; + } + res += bias_vec[n_idx]; + if (has_clamp) { + res = std::clamp(res, clamp_min, clamp_max); + } + expected_output[m_idx * n + n_idx] = res; + } + } + return groupwise_lowbit_weight_lut_test_case_bf16( + m, + k, + n, + scale_group_size, + lut_group_size, + weight_nbit, + has_scales, + has_bias, + has_clamp, + clamp_min, + clamp_max, + expected_output, + activations, + bias_vec, + weight_qval_indices, + weight_luts_bf16, // Pass the b16 version + weight_scales_bf16 // Pass the b16 version ); } -}; +}; // End of struct +#endif // defined(__ARM_FEATURE_BF16) } // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils_quantized_attention.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils_quantized_attention.h index 52fb0851bc..ba6fb83069 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils_quantized_attention.h @@ -8,9 +8,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include -#include -#include +#include +#include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_weight_packing.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_weight_packing.cpp similarity index 95% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_weight_packing.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_weight_packing.cpp index fba4fba391..b64d4b2754 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_weight_packing.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_weight_packing.cpp @@ -5,8 +5,8 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include +#include +#include template void test_weight_packing( diff --git a/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/interleave.cpp similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/interleave.cpp index 0274b0889e..3818fac2d0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/interleave.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/valpack.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/valpack.h diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/fallback/CMakeLists.txt new file mode 100644 index 0000000000..bf488ffab5 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +if (TORCHAO_BUILD_TESTS) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tests) +endif() diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/bitpack.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/bitpack.h new file mode 100644 index 0000000000..c28c6ec90d --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/bitpack.h @@ -0,0 +1,179 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::fallback::bitpacking { +namespace internal { +/** + * @brief Packs 128 unsigned 8-bit integers into a packed format of 'nbit' bits. + * + * @tparam nbit The number of bits to pack each value into (1-8). + * @param packed Pointer to the destination memory for the packed data. + * @param unpacked_values Pointer to the source memory with 128 uint8_t values. + */ +template +inline void pack_128_uint_values( + uint8_t* packed, + const uint8_t* unpacked_values) { + static_assert(nbit >= 1 && nbit <= 8, "nbit must be between 1 and 8"); + + // Dispatch to the correct packing function + if constexpr (nbit == 1) { + pack_128_uint1_values(packed, unpacked_values); + } else if constexpr (nbit == 2) { + pack_64_uint2_values(packed, unpacked_values); + pack_64_uint2_values(packed + 16, unpacked_values + 64); + } else if constexpr (nbit == 3) { + pack_128_uint3_values(packed, unpacked_values); + } else if constexpr (nbit == 4) { + pack_32_uint4_values(packed, unpacked_values); + pack_32_uint4_values(packed + 16, unpacked_values + 32); + pack_32_uint4_values(packed + 32, unpacked_values + 64); + pack_32_uint4_values(packed + 48, unpacked_values + 96); + } else if constexpr (nbit == 5) { + pack_128_uint5_values(packed, unpacked_values); + } else if constexpr (nbit == 6) { + pack_64_uint6_values(packed, unpacked_values); + pack_64_uint6_values(packed + 48, unpacked_values + 64); + } else if constexpr (nbit == 7) { + pack_128_uint7_values(packed, unpacked_values); + } else if constexpr (nbit == 8) { + // For 8-bit, it's a direct memory copy + for (int i = 0; i < 128; ++i) { + packed[i] = unpacked_values[i]; + } + } +} +/** + * @brief Unpacks 'nbit' data into 128 unsigned 8-bit integers. + * + * @tparam nbit The number of bits per value in the packed format (1-8). + * @param unpacked_values Pointer to the destination memory (128 uint8_t + * values). + * @param packed Pointer to the source packed data. + */ +template +inline void unpack_128_uint_values( + uint8_t* unpacked_values, + const uint8_t* packed) { + static_assert(nbit >= 1 && nbit <= 8, "nbit must be between 1 and 8"); + + // Dispatch to the correct unpacking function, writing directly to the output. + if constexpr (nbit == 1) { + unpack_128_uint1_values(unpacked_values, packed); + } else if constexpr (nbit == 2) { + unpack_64_uint2_values(unpacked_values, packed); + unpack_64_uint2_values(unpacked_values + 64, packed + 16); + } else if constexpr (nbit == 3) { + unpack_128_uint3_values(unpacked_values, packed); + } else if constexpr (nbit == 4) { + unpack_32_uint4_values(unpacked_values, packed); + unpack_32_uint4_values(unpacked_values + 32, packed + 16); + unpack_32_uint4_values(unpacked_values + 64, packed + 32); + unpack_32_uint4_values(unpacked_values + 96, packed + 48); + } else if constexpr (nbit == 5) { + unpack_128_uint5_values(unpacked_values, packed); + } else if constexpr (nbit == 6) { + unpack_64_uint6_values(unpacked_values, packed); + unpack_64_uint6_values(unpacked_values + 64, packed + 48); + } else if constexpr (nbit == 7) { + unpack_128_uint7_values(unpacked_values, packed); + } else if constexpr (nbit == 8) { + // For 8-bit, it's a direct memory copy + for (int i = 0; i < 128; ++i) { + unpacked_values[i] = packed[i]; + } + } +} + +/** + * @brief Packs 128 signed 8-bit integers into a packed format of 'nbit' bits. + * + * @tparam nbit The number of bits to pack each value into (1-8). + * @param packed Pointer to the destination memory. + * @param unpacked Pointer to the source memory containing 128 int8_t values. + */ +template +inline void pack_128_lowbit_int_values( + uint8_t* packed, + const int8_t* unpacked) { + // 1. Convert signed input to a temporary buffer of unsigned values. + uint8_t temp_unpacked[128]; + if constexpr (nbit < 8) { + const int8_t shift = 1 << (nbit - 1); + for (int i = 0; i < 128; ++i) { + temp_unpacked[i] = static_cast(unpacked[i] + shift); + } + } else { // nbit == 8 + for (int i = 0; i < 128; ++i) { + temp_unpacked[i] = static_cast(unpacked[i]); + } + } + + // 2. Call the generalized uint packing function. + pack_128_uint_values(packed, temp_unpacked); +} + +template +inline void unpack_128_lowbit_int_values( + int8_t* unpacked, + const uint8_t* packed) { + // 1. Get the raw unsigned values by calling the base function. + uint8_t temp_unpacked[128]; + unpack_128_uint_values(temp_unpacked, packed); + + // 2. Perform the signed conversion. + if constexpr (nbit < 8) { + const int8_t unshift = -(1 << (nbit - 1)); + for (int i = 0; i < 128; ++i) { + unpacked[i] = static_cast(temp_unpacked[i]) + unshift; + } + } else { // nbit == 8 + for (int i = 0; i < 128; ++i) { + unpacked[i] = static_cast(temp_unpacked[i]); + } + } +} + +/** + * @brief Unpacks 'nbit' data and de-quantizes it using a lookup table (LUT). + * + * @tparam nbit The number of bits per value in the packed format (1-4). + * @param unpacked Pointer to the destination memory (128 int8_t values). + * @param packed Pointer to the source packed data. + * @param lut Pointer to the lookup table (must have 2^nbit entries). + */ +template +inline void unpack_128_lowbit_values_with_lut( + int8_t* unpacked, + const uint8_t* packed, + const int8_t* lut) { + static_assert(nbit >= 1 && nbit <= 4, "LUT version only supports nbit <= 4"); + + // Create a temporary buffer on the stack for the indices. + uint8_t indices[128]; + + // 1. Call the utility function to handle all the unpacking logic. + unpack_128_uint_values(indices, packed); + + // 2. Apply the lookup table. + for (int i = 0; i < 128; ++i) { + unpacked[i] = lut[indices[i]]; + } +} +} // namespace internal +} // namespace torchao::kernels::cpu::fallback::bitpacking diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint1.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint1.h new file mode 100644 index 0000000000..08e231716b --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint1.h @@ -0,0 +1,154 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace torchao::kernels::cpu::fallback::bitpacking { +namespace internal { + +/** + * @brief Packs 8 bytes, each containing a 1-bit value (0 or 1), into a single + * byte. + * @param packed Pointer to the destination memory (1 byte). + * @param unpacked Pointer to the source memory (8 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void pack_8_uint1_values( + uint8_t* packed, + const uint8_t* unpacked) { + packed[0] = (unpacked[0] << 7) | (unpacked[1] << 6) | (unpacked[2] << 5) | + (unpacked[3] << 4) | (unpacked[4] << 3) | (unpacked[5] << 2) | + (unpacked[6] << 1) | (unpacked[7] << 0); +} + +/** + * @brief Unpacks a single byte into 8 bytes, each containing a 1-bit value. + * @param unpacked Pointer to the destination memory (8 bytes). + * @param packed Pointer to the source memory (1 byte). + */ +TORCHAO_ALWAYS_INLINE inline void unpack_8_uint1_values( + uint8_t* unpacked, + const uint8_t* packed) { + const uint8_t packed_byte = packed[0]; + unpacked[0] = (packed_byte >> 7) & 1; + unpacked[1] = (packed_byte >> 6) & 1; + unpacked[2] = (packed_byte >> 5) & 1; + unpacked[3] = (packed_byte >> 4) & 1; + unpacked[4] = (packed_byte >> 3) & 1; + unpacked[5] = (packed_byte >> 2) & 1; + unpacked[6] = (packed_byte >> 1) & 1; + unpacked[7] = (packed_byte >> 0) & 1; +} + +/** + * @brief Packs 64 bytes (each a 1-bit value) into 8 bytes. + * @param packed Pointer to the destination memory (8 bytes). + * @param unpacked Pointer to the source memory (64 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_64_uint1_values` function to ensure compatibility. The unpacked + * data is assumed to be organized as four 16-byte blocks. + */ +TORCHAO_ALWAYS_INLINE inline void pack_64_uint1_values( + uint8_t* packed, + const uint8_t* unpacked) { + const uint8_t* unpacked0 = unpacked; + const uint8_t* unpacked1 = unpacked + 16; + const uint8_t* unpacked2 = unpacked + 32; + const uint8_t* unpacked3 = unpacked + 48; + + for (int i = 0; i < 8; ++i) { + // Combine 4 bits for the low nibble of the output byte + uint8_t low_nibble = (unpacked0[i] << 3) | (unpacked1[i] << 2) | + (unpacked2[i] << 1) | (unpacked3[i] << 0); + + // Combine 4 bits for the high nibble of the output byte + uint8_t high_nibble_src = (unpacked0[i + 8] << 3) | + (unpacked1[i + 8] << 2) | (unpacked2[i + 8] << 1) | + (unpacked3[i + 8] << 0); + + // Assemble the final byte + packed[i] = low_nibble | (high_nibble_src << 4); + } +} + +/** + * @brief Unpacks 8 bytes into 64 bytes (each a 1-bit value). + * @param unpacked Pointer to the destination memory (64 bytes). + * @param packed Pointer to the source memory (8 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_64_uint1_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_64_uint1_values( + uint8_t* unpacked, + const uint8_t* packed) { + uint8_t* unpacked0 = unpacked; + uint8_t* unpacked1 = unpacked + 16; + uint8_t* unpacked2 = unpacked + 32; + uint8_t* unpacked3 = unpacked + 48; + + uint8_t combined[16]; + for (int i = 0; i < 8; ++i) { + combined[i] = packed[i] & 0x0F; // Low nibbles + combined[i + 8] = packed[i] >> 4; // High nibbles + } + + // Unpack from the combined buffer into the four destination blocks + for (int i = 0; i < 16; ++i) { + const uint8_t temp = combined[i]; + unpacked0[i] = (temp >> 3) & 1; + unpacked1[i] = (temp >> 2) & 1; + unpacked2[i] = (temp >> 1) & 1; + unpacked3[i] = (temp >> 0) & 1; + } +} + +/** + * @brief Packs 128 bytes (each a 1-bit value) into 16 bytes. + * @param packed Pointer to the destination memory (16 bytes). + * @param unpacked Pointer to the source memory (128 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_128_uint1_values` function (a transpose-and-pack operation) to + * ensure compatibility. The unpacked data is assumed to be organized as eight + * 16-byte blocks. + */ +TORCHAO_ALWAYS_INLINE inline void pack_128_uint1_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 16; ++i) { + packed[i] = (unpacked[i + 16 * 0] << 7) | (unpacked[i + 16 * 1] << 6) | + (unpacked[i + 16 * 2] << 5) | (unpacked[i + 16 * 3] << 4) | + (unpacked[i + 16 * 4] << 3) | (unpacked[i + 16 * 5] << 2) | + (unpacked[i + 16 * 6] << 1) | (unpacked[i + 16 * 7] << 0); + } +} + +/** + * @brief Unpacks 16 bytes into 128 bytes (each a 1-bit value). + * @param unpacked Pointer to the destination memory (128 bytes). + * @param packed Pointer to the source memory (16 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_128_uint1_values` function (an unpack-and-transpose operation) + * to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_128_uint1_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 16; ++i) { + const uint8_t packed_byte = packed[i]; + unpacked[i + 16 * 0] = (packed_byte >> 7) & 1; + unpacked[i + 16 * 1] = (packed_byte >> 6) & 1; + unpacked[i + 16 * 2] = (packed_byte >> 5) & 1; + unpacked[i + 16 * 3] = (packed_byte >> 4) & 1; + unpacked[i + 16 * 4] = (packed_byte >> 3) & 1; + unpacked[i + 16 * 5] = (packed_byte >> 2) & 1; + unpacked[i + 16 * 6] = (packed_byte >> 1) & 1; + unpacked[i + 16 * 7] = (packed_byte >> 0) & 1; + } +} +} // namespace internal +} // namespace torchao::kernels::cpu::fallback::bitpacking diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint2.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint2.h new file mode 100644 index 0000000000..9dc1cce463 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint2.h @@ -0,0 +1,119 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +namespace torchao::kernels::cpu::fallback::bitpacking { +namespace internal { + +/** + * @brief Packs 4 bytes, each containing a 2-bit value (0-3), into a single + * byte. + * @param packed Pointer to the destination memory (1 byte). + * @param unpacked Pointer to the source memory (4 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void pack_4_uint2_values( + uint8_t* packed, + const uint8_t* unpacked) { + // unpacked = {v0, v1, v2, v3} -> packed[0] = | v0 | v1 | v2 | v3 | + packed[0] = (unpacked[0] << 6) | (unpacked[1] << 4) | (unpacked[2] << 2) | + (unpacked[3]); +} + +/** + * @brief Unpacks a single byte into 4 bytes, each containing a 2-bit value. + * @param unpacked Pointer to the destination memory (4 bytes). + * @param packed Pointer to the source memory (1 byte). + */ +TORCHAO_ALWAYS_INLINE inline void unpack_4_uint2_values( + uint8_t* unpacked, + const uint8_t* packed) { + unpacked[0] = (packed[0] >> 6) & 0x03; // Mask 0b11000000 + unpacked[1] = (packed[0] >> 4) & 0x03; // Mask 0b00110000 + unpacked[2] = (packed[0] >> 2) & 0x03; // Mask 0b00001100 + unpacked[3] = packed[0] & 0x03; // Mask 0b00000011 +} + +/** + * @brief Packs 32 bytes (each a 2-bit value) into 8 bytes. + * @param packed Pointer to the destination memory (8 bytes). + * @param unpacked Pointer to the source memory (32 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_32_uint2_values` function (a transpose-and-pack operation) to + * ensure compatibility. The unpacked data is assumed to be organized as four + * 8-byte blocks. + */ +TORCHAO_ALWAYS_INLINE inline void pack_32_uint2_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 8; ++i) { + packed[i] = (unpacked[i + 8 * 0] << 6) | (unpacked[i + 8 * 1] << 4) | + (unpacked[i + 8 * 2] << 2) | (unpacked[i + 8 * 3] << 0); + } +} + +/** + * @brief Unpacks 8 bytes into 32 bytes (each a 2-bit value). + * @param unpacked Pointer to the destination memory (32 bytes). + * @param packed Pointer to the source memory (8 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_32_uint2_values` function (an unpack-and-transpose operation) + * to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_32_uint2_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 8; ++i) { + const uint8_t packed_byte = packed[i]; + unpacked[i + 8 * 0] = (packed_byte >> 6) & 0x03; + unpacked[i + 8 * 1] = (packed_byte >> 4) & 0x03; + unpacked[i + 8 * 2] = (packed_byte >> 2) & 0x03; + unpacked[i + 8 * 3] = (packed_byte >> 0) & 0x03; + } +} + +/** + * @brief Packs 64 bytes (each a 2-bit value) into 16 bytes. + * @param packed Pointer to the destination memory (16 bytes). + * @param unpacked Pointer to the source memory (64 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_64_uint2_values` function (a transpose-and-pack operation) to + * ensure compatibility. The unpacked data is assumed to be organized as four + * 16-byte blocks. + */ +TORCHAO_ALWAYS_INLINE inline void pack_64_uint2_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 16; ++i) { + packed[i] = (unpacked[i + 16 * 0] << 6) | (unpacked[i + 16 * 1] << 4) | + (unpacked[i + 16 * 2] << 2) | (unpacked[i + 16 * 3] << 0); + } +} + +/** + * @brief Unpacks 16 bytes into 64 bytes (each a 2-bit value). + * @param unpacked Pointer to the destination memory (64 bytes). + * @param packed Pointer to the source memory (16 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_64_uint2_values` function (an unpack-and-transpose operation) + * to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_64_uint2_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 16; ++i) { + const uint8_t packed_byte = packed[i]; + unpacked[i + 16 * 0] = (packed_byte >> 6) & 0x03; + unpacked[i + 16 * 1] = (packed_byte >> 4) & 0x03; + unpacked[i + 16 * 2] = (packed_byte >> 2) & 0x03; + unpacked[i + 16 * 3] = (packed_byte >> 0) & 0x03; + } +} + +} // namespace internal +} // namespace torchao::kernels::cpu::fallback::bitpacking diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint3.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint3.h new file mode 100644 index 0000000000..277317d5a2 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint3.h @@ -0,0 +1,195 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace torchao::kernels::cpu::fallback::bitpacking { +namespace internal { + +/** + * @brief Packs 8 bytes, each holding a 3-bit value (0-7), into 3 bytes. + * + * The packing scheme is non-trivial. Given 8 input values v0..v7, they are + * arranged into 3 bytes (b0, b1, b2) as follows: + * - b0: [v6(low 2 bits), v0(all 3 bits), v1(all 3 bits)] + * - b1: [v7(low 2 bits), v2(all 3 bits), v3(all 3 bits)] + * - b2: [v6(high 1 bit), v7(high 1 bit), v4(all 3 bits), v5(all 3 bits)] + * + * @param packed Pointer to the destination memory (3 bytes). + * @param unpacked Pointer to the source memory (8 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void pack_8_uint3_values( + uint8_t* packed, + const uint8_t* unpacked) { + // byte 0 + packed[0] = ((unpacked[6] & 0x03) << 6) | ((unpacked[0] & 0x07) << 3) | + (unpacked[1] & 0x07); + + // byte 1 + packed[1] = ((unpacked[7] & 0x03) << 6) | ((unpacked[2] & 0x07) << 3) | + (unpacked[3] & 0x07); + + // byte 2 + packed[2] = ((unpacked[6] & 0x04) << 5) | ((unpacked[7] & 0x04) << 4) | + ((unpacked[4] & 0x07) << 3) | (unpacked[5] & 0x07); +} + +/** + * @brief Unpacks 3 bytes into 8 bytes, each containing a 3-bit value. + * @param unpacked Pointer to the destination memory (8 bytes). + * @param packed Pointer to the source memory (3 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void unpack_8_uint3_values( + uint8_t* unpacked, + const uint8_t* packed) { + const uint8_t b0 = packed[0]; + const uint8_t b1 = packed[1]; + const uint8_t b2 = packed[2]; + + unpacked[0] = (b0 >> 3) & 0x07; + unpacked[1] = b0 & 0x07; + + unpacked[2] = (b1 >> 3) & 0x07; + unpacked[3] = b1 & 0x07; + + unpacked[4] = (b2 >> 3) & 0x07; + unpacked[5] = b2 & 0x07; + + unpacked[6] = (b0 >> 6) | ((b2 >> 5) & 0x04); + unpacked[7] = (b1 >> 6) | ((b2 >> 4) & 0x04); +} + +/** + * @brief Packs 64 bytes (each a 3-bit value) into 24 bytes. + * @param packed Pointer to the destination memory (24 bytes). + * @param unpacked Pointer to the source memory (64 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_64_uint3_values` function (a transpose-and-pack operation) to + * ensure compatibility. The unpacked data is assumed to be organized as eight + * 8-byte blocks. + */ +TORCHAO_ALWAYS_INLINE inline void pack_64_uint3_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 8; ++i) { + const uint8_t unpacked0 = unpacked[i + 8 * 0]; + const uint8_t unpacked1 = unpacked[i + 8 * 1]; + const uint8_t unpacked2 = unpacked[i + 8 * 2]; + const uint8_t unpacked3 = unpacked[i + 8 * 3]; + const uint8_t unpacked4 = unpacked[i + 8 * 4]; + const uint8_t unpacked5 = unpacked[i + 8 * 5]; + const uint8_t unpacked6 = unpacked[i + 8 * 6]; + const uint8_t unpacked7 = unpacked[i + 8 * 7]; + + // byte 0 + packed[i] = ((unpacked6 & 0x03) << 6) | ((unpacked0 & 0x07) << 3) | + (unpacked1 & 0x07); + + // byte 1 + packed[i + 8] = ((unpacked7 & 0x03) << 6) | ((unpacked2 & 0x07) << 3) | + (unpacked3 & 0x07); + + // byte 2 + packed[i + 16] = ((unpacked6 & 0x04) << 5) | ((unpacked7 & 0x04) << 4) | + ((unpacked4 & 0x07) << 3) | (unpacked5 & 0x07); + } +} + +/** + * @brief Unpacks 24 bytes into 64 bytes (each a 3-bit value). + * @param unpacked Pointer to the destination memory (64 bytes). + * @param packed Pointer to the source memory (24 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_64_uint3_values` function (an unpack-and-transpose operation) + * to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_64_uint3_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 8; ++i) { + const uint8_t b0 = packed[i]; + const uint8_t b1 = packed[i + 8]; + const uint8_t b2 = packed[i + 16]; + + unpacked[i + 8 * 0] = (b0 >> 3) & 0x07; + unpacked[i + 8 * 1] = b0 & 0x07; + unpacked[i + 8 * 2] = (b1 >> 3) & 0x07; + unpacked[i + 8 * 3] = b1 & 0x07; + unpacked[i + 8 * 4] = (b2 >> 3) & 0x07; + unpacked[i + 8 * 5] = b2 & 0x07; + unpacked[i + 8 * 6] = (b0 >> 6) | ((b2 >> 5) & 0x04); + unpacked[i + 8 * 7] = (b1 >> 6) | ((b2 >> 4) & 0x04); + } +} + +/** + * @brief Packs 128 bytes (each a 3-bit value) into 48 bytes. + * @param packed Pointer to the destination memory (48 bytes). + * @param unpacked Pointer to the source memory (128 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_128_uint3_values` function (a transpose-and-pack operation) to + * ensure compatibility. The unpacked data is assumed to be organized as eight + * 16-byte blocks. + */ +TORCHAO_ALWAYS_INLINE inline void pack_128_uint3_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 16; ++i) { + const uint8_t unpacked0 = unpacked[i + 16 * 0]; + const uint8_t unpacked1 = unpacked[i + 16 * 1]; + const uint8_t unpacked2 = unpacked[i + 16 * 2]; + const uint8_t unpacked3 = unpacked[i + 16 * 3]; + const uint8_t unpacked4 = unpacked[i + 16 * 4]; + const uint8_t unpacked5 = unpacked[i + 16 * 5]; + const uint8_t unpacked6 = unpacked[i + 16 * 6]; + const uint8_t unpacked7 = unpacked[i + 16 * 7]; + + // byte 0 + packed[i] = ((unpacked6 & 0x03) << 6) | ((unpacked0 & 0x07) << 3) | + (unpacked1 & 0x07); + + // byte 1 + packed[i + 16] = ((unpacked7 & 0x03) << 6) | ((unpacked2 & 0x07) << 3) | + (unpacked3 & 0x07); + + // byte 2 + packed[i + 32] = ((unpacked6 & 0x04) << 5) | ((unpacked7 & 0x04) << 4) | + ((unpacked4 & 0x07) << 3) | (unpacked5 & 0x07); + } +} + +/** + * @brief Unpacks 48 bytes into 128 bytes (each a 3-bit value). + * @param unpacked Pointer to the destination memory (128 bytes). + * @param packed Pointer to the source memory (48 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_128_uint3_values` function (an unpack-and-transpose operation) + * to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_128_uint3_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 16; ++i) { + const uint8_t b0 = packed[i]; + const uint8_t b1 = packed[i + 16]; + const uint8_t b2 = packed[i + 32]; + + unpacked[i + 16 * 0] = (b0 >> 3) & 0x07; + unpacked[i + 16 * 1] = b0 & 0x07; + unpacked[i + 16 * 2] = (b1 >> 3) & 0x07; + unpacked[i + 16 * 3] = b1 & 0x07; + unpacked[i + 16 * 4] = (b2 >> 3) & 0x07; + unpacked[i + 16 * 5] = b2 & 0x07; + unpacked[i + 16 * 6] = (b0 >> 6) | ((b2 >> 5) & 0x04); + unpacked[i + 16 * 7] = (b1 >> 6) | ((b2 >> 4) & 0x04); + } +} + +} // namespace internal +} // namespace torchao::kernels::cpu::fallback::bitpacking diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint4.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint4.h new file mode 100644 index 0000000000..4b98a47143 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint4.h @@ -0,0 +1,109 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace torchao::kernels::cpu::fallback::bitpacking { +namespace internal { +/** + * @brief Packs 2 bytes, each holding a 4-bit value (0-15), into a single + * byte. The first value goes into the high nibble, the second into the low + * nibble. + * @param packed Pointer to the destination memory (1 byte). + * @param unpacked Pointer to the source memory (2 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void pack_2_uint4_values( + uint8_t* packed, + const uint8_t* unpacked) { + // This is compatible with the scalar NEON version. + packed[0] = (unpacked[0] << 4) | (unpacked[1] & 0x0F); +} + +/** + * @brief Unpacks a single byte into 2 bytes, each containing a 4-bit value. + * @param unpacked Pointer to the destination memory (2 bytes). + * @param packed Pointer to the source memory (1 byte). + */ +TORCHAO_ALWAYS_INLINE inline void unpack_2_uint4_values( + uint8_t* unpacked, + const uint8_t* packed) { + // This is compatible with the scalar NEON version. + unpacked[0] = packed[0] >> 4; + unpacked[1] = packed[0] & 0x0F; +} + +/** + * @brief Packs 16 bytes (each a 4-bit value) into 8 bytes. + * @param packed Pointer to the destination memory (8 bytes). + * @param unpacked Pointer to the source memory (16 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_16_uint4_values` function (a transpose-and-pack operation) to + * ensure compatibility. It packs unpacked[i] and unpacked[i+8] into + * packed[i]. + */ +TORCHAO_ALWAYS_INLINE inline void pack_16_uint4_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 8; ++i) { + packed[i] = ((unpacked[i + 8] & 0x0F) << 4) | (unpacked[i] & 0x0F); + } +} + +/** + * @brief Unpacks 8 bytes into 16 bytes (each a 4-bit value). + * @param unpacked Pointer to the destination memory (16 bytes). + * @param packed Pointer to the source memory (8 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_16_uint4_values` function (an unpack-and-transpose operation) + * to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_16_uint4_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 8; ++i) { + unpacked[i] = packed[i] & 0x0F; + unpacked[i + 8] = packed[i] >> 4; + } +} + +/** + * @brief Packs 32 bytes (each a 4-bit value) into 16 bytes. + * @param packed Pointer to the destination memory (16 bytes). + * @param unpacked Pointer to the source memory (32 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_32_uint4_values` function (a transpose-and-pack operation) to + * ensure compatibility. It packs unpacked[i] and unpacked[i+16] into + * packed[i]. + */ +TORCHAO_ALWAYS_INLINE inline void pack_32_uint4_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 16; ++i) { + packed[i] = ((unpacked[i + 16] & 0x0F) << 4) | (unpacked[i] & 0x0F); + } +} + +/** + * @brief Unpacks 16 bytes into 32 bytes (each a 4-bit value). + * @param unpacked Pointer to the destination memory (32 bytes). + * @param packed Pointer to the source memory (16 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_32_uint4_values` function (an unpack-and-transpose operation) + * to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_32_uint4_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 16; ++i) { + unpacked[i] = packed[i] & 0x0F; + unpacked[i + 16] = packed[i] >> 4; + } +} +} // namespace internal +} // namespace torchao::kernels::cpu::fallback::bitpacking diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint5.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint5.h new file mode 100644 index 0000000000..3de577e05f --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint5.h @@ -0,0 +1,175 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace torchao::kernels::cpu::fallback::bitpacking { +namespace internal { + +/** + * @brief Packs 8 bytes, each holding a 5-bit value (0-31), into 5 bytes. + * + * @param packed Pointer to the destination memory (5 bytes). + * @param unpacked Pointer to the source memory (8 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void pack_8_uint5_values( + uint8_t* packed, + const uint8_t* unpacked) { + // pack 8 uint5 values (u0..u7) into 5 bytes (p0..p4) + // p0 = u0_all | u1_low_3_bits + // p1 = u2_all | u3_low_3_bits + // p2 = u4_all | u5_low_3_bits + // p3 = u6_all | u7_low_3_bits + // p4 = u1_high_2_bits | u3_high_2_bits | u5_high_2_bits | u7_high_2_bits + packed[0] = (unpacked[0] & 0x1F) | ((unpacked[1] & 0x1F) << 5); + packed[1] = (unpacked[2] & 0x1F) | ((unpacked[3] & 0x1F) << 5); + packed[2] = (unpacked[4] & 0x1F) | ((unpacked[5] & 0x1F) << 5); + packed[3] = (unpacked[6] & 0x1F) | ((unpacked[7] & 0x1F) << 5); + packed[4] = ((unpacked[1] & 0x1F) >> 3) | (((unpacked[3] & 0x1F) >> 3) << 2) | + (((unpacked[5] & 0x1F) >> 3) << 4) | (((unpacked[7] & 0x1F) >> 3) << 6); +} + +/** + * @brief Unpacks 5 bytes into 8 bytes, each containing a 5-bit value. + * + * @param unpacked Pointer to the destination memory (8 bytes). + * @param packed Pointer to the source memory (5 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void unpack_8_uint5_values( + uint8_t* unpacked, + const uint8_t* packed) { + const uint8_t p0 = packed[0]; + const uint8_t p1 = packed[1]; + const uint8_t p2 = packed[2]; + const uint8_t p3 = packed[3]; + const uint8_t p4 = packed[4]; + + // This is compatible with the scalar NEON version. + unpacked[0] = p0 & 0x1F; + unpacked[1] = (p0 >> 5) | ((p4 & 0x03) << 3); + unpacked[2] = p1 & 0x1F; + unpacked[3] = (p1 >> 5) | ((p4 & 0x0C) << 1); + unpacked[4] = p2 & 0x1F; + unpacked[5] = (p2 >> 5) | ((p4 & 0x30) >> 1); + unpacked[6] = p3 & 0x1F; + unpacked[7] = (p3 >> 5) | ((p4 & 0xC0) >> 3); +} + +/** + * @brief Packs 64 bytes (each a 5-bit value) into 40 bytes. + * @param packed Pointer to the destination memory (40 bytes). + * @param unpacked Pointer to the source memory (64 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_64_uint5_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void pack_64_uint5_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Pack the first 32 bytes (p0, p1) + for (int i = 0; i < 16; ++i) { + packed[i] = (unpacked[i] & 0x1F) | ((unpacked[i + 16] & 0x1F) << 5); + packed[i + 16] = (unpacked[i + 32] & 0x1F) | ((unpacked[i + 48] & 0x1F) << 5); + } + + // Pack the final 8 bytes (p2) + for (int i = 0; i < 8; ++i) { + uint8_t val1 = (unpacked[16 + i] >> 3) & 0x03; + uint8_t val2 = (unpacked[24 + i] >> 3) & 0x03; + uint8_t val3 = (unpacked[48 + i] >> 3) & 0x03; + uint8_t val4 = (unpacked[56 + i] >> 3) & 0x03; + packed[32 + i] = val1 | (val2 << 2) | (val3 << 4) | (val4 << 6); + } +} + +/** + * @brief Unpacks 40 bytes into 64 bytes (each a 5-bit value). + * @param unpacked Pointer to the destination memory (64 bytes). + * @param packed Pointer to the source memory (40 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_64_uint5_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_64_uint5_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 16; ++i) { + const uint8_t p0 = packed[i]; + const uint8_t p1 = packed[i + 16]; + // p2 is only 8 bytes wide, so we use modulo to access it correctly. + const uint8_t p2 = packed[32 + (i % 8)]; + + unpacked[i] = p0 & 0x1F; + unpacked[i + 32] = p1 & 0x1F; + + if (i < 8) { + unpacked[i + 16] = (p0 >> 5) | ((p2 & 0x03) << 3); + unpacked[i + 48] = (p1 >> 5) | ((p2 & 0x30) >> 1); + } else { + unpacked[i + 16] = (p0 >> 5) | ((p2 & 0x0C) << 1); + unpacked[i + 48] = (p1 >> 5) | ((p2 & 0xC0) >> 3); + } + } +} + +/** + * @brief Packs 128 bytes (each a 5-bit value) into 80 bytes. + * @param packed Pointer to the destination memory (80 bytes). + * @param unpacked Pointer to the source memory (128 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_128_uint5_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void pack_128_uint5_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Pack the first 64 bytes (p0, p1, p2, p3) + for (int i = 0; i < 16; ++i) { + packed[i] = (unpacked[i] & 0x1F) | ((unpacked[i + 16] & 0x1F) << 5); + packed[i + 16] = (unpacked[i + 32] & 0x1F) | ((unpacked[i + 48] & 0x1F) << 5); + packed[i + 32] = (unpacked[i + 64] & 0x1F) | ((unpacked[i + 80] & 0x1F) << 5); + packed[i + 48] = (unpacked[i + 96] & 0x1F) | ((unpacked[i + 112] & 0x1F) << 5); + } + + // Pack the final 16 bytes (p4) + for (int i = 0; i < 16; ++i) { + uint8_t val1 = (unpacked[16 + i] >> 3) & 0x03; + uint8_t val2 = (unpacked[48 + i] >> 3) & 0x03; + uint8_t val3 = (unpacked[80 + i] >> 3) & 0x03; + uint8_t val4 = (unpacked[112 + i] >> 3) & 0x03; + packed[64 + i] = val1 | (val2 << 2) | (val3 << 4) | (val4 << 6); + } +} + +/** + * @brief Unpacks 80 bytes into 128 bytes (each a 5-bit value). + * @param unpacked Pointer to the destination memory (128 bytes). + * @param packed Pointer to the source memory (80 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_128_uint5_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_128_uint5_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 16; ++i) { + const uint8_t p0 = packed[i]; + const uint8_t p1 = packed[i + 16]; + const uint8_t p2 = packed[i + 32]; + const uint8_t p3 = packed[i + 48]; + const uint8_t p4 = packed[i + 64]; + + unpacked[i + 16 * 0] = p0 & 0x1F; + unpacked[i + 16 * 1] = (p0 >> 5) | ((p4 & 0x03) << 3); + unpacked[i + 16 * 2] = p1 & 0x1F; + unpacked[i + 16 * 3] = (p1 >> 5) | ((p4 & 0x0C) << 1); + unpacked[i + 16 * 4] = p2 & 0x1F; + unpacked[i + 16 * 5] = (p2 >> 5) | ((p4 & 0x30) >> 1); + unpacked[i + 16 * 6] = p3 & 0x1F; + unpacked[i + 16 * 7] = (p3 >> 5) | ((p4 & 0xC0) >> 3); + } +} + +}} diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint6.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint6.h new file mode 100644 index 0000000000..2fcd9334ec --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint6.h @@ -0,0 +1,142 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace torchao::kernels::cpu::fallback::bitpacking { +namespace internal { + +/** + * @brief Packs 4 bytes, each holding a 6-bit value (0-63), into 3 bytes. + * + * @param packed Pointer to the destination memory (3 bytes). + * @param unpacked Pointer to the source memory (4 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values( + uint8_t* packed, + const uint8_t* unpacked) { + // pack 4 uint6 values (u0..u3) into 3 bytes (p0..p2) + // p0's low 6 bits = u0; p0's high 2 bits = u3's low 2 bits + // p1's low 6 bits = u1; p1's high 2 bits = u3's mid 2 bits + // p2's low 6 bits = u2; p2's high 2 bits = u3's high 2 bits + const uint8_t u3 = unpacked[3] & 0x3F; + packed[0] = (unpacked[0] & 0x3F) | ((u3 & 0x03) << 6); + packed[1] = (unpacked[1] & 0x3F) | ((u3 & 0x0C) << 4); + packed[2] = (unpacked[2] & 0x3F) | ((u3 & 0x30) << 2); +} + +/** + * @brief Unpacks 3 bytes into 4 bytes, each containing a 6-bit value. + * + * @param unpacked Pointer to the destination memory (4 bytes). + * @param packed Pointer to the source memory (3 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values( + uint8_t* unpacked, + const uint8_t* packed) { + // This is compatible with the scalar NEON version. + unpacked[0] = packed[0] & 0x3F; + unpacked[1] = packed[1] & 0x3F; + unpacked[2] = packed[2] & 0x3F; + unpacked[3] = ((packed[0] & 0xC0) >> 6) | ((packed[1] & 0xC0) >> 4) | + ((packed[2] & 0xC0) >> 2); +} + +/** + * @brief Packs 32 bytes (each a 6-bit value) into 24 bytes. + * @param packed Pointer to the destination memory (24 bytes). + * @param unpacked Pointer to the source memory (32 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_32_uint6_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void pack_32_uint6_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 8; ++i) { + const uint8_t u0 = unpacked[i]; + const uint8_t u1 = unpacked[i + 8]; + const uint8_t u2 = unpacked[i + 16]; + const uint8_t u3 = unpacked[i + 24]; + + packed[i] = (u0 & 0x3F) | ((u3 & 0x03) << 6); + packed[i + 8] = (u1 & 0x3F) | ((u3 & 0x0C) << 4); + packed[i + 16] = (u2 & 0x3F) | ((u3 & 0x30) << 2); + } +} + +/** + * @brief Unpacks 24 bytes into 32 bytes (each a 6-bit value). + * @param unpacked Pointer to the destination memory (32 bytes). + * @param packed Pointer to the source memory (24 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_32_uint6_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_32_uint6_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 8; ++i) { + const uint8_t p0 = packed[i]; + const uint8_t p1 = packed[i + 8]; + const uint8_t p2 = packed[i + 16]; + + unpacked[i] = p0 & 0x3F; + unpacked[i + 8] = p1 & 0x3F; + unpacked[i + 16] = p2 & 0x3F; + unpacked[i + 24] = + ((p0 & 0xC0) >> 6) | ((p1 & 0xC0) >> 4) | ((p2 & 0xC0) >> 2); + } +} + +/** + * @brief Packs 64 bytes (each a 6-bit value) into 48 bytes. + * @param packed Pointer to the destination memory (48 bytes). + * @param unpacked Pointer to the source memory (64 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_64_uint6_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void pack_64_uint6_values( + uint8_t* packed, + const uint8_t* unpacked) { + for (int i = 0; i < 16; ++i) { + const uint8_t u0 = unpacked[i]; + const uint8_t u1 = unpacked[i + 16]; + const uint8_t u2 = unpacked[i + 32]; + const uint8_t u3 = unpacked[i + 48]; + + packed[i] = (u0 & 0x3F) | ((u3 & 0x03) << 6); + packed[i + 16] = (u1 & 0x3F) | ((u3 & 0x0C) << 4); + packed[i + 32] = (u2 & 0x3F) | ((u3 & 0x30) << 2); + } +} + +/** + * @brief Unpacks 48 bytes into 64 bytes (each a 6-bit value). + * @param unpacked Pointer to the destination memory (64 bytes). + * @param packed Pointer to the source memory (48 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_64_uint6_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_64_uint6_values( + uint8_t* unpacked, + const uint8_t* packed) { + for (int i = 0; i < 16; ++i) { + const uint8_t p0 = packed[i]; + const uint8_t p1 = packed[i + 16]; + const uint8_t p2 = packed[i + 32]; + + unpacked[i] = p0 & 0x3F; + unpacked[i + 16] = p1 & 0x3F; + unpacked[i + 32] = p2 & 0x3F; + unpacked[i + 48] = + ((p0 & 0xC0) >> 6) | ((p1 & 0xC0) >> 4) | ((p2 & 0xC0) >> 2); + } +} + +} // namespace internal +} // namespace torchao::kernels::cpu::fallback::bitpacking diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint7.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint7.h new file mode 100644 index 0000000000..60493a20b2 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint7.h @@ -0,0 +1,140 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace torchao::kernels::cpu::fallback::bitpacking { +namespace internal { +/** + * @brief Packs 8 bytes, each holding a 7-bit value (0-127), into 7 bytes. + * + * @param packed Pointer to the destination memory (7 bytes). + * @param unpacked Pointer to the source memory (8 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void pack_8_uint7_values( + uint8_t* packed, + const uint8_t* unpacked) { + // pack 8 uint7 values (u0..u7) into 7 bytes (p0..p6) + // The 7 bits of u7 are distributed across the most significant bit (MSB) + // of each of the 7 packed bytes. + // p0 = u7_bit_0 | u0_all_7_bits + // p1 = u7_bit_1 | u1_all_7_bits + // ... + // p6 = u7_bit_6 | u6_all_7_bits + const uint8_t u7 = unpacked[7] & 0x7F; + + for (int i = 0; i < 7; ++i) { + uint8_t u7_bit = (u7 >> i) & 1; + packed[i] = (unpacked[i] & 0x7F) | (u7_bit << 7); + } +} + +/** + * @brief Unpacks 7 bytes into 8 bytes, each containing a 7-bit value. + * + * @param unpacked Pointer to the destination memory (8 bytes). + * @param packed Pointer to the source memory (7 bytes). + */ +TORCHAO_ALWAYS_INLINE inline void unpack_8_uint7_values( + uint8_t* unpacked, + const uint8_t* packed) { + unpacked[7] = 0; + for (int i = 0; i < 7; ++i) { + // The low 7 bits of the packed byte are the original value. + unpacked[i] = packed[i] & 0x7F; + // The high bit of the packed byte is the i-th bit of the 8th value. + uint8_t u7_bit = packed[i] >> 7; + unpacked[7] |= (u7_bit << i); + } +} + +/** + * @brief Packs 64 bytes (each a 7-bit value) into 56 bytes. + * @param packed Pointer to the destination memory (56 bytes). + * @param unpacked Pointer to the source memory (64 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_64_uint7_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void pack_64_uint7_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Transpose-and-pack operation + for (int j = 0; j < 8; ++j) { // Iterate through columns + const uint8_t u7 = unpacked[56 + j] & 0x7F; + for (int i = 0; i < 7; ++i) { // Iterate through rows + uint8_t u7_bit = (u7 >> i) & 1; + packed[i * 8 + j] = (unpacked[i * 8 + j] & 0x7F) | (u7_bit << 7); + } + } +} + +/** + * @brief Unpacks 56 bytes into 64 bytes (each a 7-bit value). + * @param unpacked Pointer to the destination memory (64 bytes). + * @param packed Pointer to the source memory (56 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_64_uint7_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_64_uint7_values( + uint8_t* unpacked, + const uint8_t* packed) { + // Unpack-and-transpose operation + for (int j = 0; j < 8; ++j) { // Iterate through columns + uint8_t u7 = 0; + for (int i = 0; i < 7; ++i) { // Iterate through rows + unpacked[i * 8 + j] = packed[i * 8 + j] & 0x7F; + u7 |= ((packed[i * 8 + j] >> 7) & 1) << i; + } + unpacked[56 + j] = u7; + } +} + +/** + * @brief Packs 128 bytes (each a 7-bit value) into 112 bytes. + * @param packed Pointer to the destination memory (112 bytes). + * @param unpacked Pointer to the source memory (128 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_pack_128_uint7_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void pack_128_uint7_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Transpose-and-pack operation + for (int j = 0; j < 16; ++j) { // Iterate through columns + const uint8_t u7 = unpacked[112 + j] & 0x7F; + for (int i = 0; i < 7; ++i) { // Iterate through rows + uint8_t u7_bit = (u7 >> i) & 1; + packed[i * 16 + j] = (unpacked[i * 16 + j] & 0x7F) | (u7_bit << 7); + } + } +} + +/** + * @brief Unpacks 112 bytes into 128 bytes (each a 7-bit value). + * @param unpacked Pointer to the destination memory (128 bytes). + * @param packed Pointer to the source memory (112 bytes). + * @note This implementation mirrors the logic of the ARM NEON + * `vec_unpack_128_uint7_values` function to ensure compatibility. + */ +TORCHAO_ALWAYS_INLINE inline void unpack_128_uint7_values( + uint8_t* unpacked, + const uint8_t* packed) { + // Unpack-and-transpose operation + for (int j = 0; j < 16; ++j) { // Iterate through columns + uint8_t u7 = 0; + for (int i = 0; i < 7; ++i) { // Iterate through rows + unpacked[i * 16 + j] = packed[i * 16 + j] & 0x7F; + u7 |= ((packed[i * 16 + j] >> 7) & 1) << i; + } + unpacked[112 + j] = u7; + } +} + +} // namespace internal +} // namespace torchao::kernels::cpu::fallback::bitpacking diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h b/torchao/csrc/cpu/torch_free_kernels/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h similarity index 100% rename from torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h b/torchao/csrc/cpu/torch_free_kernels/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h similarity index 100% rename from torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/tests/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/fallback/tests/CMakeLists.txt new file mode 100644 index 0000000000..eab4f9e54b --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/tests/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_tests) + +set(TEST_TARGET_PREFIX "torchao_tests_torch_free_kernels_fallback_") + +enable_testing() + +add_executable(${TEST_TARGET_PREFIX}test_bitpacking test_bitpacking.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_bitpacking + PRIVATE + GTest::gtest_main +) + +include(GoogleTest) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_bitpacking) diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/tests/test_bitpacking.cpp b/torchao/csrc/cpu/torch_free_kernels/fallback/tests/test_bitpacking.cpp new file mode 100644 index 0000000000..32177e63da --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/tests/test_bitpacking.cpp @@ -0,0 +1,217 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. +// test pack with cpp unpack with arm_neon +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TEST(FallbackBitpackingTest, PackUnpack8_uint1) { + int unpacked_bytes = 8; + int packed_bytes = 1; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 1); + std::vector packed(packed_bytes); + std::vector unpacked(unpacked_bytes); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_8_uint1_values( + packed.data(), input.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_8_uint1_values( + unpacked.data(), packed.data()); + + ASSERT_EQ(input, unpacked); +} + +TEST(FallbackBitpackingTest, PackUnpack4_uint2) { + int unpacked_bytes = 4; + int packed_bytes = 1; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 2); + std::vector packed(packed_bytes); + std::vector unpacked(unpacked_bytes); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_4_uint2_values( + packed.data(), input.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_4_uint2_values( + unpacked.data(), packed.data()); + + ASSERT_EQ(input, unpacked); +} + +TEST(FallbackBitpackingTest, PackUnpack8_uint3) { + int unpacked_bytes = 8; + int packed_bytes = 3; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 3); + std::vector packed(packed_bytes); + std::vector unpacked(unpacked_bytes); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_8_uint3_values( + packed.data(), input.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_8_uint3_values( + unpacked.data(), packed.data()); + + ASSERT_EQ(input, unpacked); +} + +TEST(FallbackBitpackingTest, PackUnpack32_uint4) { + int unpacked_bytes = 32; + int packed_bytes = 16; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 4); + std::vector packed(packed_bytes); + std::vector unpacked(unpacked_bytes); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_32_uint4_values( + packed.data(), input.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_32_uint4_values( + unpacked.data(), packed.data()); + + ASSERT_EQ(input, unpacked); +} + +TEST(FallbackBitpackingTest, PackUnpack8_uint5) { + int unpacked_bytes = 8; + int packed_bytes = 5; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 5); + std::vector packed(packed_bytes); + std::vector unpacked(unpacked_bytes); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_8_uint5_values( + packed.data(), input.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_8_uint5_values( + unpacked.data(), packed.data()); + + ASSERT_EQ(input, unpacked); +} + +TEST(FallbackBitpackingTest, PackUnpack4_uint6) { + int unpacked_bytes = 4; + int packed_bytes = 3; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 6); + std::vector packed(packed_bytes); + std::vector unpacked(unpacked_bytes); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_4_uint6_values( + packed.data(), input.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_4_uint6_values( + unpacked.data(), packed.data()); + + ASSERT_EQ(input, unpacked); +} + +TEST(FallbackBitpackingTest, PackUnpack8_uint7) { + int unpacked_bytes = 8; + int packed_bytes = 7; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 7); + std::vector packed(packed_bytes); + std::vector unpacked(unpacked_bytes); + + torchao::kernels::cpu::fallback::bitpacking::internal::pack_8_uint7_values( + packed.data(), input.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::unpack_8_uint7_values( + unpacked.data(), packed.data()); + + ASSERT_EQ(input, unpacked); +} + +// --- Template test for the main dispatcher function --- +template +void test_bitpacking_128_lowbit_values() { + const int unpacked_bytes = 128; + const int packed_bytes = unpacked_bytes * nbit / 8; + + auto input = torchao::get_random_signed_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes); + std::vector unpacked(unpacked_bytes); + + torchao::kernels::cpu::fallback::bitpacking::internal:: + pack_128_lowbit_int_values(packed.data(), input.data()); + torchao::kernels::cpu::fallback::bitpacking::internal:: + unpack_128_lowbit_int_values(unpacked.data(), packed.data()); + + ASSERT_EQ(input, unpacked); +} + +// --- Template test for the LUT dispatcher function --- +template +void test_bitpacking_128_lowbit_values_with_lut() { + const int unpacked_bytes = 128; + const int packed_bytes = unpacked_bytes * nbit / 8; + const int num_lut_entries = 1 << nbit; + + // 1. Create a LUT and random indices + auto lut = torchao::get_random_signed_lowbit_vector(num_lut_entries, 8); + auto indices = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + + // 2. Create the ground truth data by applying the LUT + std::vector ground_truth(unpacked_bytes); + for (int i = 0; i < unpacked_bytes; ++i) { + ground_truth[i] = lut[indices[i]]; + } + + // 3. Pack the indices + std::vector packed(packed_bytes); + if constexpr (nbit == 1) + torchao::kernels::cpu::fallback::bitpacking::internal:: + pack_128_uint1_values(packed.data(), indices.data()); + if constexpr (nbit == 2) { + torchao::kernels::cpu::fallback::bitpacking::internal::pack_64_uint2_values( + packed.data(), indices.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::pack_64_uint2_values( + packed.data() + 16, indices.data() + 64); + } + if constexpr (nbit == 3) + torchao::kernels::cpu::fallback::bitpacking::internal:: + pack_128_uint3_values(packed.data(), indices.data()); + if constexpr (nbit == 4) { + torchao::kernels::cpu::fallback::bitpacking::internal::pack_32_uint4_values( + packed.data(), indices.data()); + torchao::kernels::cpu::fallback::bitpacking::internal::pack_32_uint4_values( + packed.data() + 16, indices.data() + 32); + torchao::kernels::cpu::fallback::bitpacking::internal::pack_32_uint4_values( + packed.data() + 32, indices.data() + 64); + torchao::kernels::cpu::fallback::bitpacking::internal::pack_32_uint4_values( + packed.data() + 48, indices.data() + 96); + } + + // 4. Unpack using the LUT function + std::vector unpacked(unpacked_bytes); + torchao::kernels::cpu::fallback::bitpacking::internal:: + unpack_128_lowbit_values_with_lut( + unpacked.data(), packed.data(), lut.data()); + + // 5. Verify the result matches the ground truth + ASSERT_EQ(ground_truth, unpacked); +} + +// --- Instantiate all test cases using macros --- +#define TEST_BITPACKING_128_LOWBIT_VALUES(nbit) \ + TEST(GenericBitpacking128, Lowbit_##nbit) { \ + test_bitpacking_128_lowbit_values(); \ + } + +#define TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(nbit) \ + TEST(GenericBitpacking128, Lowbit_with_lut_##nbit) { \ + test_bitpacking_128_lowbit_values_with_lut(); \ + } + +TEST_BITPACKING_128_LOWBIT_VALUES(1); +TEST_BITPACKING_128_LOWBIT_VALUES(2); +TEST_BITPACKING_128_LOWBIT_VALUES(3); +TEST_BITPACKING_128_LOWBIT_VALUES(4); +TEST_BITPACKING_128_LOWBIT_VALUES(5); +TEST_BITPACKING_128_LOWBIT_VALUES(6); +TEST_BITPACKING_128_LOWBIT_VALUES(7); +TEST_BITPACKING_128_LOWBIT_VALUES(8); + +TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(1); +TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(2); +TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(3); +TEST_BITPACKING_128_LOWBIT_VALUES_WITH_LUT(4); diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/csrc/cpu/torch_free_kernels/interface/quantized_matmul.h similarity index 94% rename from torchao/experimental/kernels/cpu/interface/quantized_matmul.h rename to torchao/csrc/cpu/torch_free_kernels/interface/quantized_matmul.h index 826fe9e85b..da3fd32747 100644 --- a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h +++ b/torchao/csrc/cpu/torch_free_kernels/interface/quantized_matmul.h @@ -8,11 +8,11 @@ #include -#include -#include +#include +#include #if defined(__aarch64__) && defined(__ARM_NEON) -#include +#include #endif // defined(__aarch64__) && defined(__ARM_NEON) namespace torchao::kernels::cpu::quantized_matmul { diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/csrc/cpu/torch_free_kernels/interface/test_qmatmul_interface.cpp similarity index 99% rename from torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp rename to torchao/csrc/cpu/torch_free_kernels/interface/test_qmatmul_interface.cpp index 0fbe33ccdc..5ce1593732 100644 --- a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/interface/test_qmatmul_interface.cpp @@ -11,7 +11,7 @@ #include #include -#include +#include float kTol = 0.0001; diff --git a/torchao/experimental/kernels/cpu/aarch64/macro.h b/torchao/csrc/cpu/torch_free_kernels/macro.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/macro.h rename to torchao/csrc/cpu/torch_free_kernels/macro.h diff --git a/torchao/csrc/cpu/torch_free_kernels/test_utils.h b/torchao/csrc/cpu/torch_free_kernels/test_utils.h new file mode 100644 index 0000000000..29b72b51c0 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/test_utils.h @@ -0,0 +1,62 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include + +namespace torchao { +inline std::vector +get_random_vector(int size, float min = -1.0, float max = 1.0) { + assert(min < max); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_real_distribution(min, max), rng); + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +inline std::vector get_random_lowbit_vector(int size, int nbit) { + assert(nbit >= 1); + assert(nbit <= 8); + + int min = 0; + int max = (1 << nbit) - 1; + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); + + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +inline std::vector get_random_signed_lowbit_vector(int size, int nbit) { + assert(nbit >= 1); + assert(nbit <= 8); + + int min = 0; + int max = (1 << nbit) - 1; + int offset = (1 << (nbit - 1)); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); + + std::vector res(size); + std::vector tmp(size); + std::generate(tmp.begin(), tmp.end(), std::ref(dist)); + for (int i = 0; i < size; i++) { + res[i] = tmp[i] - offset; + } + return res; +} +} // namespace torchao diff --git a/torchao/csrc/cuda/activation24/sparse_gemm.cu b/torchao/csrc/cuda/activation24/sparse_gemm.cu index f837bcc3aa..48b4c6e687 100644 --- a/torchao/csrc/cuda/activation24/sparse_gemm.cu +++ b/torchao/csrc/cuda/activation24/sparse_gemm.cu @@ -95,10 +95,10 @@ struct SparseRowwiseKernel { float, ElementOut, cutlass::layout::RowMajor, - 1, + 8, ElementOut, cutlass::layout::RowMajor, - 1, + 8, cutlass::epilogue::TmaWarpSpecializedCooperative, EpilogueEVT>::CollectiveOp; @@ -172,10 +172,10 @@ struct SparseRowwiseKernel { float, ElementOut, cutlass::layout::RowMajor, - 1, + 8, ElementOut, cutlass::layout::RowMajor, - 1, + 8, cutlass::epilogue::TmaWarpSpecializedCooperative, EpilogueEVT>::CollectiveOp; diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu b/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu new file mode 100644 index 0000000000..7546dc7b7b --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu @@ -0,0 +1,180 @@ +// CUDA bridge for MXFP8 quantization + +#include "mxfp8_quantize.cuh" +#include +#include +#include + + +namespace mxfp8 { + +// Convert PyTorch scalar type to our DType enum +DType get_input_dtype(const torch::Tensor &t) { + switch (t.scalar_type()) { + case torch::kFloat32: + return DType::kFloat32; + case torch::kFloat16: + return DType::kFloat16; + case torch::kBFloat16: + return DType::kBFloat16; + case torch::kUInt8: + return DType::kByte; + default: + TORCH_CHECK(false, "Unsupported input tensor dtype: ", t.scalar_type()); + } +} + +ScaleCalculationMode get_scaling_mode(const std::string &scaling_mode) { + if (scaling_mode.compare("floor") == 0) { + return ScaleCalculationMode::FLOOR; + } else if (scaling_mode.compare("rceil") == 0) { + return ScaleCalculationMode::RCEIL; + } else { + TORCH_CHECK(false, "Unsupported scaling mode: ", scaling_mode, ". Only ['floor', 'rceil'] are supported."); + } +} + +// Convert FP8 format string to DType enum +DType get_output_dtype(const std::string &fp8_format) { + if (fp8_format.compare("e4m3") == 0) { + return DType::kFloat8E4M3; + } else { + TORCH_CHECK(false, "Unsupported FP8 format: ", fp8_format, + ". Only 'e4m3' is supported."); + } +} + +void mxfp8_quantize_cuda(const torch::Tensor &input, + torch::Tensor &output_rowwise, + torch::Tensor &output_colwise, + torch::Tensor &scales_rowwise, + torch::Tensor &scales_colwise, + int64_t scale_dim_x, + int64_t scale_dim_y, + const std::string &fp8_format, + const std::string &scaling_mode) { + + // Get tensor properties + const int64_t rows = input.size(0); + const int64_t cols = input.size(1); + + // Get data pointers + const void *input_ptr = input.data_ptr(); + void *output_rowwise_ptr = + output_rowwise.numel() > 0 ? output_rowwise.data_ptr() : nullptr; + void *output_colwise_ptr = + output_colwise.numel() > 0 ? output_colwise.data_ptr() : nullptr; + e8m0_t *scales_rowwise_ptr = + scales_rowwise.numel() > 0 + ? reinterpret_cast(scales_rowwise.data_ptr()) + : nullptr; + e8m0_t *scales_colwise_ptr = + scales_colwise.numel() > 0 + ? reinterpret_cast(scales_colwise.data_ptr()) + : nullptr; + + // Get CUDA stream + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Get strides of scale ptrs + int64_t scale_rowwise_stride_dim0 = scales_rowwise.strides()[0]; + int64_t scale_rowwise_stride_dim1 = scales_rowwise.strides()[1]; + int64_t scale_colwise_stride_dim0 = scales_colwise.strides()[0]; + int64_t scale_colwise_stride_dim1 = scales_colwise.strides()[1]; + +#if defined(DEBUG) + printf("mxfp8_quantize_cuda:\n"); + printf("Quantizing input tensor of size %ld x %ld\n", rows, cols); + printf("scaling_mode: %s\n", scaling_mode.c_str()); + printf("Scale dim x: %ld\n", scale_dim_x); + printf("Scale dim y: %ld\n", scale_dim_y); + printf("Rowwise scale shape: %ld x %ld\n", scales_rowwise.sizes()[0], scales_rowwise.sizes()[1]); + printf("Colwise scale shape: %ld x %ld\n", scales_colwise.sizes()[0], scales_colwise.sizes()[1]); + printf("scale_rowwise_stride_dim0 = %ld\n", scale_rowwise_stride_dim0); + printf("scale_rowwise_stride_dim1 = %ld\n", scale_rowwise_stride_dim1); + printf("scale_colwise_stride_dim0 = %ld\n", scale_colwise_stride_dim0); + printf("scale_colwise_stride_dim1 = %ld\n", scale_colwise_stride_dim1); +#endif + + // Call the quantization kernel + MXFP8Quantizer::quantize(input_ptr, + output_rowwise_ptr, output_colwise_ptr, + scales_rowwise_ptr, scales_colwise_ptr, + scale_rowwise_stride_dim0, scale_rowwise_stride_dim1, + scale_colwise_stride_dim0, scale_colwise_stride_dim1, + rows, cols, + get_input_dtype(input), get_output_dtype(fp8_format), + scale_dim_x, scale_dim_y, + get_scaling_mode(scaling_mode), + stream); +} + +void mxfp8_quantize_3d_cuda(const torch::Tensor &input, + torch::Tensor &output_colwise, + torch::Tensor &scales_colwise, + int64_t scale_dim_n, + const std::string &fp8_format, + const std::string &scaling_mode) { + + // Get tensor properties for 3D tensor (E, N, K) + const int64_t E = input.size(0); + const int64_t N = input.size(1); + const int64_t K = input.size(2); + + // Get data pointers + const void *input_ptr = input.data_ptr(); + void *output_colwise_ptr = output_colwise.data_ptr(); + e8m0_t *scales_colwise_ptr = + reinterpret_cast(scales_colwise.data_ptr()); + + // Get CUDA stream + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Get strides of scales tensor + int64_t scales_colwise_stride_dim0 = scales_colwise.stride(0); + int64_t scales_colwise_stride_dim1 = scales_colwise.stride(1); + int64_t scales_colwise_stride_dim2 = scales_colwise.stride(2); + + // Get input tensor strides for generic layout support + int64_t input_stride_dim0 = input.stride(0); // E dimension stride + int64_t input_stride_dim1 = input.stride(1); // N dimension stride + int64_t input_stride_dim2 = input.stride(2); // K dimension stride + + // Get output tensor strides (shoudl be col major) + int64_t output_stride_dim0 = output_colwise.stride(0); // E dimension stride + int64_t output_stride_dim1 = output_colwise.stride(1); // N dimension stride + int64_t output_stride_dim2 = output_colwise.stride(2); // K dimension stride + + +#if defined(DEBUG) + printf("mxfp8_quantize_3d_cuda:\n"); + printf("Quantizing 3D input tensor of size %ld x %ld x %ld\n", E, N, K); + printf("scaling_mode: %s\n", scaling_mode.c_str()); + printf("Scale dim n: %ld\n", scale_dim_n); + printf("Output scale shape: %ld x %ld x %ld\n", + scales_colwise.sizes()[0], scales_colwise.sizes()[1], scales_colwise.sizes()[2]); + printf("scales_colwise_stride_dim0 = %ld\n", scales_colwise_stride_dim0); + printf("scales_colwise_stride_dim1 = %ld\n", scales_colwise_stride_dim1); + printf("input_stride_dim0 = %ld\n", input_stride_dim0); + printf("input_stride_dim1 = %ld\n", input_stride_dim1); + printf("input_stride_dim2 = %ld\n", input_stride_dim2); + printf("output_stride_dim0 = %ld\n", output_stride_dim0); + printf("output_stride_dim1 = %ld\n", output_stride_dim1); + printf("output_stride_dim2 = %ld\n", output_stride_dim2); +#endif + + // Call the 3D quantization kernel + MXFP8Quantizer::quantize_3d(input_ptr, + output_colwise_ptr, + scales_colwise_ptr, + E, N, K, + input_stride_dim0, input_stride_dim1, input_stride_dim2, + output_stride_dim0, output_stride_dim1, output_stride_dim2, + scales_colwise_stride_dim0, scales_colwise_stride_dim1, scales_colwise_stride_dim2, + get_input_dtype(input), get_output_dtype(fp8_format), + scale_dim_n, + get_scaling_mode(scaling_mode), + stream); +} + +} // namespace mxfp8 diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp new file mode 100644 index 0000000000..d445fcad4d --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp @@ -0,0 +1,195 @@ +// PyBind wrapping for the mxfp8 extension +#include +#include +#include +#include +#include + +namespace mxfp8 { + +// Forward declarations +void mxfp8_quantize_cuda(const torch::Tensor &input, + torch::Tensor &output_rowwise, + torch::Tensor &output_columnwise, + torch::Tensor &scales_rowwise, + torch::Tensor &scales_colwise, + int64_t scale_dim_x, + int64_t scale_dim_y, + const std::string &fp8_format, + const std::string &scaling_mode); + +void mxfp8_quantize_3d_cuda(const torch::Tensor &input, + torch::Tensor &output_colwise, + torch::Tensor &scales_colwise, + int64_t scale_dim_n, + const std::string &fp8_format, + const std::string &scaling_mode); + +// Helper for tensor validation +void check_cuda_tensor(const torch::Tensor &t, const char *name) { + TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); +} + +// Helper to validate FP8 format +void validate_fp8_format(const std::string &fp8_format) { + TORCH_CHECK(fp8_format.compare("e4m3") == 0, + "fp8_format must be 'e4m3', got: ", fp8_format); +} + +// Helper to validate scale dimensions +void validate_scale_dimensions(int64_t scale_dim_x, int64_t scale_dim_y) { + TORCH_CHECK(scale_dim_x == 1 || scale_dim_x == 32, + "scale_dim_x must be 1 or 32, got: ", scale_dim_x); + TORCH_CHECK(scale_dim_y == 1 || scale_dim_y == 32, + "scale_dim_y must be 1 or 32, got: ", scale_dim_y); +} + +// Main quantization function +std::tuple +mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, + int64_t scale_dim_x, int64_t scale_dim_y, + const std::string &fp8_format, + const std::string &scaling_mode) { + + // Validate inputs + TORCH_CHECK(!rowwise, "rowwise scaling is not supported yet"); + TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(input.scalar_type() == torch::kFloat32 || + input.scalar_type() == torch::kFloat16 || + input.scalar_type() == torch::kBFloat16, + "Input must be float32, float16, or bfloat16"); + TORCH_CHECK(rowwise || colwise, + "At least one of rowwise or colwise must be true"); + + validate_scale_dimensions(scale_dim_x, scale_dim_y); + validate_fp8_format(fp8_format); + + const int64_t rows = input.size(0); + const int64_t cols = input.size(1); + TORCH_CHECK((rows >= 32) && (rows % 32 == 0), "rows must be a multiple of 32"); + TORCH_CHECK((cols >= 32) && (cols % 32 == 0), "cols must be a multiple of 32"); + + c10::cuda::CUDAGuard device_guard(input.device()); + + // Create tensor options + const auto options_fp8 = torch::TensorOptions() + .dtype(torch::kFloat8_e4m3fn) // FP8 stored as uint8 + .device(input.device()); + + const auto options_scale = torch::TensorOptions() + .dtype(torch::kFloat8_e8m0fnu) // E8M0 stored as uint8 + .device(input.device()); + + // Allocate output tensors + torch::Tensor output_rowwise, output_colwise; + torch::Tensor scales_rowwise, scales_colwise; + + if (rowwise) { + const int64_t num_col_blocks = (cols + scale_dim_x - 1) / scale_dim_x; + output_rowwise = torch::empty({rows, cols}, options_fp8); + scales_rowwise = torch::empty({rows, num_col_blocks}, options_scale); + } else { + output_rowwise = torch::empty({0}, options_fp8); + scales_rowwise = torch::empty({0}, options_scale); + } + + if (colwise) { + const int64_t num_row_blocks = (rows + scale_dim_y - 1) / scale_dim_y; + output_colwise = torch::empty_strided({rows, cols}, {1, rows}, options_fp8); + // Need scales_colwise to be this shape so the 'col' dim stride is 1, + // for colwise scaling, we can avoid uncoalesced writes to global memory. + // This is because each of the 32 threads in a warp will be computing + // a scale for a different column of 32 input data values, then each writing + // that scale to global memory - so the stride along this `col` dim should be 1 + // so writes can be coalesced into a single transaction. + scales_colwise = torch::empty_strided({cols, num_row_blocks}, {1, cols}, options_scale); + } else { + output_colwise = torch::empty({0}, options_fp8); + scales_colwise = torch::empty({0}, options_scale); + } + + // Call CUDA kernels + mxfp8_quantize_cuda(input, + output_rowwise, output_colwise, + scales_rowwise, scales_colwise, + rowwise ? scale_dim_x : 1, // scale_dim_x + colwise ? scale_dim_y : 1, // scale_dim_y + fp8_format, scaling_mode); + + return std::make_tuple(output_rowwise, output_colwise, scales_rowwise, + scales_colwise); +} + +// 3D tensor quantization function +std::tuple +mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n, + const std::string &fp8_format, + const std::string &scaling_mode) { + + // Validate inputs + TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + // Note: We don't check contiguous for 3D as it may have column major strides + TORCH_CHECK(input.dim() == 3, "input must be 3D"); + TORCH_CHECK(input.scalar_type() == torch::kFloat32 || + input.scalar_type() == torch::kFloat16 || + input.scalar_type() == torch::kBFloat16, + "Input must be float32, float16, or bfloat16"); + TORCH_CHECK(scale_dim_n == 32, "scale_dim_n must be 32 for now"); + + validate_fp8_format(fp8_format); + + const int64_t E = input.size(0); + const int64_t N = input.size(1); + const int64_t K = input.size(2); + + // Check dimensions are valid for 3D kernel + TORCH_CHECK((N >= 32) && (N % 32 == 0), "N must be a multiple of 32"); + TORCH_CHECK((K >= 32) && (K % 32 == 0), "K must be a multiple of 32"); + + + c10::cuda::CUDAGuard device_guard(input.device()); + + // Create tensor options + const auto options_fp8 = torch::TensorOptions() + .dtype(torch::kFloat8_e4m3fn) + .device(input.device()); + + const auto options_scale = torch::TensorOptions() + .dtype(torch::kFloat8_e8m0fnu) + .device(input.device()); + + // Create output tensor with column major layout (required for downstream ops) + torch::Tensor output_colwise = torch::empty_strided( + {E, N, K}, {N * K, 1, N}, options_fp8); + + // Create scales tensor with shape (E, num_n_blocks, K) + const int64_t num_n_blocks = (N + scale_dim_n - 1) / scale_dim_n; + torch::Tensor scales_colwise = torch::empty({E, num_n_blocks, K}, options_scale); + + // Call CUDA kernel + mxfp8_quantize_3d_cuda(input, output_colwise, scales_colwise, + scale_dim_n, fp8_format, scaling_mode); + + return std::make_tuple(output_colwise, scales_colwise); +} + +} // namespace mxfp8 + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "MXFP8 Quantization PyTorch Extension"; + + m.def("quantize", &mxfp8::mxfp8_quantize, "MXFP8 quantization", + py::arg("input"), py::arg("rowwise") = true, py::arg("colwise") = false, + py::arg("scale_dim_x") = 32, py::arg("scale_dim_y") = 32, + py::arg("fp8_format") = "e4m3", + py::arg("scaling_mode") = "floor"); + + m.def("quantize_3d", &mxfp8::mxfp8_quantize_3d, "MXFP8 3D quantization", + py::arg("input"), py::arg("scale_dim_n") = 32, + py::arg("fp8_format") = "e4m3", + py::arg("scaling_mode") = "floor"); +} diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh b/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh new file mode 100644 index 0000000000..fbaeb129d9 --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh @@ -0,0 +1,1445 @@ +// Adapted from https://github.com/NVIDIA/TransformerEngine +// License - Apache-2.0 +// https://github.com/NVIDIA/TransformerEngine/blob/main/LICENSE +// * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Portions (c) Meta Platforms, Inc. and affiliates. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Use official CUDA PTX library +#include "ptx.cuh" +#include +#include + +#define MIN_CUDA_SM 1000 // SM90 = 900, SM100 = 1000 + +// Check if we're compiling for supported architecture +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < MIN_CUDA_SM) +#warning \ + "MXFP8 quantization requires SM90+ (Hopper) or SM100+ (Blackwell) architecture. Kernel will be disabled for this architecture." +#endif + +// Architecture detection for native FP8 support +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 +#define HAS_NATIVE_FP8_CONVERSION 1 +#else +#define HAS_NATIVE_FP8_CONVERSION 0 +#endif + +// Macro to check CUDA error. +#define CUDA_CHECK(call) \ +do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA Error in %s at line %d: %s\n", \ + __FILE__, __LINE__, cudaGetErrorString(err)); \ + throw std::runtime_error(cudaGetErrorString(err)); \ + } \ +} while (0) + +enum class DType { + kByte, + kFloat32, + kFloat16, + kBFloat16, + kFloat8E4M3, + kFloat8E5M2 +}; + +enum class ScaleCalculationMode { + FLOOR, // uses software scaling + RCEIL, // uses hardware scaling +}; + +// Data types +using e8m0_t = uint8_t; +using bfloat16 = nv_bfloat16; +using fp8e4m3 = __nv_fp8_e4m3; + +constexpr size_t get_dtype_bits(DType dtype) { + switch (dtype) { + case DType::kFloat32: + return 32; + case DType::kBFloat16: + return 16; + case DType::kFloat8E4M3: + return 8; + default: + // TODO: something smarter than this + return 0; + } +} + +// FP32 constants +constexpr int32_t FP32_MANTISSA_BITS = 23; +constexpr int32_t FP32_EXPONENT_BIAS = 127; + +// BF16 constants +constexpr int32_t BF16_MANTISSA_BITS = 7; +constexpr int32_t BF16_EXPONENT_BIAS = 127; + +// FP8E4M3 constants +constexpr int32_t F8E4M3_MAX_POW2 = 8; +constexpr float F8E4M3_MAX = 448.0; + +// FP8E8M0 constants +constexpr int32_t E8M0_EXPONENT_BIAS = 127; + +// 1. Base template (for unsupported types) +template struct DataTypeTraits { + static constexpr bool is_supported = false; +}; + +// 2. Specialization for float32 +template <> struct DataTypeTraits { + static constexpr bool is_supported = true; + static constexpr int mantissa_bits = 23; + static constexpr int exponent_bias = 127; + + __device__ static __forceinline__ float to_float(const float val) { + return val; + } +}; + +// 3. Specialization for bfloat16 +template <> struct DataTypeTraits { + static constexpr bool is_supported = true; + static constexpr int mantissa_bits = 7; + static constexpr int exponent_bias = 127; + + __device__ static __forceinline__ float to_float(const nv_bfloat16 val) { + return __bfloat162float(val); + } +}; + +__device__ static __forceinline__ e8m0_t +calculate_e8m0_biased_scale(const float amax) { + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L239 + const int32_t int_amax = *reinterpret_cast(&amax); + const int32_t extracted_pow2 = + ((int_amax >> FP32_MANTISSA_BITS) & 0b11111111) - FP32_EXPONENT_BIAS; + + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L244 + int32_t scale_unbiased = extracted_pow2 - F8E4M3_MAX_POW2; + + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L256 + scale_unbiased = max(scale_unbiased, -E8M0_EXPONENT_BIAS); + scale_unbiased = min(scale_unbiased, E8M0_EXPONENT_BIAS + 1); + int32_t scale_with_e8m0_bias = scale_unbiased + E8M0_EXPONENT_BIAS; + + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L261C9-L261C26 + const e8m0_t e8m0_biased_scale = + *reinterpret_cast(&scale_with_e8m0_bias); + return e8m0_biased_scale; +} + +// Constants for MXFP8 kernel +constexpr size_t MXFP8_CHUNK_DIM_Y = 64; +constexpr size_t MXFP8_CHUNK_DIM_X = 64; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK = + MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; // 1 * 1 = 1 +constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr size_t MXFP8_BUFFERS_NUM = 2; +constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; +constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 +constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 64/16 = 4 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 64 / 4 = 16 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_BUFF_STAGES_NUM = + MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 +constexpr size_t MXFP8_ITERATIONS = + MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); + +constexpr size_t THREADS_PER_WARP = 32; // lol + +// Utility macros +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) + +// Vector type for loading/storing multiple elements +template struct Vec { + union { + T elt[N]; + } data; + + __device__ inline void clear() { +#pragma unroll + for (int i = 0; i < N; ++i) { + data.elt[i] = T(0); + } + } + + __device__ inline void load_from(const T *ptr) { +#pragma unroll + for (int i = 0; i < N; ++i) { + data.elt[i] = ptr[i]; + } + } + + __device__ inline void store_to(T *ptr) const { +#pragma unroll + for (int i = 0; i < N; ++i) { + ptr[i] = data.elt[i]; + } + } +}; + +// Source: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L971 +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) + ? 1 + : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + +// Source: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L937 +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } +#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + uint16_t out; + asm volatile("{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && + !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + +// Quantization limits +// Source: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L929 +template struct Quantized_Limits { + static constexpr float max_norm = 448.0f; // For E4M3 + static constexpr float max_norm_rcp = 1.0f / max_norm; +}; + +// Warp reduction utilities +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L867 +/** + * Max reduction in subwarps + * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four + * MXFP8 scaling factors. To compute an actual scaling factor for 32 + * consequentive elements, only 8 threads need to participate, thus splitting + * the warp into 4x smaller subwarps 8-thread width. 'Butterfly' reduction is + * used inside subwarps. + */ +template +__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = subwarp_width / 2; offset > 0; offset /= 2) { + const float val_other = + __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp + // lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width); + return val_tmp; +} + +// Source: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L813C1-L824C2 +template +__device__ __forceinline__ float warp_reduce_max(const float m) { + float tmp = m; +#pragma unroll + for (int delta = num_elems / 2; delta > 0; delta /= 2) { + const float other_m = __shfl_down_sync(0xFFFFFFFF, tmp, delta); + __builtin_assume(tmp >= 0); + __builtin_assume(other_m >= 0); + tmp = fmaxf(tmp, other_m); + } + return tmp; +} + +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/utils.cuh#L841C1-L857C2 +template +__device__ __forceinline__ compute_t reduce_max(const compute_t m, + const int warpid) { + __shared__ float staging[num_warps]; + constexpr int warp_size = 32; + const float my_max = m; + const float my_warp_max = warp_reduce_max(my_max); + if (threadIdx.x % 32 == 0) { + staging[warpid] = my_warp_max; + } + __syncthreads(); + compute_t result = 0.f; + if (warpid == 0) { + const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; + result = warp_reduce_max(my_max); + } + return result; +} + +// https://stackoverflow.com/a/51549250 +// TODO: handle -0 case +__device__ __forceinline__ float atomicMaxFloat(float *addr, float value) { + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int *)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int *)addr, __float_as_uint(value))); + + return old; +} + +// TMA descriptor creation +inline CUtensorMapDataType get_dtype_for_tma(DType dtype) { + switch (dtype) { + case DType::kFloat32: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + case DType::kFloat16: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + case DType::kBFloat16: + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case DType::kFloat8E4M3: + case DType::kFloat8E5M2: + case DType::kByte: + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + default: + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } +} + +void* get_driver_ptr() { + // Only initialize driver_ptr once during the lifetime of the program. + static void *driver_ptr = nullptr; + if (!driver_ptr) { + cudaDriverEntryPointQueryResult result; + CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, + cudaEnableDefault, &result)); + } + return driver_ptr; +} + +inline void create_3D_tensor_map_output(CUtensorMap &tensorMap, + void *data_ptr, + DType dtype, + const size_t E, + const size_t N, + const size_t K, + uint32_t shmem_e, + uint32_t shmem_n, + uint32_t shmem_k, + const size_t type_num_bits) { + // Get function pointer to cuTensorMapEncodeTiled + void *driver_ptr = get_driver_ptr(); + auto cuTensorMapEncodeTiled = + reinterpret_cast(driver_ptr); + + + // Rank of the tensor is 3 + constexpr uint32_t rank = 3; + + // Dimensions must be ordered from fastest to slowest moving dimension. + // Given shape (E, N, K) and strides (N * K, 1, N), the order is N, K, E. + uint64_t size[rank] = {N, K, E}; + + // The stride array has rank-1 elements. + // stride[0] = byte stride for the second-fastest dimension (K). + // stride[1] = byte stride for the third-fastest dimension (E). + const size_t bytes_per_elem = type_num_bits / 8; + uint64_t stride[rank - 1] = { + N * bytes_per_elem, // Stride for K dim: N elements * bytes/element + N * K * bytes_per_elem}; // Stride for E dim: N*K elements * bytes/element + + // Box dimensions (tile size) must follow the same fastest-to-slowest order. + uint32_t boxSize[rank] = {shmem_n, shmem_k, shmem_e}; + + // Element strides within the tile (box). For a contiguous copy, this is always 1. + uint32_t elemStride[rank] = {1, 1, 1}; + + cuTensorMapEncodeTiled( + &tensorMap, get_dtype_for_tma(dtype), rank, data_ptr, size, stride, + boxSize, elemStride, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); +} + +// Reference: +// https://github.com/NVIDIA/TransformerEngine/blob/1ae1d228d725a488621deba685bd26d6ee1cdb21/transformer_engine/common/common.cu#L137 +// This was modified to make it compatible with our implementation and avoid +// using internal TE types. +inline void create_2D_tensor_map(CUtensorMap &tensorMap, void *data_ptr, + DType dtype, const size_t rows, + const size_t cols, uint32_t shmem_y, + uint32_t shmem_x, const size_t stride_elems, + const size_t type_num_bits) { + // Get function pointer to cuTensorMapEncodeTiled + void *driver_ptr = get_driver_ptr(); + auto cuTensorMapEncodeTiled = + reinterpret_cast(driver_ptr); + + constexpr uint32_t rank = 2; + uint64_t size[rank] = {cols, rows}; + uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / + 8}; // (cols * bits per element) / 8 + uint32_t boxSize[rank] = {shmem_x, shmem_y}; + uint32_t elemStride[rank] = {1, 1}; + + cuTensorMapEncodeTiled( + &tensorMap, get_dtype_for_tma(dtype), rank, data_ptr, size, stride, + boxSize, elemStride, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); +} + + +// Helper functions for TMA operations +__device__ inline void copy_2d_to_shared(void *smem, + const CUtensorMap *tensor_map, + uint32_t x, uint32_t y, + size_t smem_size, uint64_t *mbar, + bool is_master) { + if (is_master) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(smem), + reinterpret_cast(tensor_map), x, y, mbar); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(mbar, smem_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(mbar); + } +} + +//////////////////////////////////////////////////////////////////////////////// +// TorchAO shared quantization utils +//////////////////////////////////////////////////////////////////////////////// + +/** + * Convert e8m0 biased scale to float32 scale following torchao implementation + * torchao ref: + * https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L275C1-L277C30 + */ +__device__ __forceinline__ float e8m0_to_scale_fp32(e8m0_t e8m0_biased_scale) { + int32_t exponent_as_int32 = static_cast(e8m0_biased_scale); + int32_t float_bits = exponent_as_int32 << FP32_MANTISSA_BITS; + float scale_fp32 = *reinterpret_cast(&float_bits); + + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L286 + const float F32_MIN_NORMAL = exp2f(-FP32_EXPONENT_BIAS + 1); + scale_fp32 = max(scale_fp32, F32_MIN_NORMAL); + + return scale_fp32; +} + +/** + * Quantize a single value using torchao-style clamping and conversion + * torchao ref: + * https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L289 + */ +template +__device__ __forceinline__ OType torchao_quantize_value(float input_value, + float inv_scale_fp32) { + // Scale the input value + float data_lp = input_value * inv_scale_fp32; + + // Apply torchao-style clamping + // torchao ref: + // https://github.com/pytorch/ao/blob/00417b8b33abb75c54cdb347bd320fb6ac0a4d94/torchao/prototype/mx_formats/mx_tensor.py#L301C23-L301C74 + data_lp = min(data_lp, F8E4M3_MAX); + data_lp = max(data_lp, -F8E4M3_MAX); + + return static_cast(data_lp); +} + +/** + * Complete torchao-style quantization: calculate scale and convert values + * Template parameters ensure compile-time array size checking for safety + */ +template +__device__ __forceinline__ void +quantize_block(float amax, e8m0_t &out_scale, + const float (&input_values)[NUM_VALUES], + OType (&output_values)[NUM_VALUES]) { + + float inv_scale_fp32; + if constexpr (ScalingMode == ScaleCalculationMode::FLOOR) { + // FLOOR scaling. + out_scale = calculate_e8m0_biased_scale(amax); + + // Convert scale to float32 + float scale_fp32 = e8m0_to_scale_fp32(out_scale); + + // Calculate inverse scale for fast multiplication + inv_scale_fp32 = __fdiv_rn(1.0f, scale_fp32); + + // Quantize all values +#pragma unroll + for (int i = 0; i < NUM_VALUES; ++i) { + output_values[i] = + torchao_quantize_value(input_values[i], inv_scale_fp32); + } + + } else { + // RCEIL scaling. + out_scale = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + inv_scale_fp32 = exp2f_rcp(out_scale); + +#pragma unroll + for (int i = 0; i < NUM_VALUES; ++i) { + output_values[i] = + static_cast(input_values[i] * inv_scale_fp32); + } + } + +} + +/** + * Bounds checking helper for IMA avoidance + */ +struct BoundsChecker { + const size_t rows, cols; + const size_t chunk_offset_X, chunk_offset_Y; + + __device__ __forceinline__ BoundsChecker(size_t r, size_t c, size_t cox, + size_t coy) + : rows(r), cols(c), chunk_offset_X(cox), chunk_offset_Y(coy) {} + + __device__ __forceinline__ bool is_out_of_bounds(size_t row, + size_t col) const { + return (row >= rows) || (col >= cols); + } + + __device__ __forceinline__ bool + is_rowwise_out_of_bounds(size_t shmem_y, size_t shmem_x, int j, + size_t row_base) const { + const size_t row = row_base + shmem_y; + const size_t col = chunk_offset_X + shmem_x + j; + return is_out_of_bounds(row, col); + } + + __device__ __forceinline__ bool + is_colwise_out_of_bounds(size_t row_offset, size_t col, + size_t row_base) const { + const size_t row = row_base + row_offset; + return is_out_of_bounds(row, col); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// MXFP8 quantization kernel +//////////////////////////////////////////////////////////////////////////////// + +// Main MXFP8 quantization kernel (with TMA) +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + mxfp8_quantize_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, + const size_t scales_rowwise_stride_dim0, + const size_t scales_rowwise_stride_dim1, + const size_t scales_colwise_stride_dim0, + const size_t scales_colwise_stride_dim1) { + +#if defined(DEBUG) + printf("mxfp8_quantize_kernel: rows=%llu, cols=%llu, " + "scales_rowwise_stride_dim0=%llu, scales_rowwise_stride_dim1=%llu, " + "scales_colwise_stride_dim0=%llu, scales_colwise_stride_dim1=%llu\n", + (unsigned long long)rows, (unsigned long long)cols, + (unsigned long long)scales_rowwise_stride_dim0, + (unsigned long long)scales_rowwise_stride_dim1, + (unsigned long long)scales_colwise_stride_dim0, + (unsigned long long)scales_colwise_stride_dim1); + + if (ScalingMode == ScaleCalculationMode::FLOOR) { + printf("mxfp8_quantize_kernel: scaling_mode: floor\n"); + } else if (ScalingMode == ScaleCalculationMode::RCEIL) { + printf("mxfp8_quantize_kernel: scaling_mode: rceil\n"); + } else { + printf("mxfp8_quanitze_kenrel: unknown scaling mode\n"); + } +#endif + + + static_assert(DataTypeTraits::is_supported, + "Input data type is not supported by this kernel."); + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = + MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = + MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = + MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = + MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int block_offset_Y = + blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = + blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; + const int scales_rowwise_block_offset_X = + blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = + blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = + blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + + // The destination shared memory buffer of a bulk tensor operation should be + // 128 e8m0_t aligned + __shared__ alignas(128) + IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_X][MXFP8_SHMEM_DIM_Y]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float block_amax = 0; + +// Initialize shared memory barrier with the number of threads participating in +// the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers( + mbar, is_master_thread); + + int parity = 0; + +// Process chunks +#pragma unroll + // Calculate chunk offsets + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = + scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = + scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +// Prefetch initial data +#pragma unroll + // Kick off TMA async copy from global to shared memory + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; + ++prefetch_buff) { + const int chunk_stage_offset_Y = + chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, + chunk_stage_offset_X, chunk_stage_offset_Y, + shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + +// Process iterations +#pragma unroll + // Iterate through the chunk along the Y dim + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + // Prefetch next iteration data + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = + chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#if defined(DEBUG_SMEM) + // Debugging smem data + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) { + printf("Shared memory values:\n"); + for (int b = 0; b < MXFP8_BUFFERS_NUM; b++) { + for (int y = 0; y < MXFP8_SHMEM_DIM_Y; y++) { + for (int x = 0; x < MXFP8_SHMEM_DIM_X; x++) { + printf("in_sh[%d][%d][%d] = %f\n", b, y, x, + DataTypeTraits::to_float(in_sh[b][y][x])); + } + } + } + } +#endif + + // ======== RowWise SCALING ======== + + // Updated Row-wise scaling section: + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec out_c; + + // Create bounds checker for this chunk + BoundsChecker bounds(rows, cols, chunk_offset_X, chunk_offset_Y); + + const int iteration_scale_rowwise_offset_Y = + scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + +#pragma unroll + for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + + // Load from shared memory into thread local registers + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + + // Calculate thread-local amax and prepare input values +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool out_of_bounds = bounds.is_rowwise_out_of_bounds( + shmem_offset_y, shmem_offset_x, j, row_base); + + // Load and convert to float + float elt = DataTypeTraits::to_float(in.data.elt[j]); + in_compute[j] = elt; + + // Update thread local amax + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + // Update block local amax + block_amax = fmaxf(block_amax, thread_amax); + + // Reduce amax across subwarp + const float subwarp_amax = + subwarp_reduce_max_broadcast(thread_amax); + + + // Apply quantization to the local block. + e8m0_t e8m0_biased_scale; + OType quantized_values[ELEMS_PER_THREAD]; + + quantize_block( + subwarp_amax, e8m0_biased_scale, in_compute, quantized_values); + + // Write scaling factor (only a single thread writes it to global + // memory) + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + + tid_rowwise_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + const int scale_idx = + global_scales_offset_Y * scales_rowwise_stride_dim0 + + global_scales_offset_X; + scales_rowwise[scale_idx] = e8m0_biased_scale; + } + + // Store quantized values +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out_c.data.elt[j] = quantized_values[j]; + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + +#if defined(DEBUG) + if (tid_rowwise_X == 0 && tid_rowwise_Y == 0) { + printf("Rowwise: subwarp_amax=%f, e8m0_scale=%u\n", subwarp_amax, e8m0_biased_scale); + } +#endif + + } + } + // ======== End RowWise SCALING ======== + + // ======== ColWise SCALING ======== + // Column-wise scaling + + if constexpr (USE_COLWISE_SCALING) { + // Create bounds checker for this chunk + BoundsChecker bounds(rows, cols, chunk_offset_X, chunk_offset_Y); + + const size_t col = chunk_offset_X + tid_colwise_X; + const bool col_out_of_bounds = (col >= cols); + + float in_compute[SCALE_DIM_Y]; + float amax = 0; + + // Calculate amax and prepare input values +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const bool out_of_bounds = + bounds.is_colwise_out_of_bounds(i, col, row_base); + + // Load and convert to float + float elt = + DataTypeTraits::to_float(in_sh[buff][i][tid_colwise_X]); + in_compute[i] = elt; + + // Update thread local amax + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + // Apply quantization to the local block. + e8m0_t e8m0_biased_scale; + OType quantized_values[SCALE_DIM_Y]; + quantize_block( + amax, e8m0_biased_scale, in_compute, quantized_values); + + // Write scaling factor to global memory + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = + scales_colwise_chunk_offset_X + tid_colwise_X; + + // Write scale in column major memory layout, shape (cols, num_row_blocks, 1). + // Stride along `cols` dim must be 1, for coalesced writes to global memory. + const int scale_idx = + global_scales_offset_Y * scales_colwise_stride_dim1 + + global_scales_offset_X * scales_colwise_stride_dim0; + + // Bounds check for scale writing + const bool row_out_of_bounds = (row_base >= rows); + if (!row_out_of_bounds && !col_out_of_bounds) { + scales_colwise[scale_idx] = e8m0_biased_scale; + } + + // Store quantized values to shared memory +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][tid_colwise_X][i] = quantized_values[i]; + } + +#if defined(DEBUG) + if (tid_colwise_X == 0) { + printf("Colwise: amax=%f, e8m0_scale=%u\n", amax, e8m0_biased_scale); + } +#endif + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + if constexpr (USE_ROWWISE_SCALING) { + const int chunk_it_offset_y = + chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_rowwise_sh[buff])); + } + if constexpr (USE_COLWISE_SCALING) { + // Swap logical destination offsets for TMA to write into column major layout. + const int chunk_it_offset_y = chunk_offset_X; + const int chunk_it_offset_x = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), + chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(&out_colwise_sh[buff])); + } + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + destroy_barriers(mbar, is_master_thread); + // #endif +} + +// 3D MXFP8 quantization kernel using 2D TMA +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + mxfp8_quantize_kernel_3d( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + e8m0_t *const scales_colwise, + const size_t E, const size_t N, const size_t K, + const size_t scales_colwise_stride_dim0, + const size_t scales_colwise_stride_dim1, + const size_t scales_colwise_stride_dim2) { + + static_assert(DataTypeTraits::is_supported, + "Input data type is not supported by this kernel."); + + // Only support colwise scaling for 3D case + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + static_assert(USE_COLWISE_SCALING, "3D kernel only supports colwise scaling"); + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = + MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = + MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + const int block_offset_Y = + blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_colwise_block_offset_Y = + blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = + blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + const int expert_idx = blockIdx.z; + const int expert_logical_base_row = expert_idx * N; + + // The destination shared memory buffer of a bulk tensor operation should be + // 128 e8m0_t aligned + __shared__ alignas(128) + IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + + // SMEM buffer for expert must be 3d since we use cp async bulk tensor 3d ptx instruction. + // We parallelize across experts, so leading "E" dim will always be 1 for single expert. + constexpr size_t smem_e = 1; + __shared__ alignas(128) OType + out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_X][MXFP8_SHMEM_DIM_Y][smem_e]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in +// the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers( + mbar, is_master_thread); + + int parity = 0; + +// Process chunks +#pragma unroll + // Calculate chunk offsets + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +// Prefetch initial data +#pragma unroll + // Kick off TMA async copy from global to shared memory + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; + ++prefetch_buff) { + const int chunk_stage_offset_Y = + chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + + // Calculate TMA coordinates for using 2D descriptor to read 3D input data + const int tma_x_offset = chunk_stage_offset_X; + const int tma_y_offset = expert_logical_base_row + chunk_stage_offset_Y; + + copy_2d_to_shared(&in_sh[prefetch_buff], + &tensor_map_input, + tma_x_offset, + tma_y_offset, + shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + +// Process iterations +#pragma unroll + // Iterate through the chunk along the Y dim + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = expert_logical_base_row + chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + // Prefetch next iteration data + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = + chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // Calculate TMA coordinates for using 2D descriptor to read 3D input data + const int tma_x_offset = chunk_it_offset_x; + const int tma_y_offset = expert_logical_base_row + chunk_it_offset_y; + + copy_2d_to_shared(&in_sh[next_buff], + &tensor_map_input, + tma_x_offset, + tma_y_offset, + shmem_buff_size, + &mbar[next_iter], + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + + // ======== 3d tensor column-wise scaling + + // Create bounds checker for this chunk - using the full tensor dimensions (E*N, K) + BoundsChecker bounds(E * N, K, chunk_offset_X, chunk_offset_Y); + + const size_t col = chunk_offset_X + tid_colwise_X; + const bool col_out_of_bounds = (col >= K); + + float in_compute[SCALE_DIM_Y]; + float amax = 0; + + // Calculate amax and prepare input values +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const bool out_of_bounds = + bounds.is_colwise_out_of_bounds(i, col, row_base); + + // Load and convert to float + float elt = + DataTypeTraits::to_float(in_sh[buff][i][tid_colwise_X]); + in_compute[i] = elt; + + // Update thread local amax + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + // Apply quantization to the local block. + e8m0_t e8m0_biased_scale; + OType quantized_values[SCALE_DIM_Y]; + quantize_block( + amax, e8m0_biased_scale, in_compute, quantized_values); + + // Write scaling factor to global memory + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = + scales_colwise_chunk_offset_X + tid_colwise_X; + + // Calculate scale offset using expert base offset plus local scale offset. + const int expert_scale_base_offset = expert_idx * scales_colwise_stride_dim0; + const int scale_idx = expert_scale_base_offset + + global_scales_offset_Y * scales_colwise_stride_dim1 + + global_scales_offset_X * scales_colwise_stride_dim2; + + // Bounds check for scale writing + const bool row_out_of_bounds = (row_base >= E * N); + if (!row_out_of_bounds && !col_out_of_bounds) { + scales_colwise[scale_idx] = e8m0_biased_scale; + } + + // Store quantized values to shared memory. + // SHMEM E dim is 1 since we parallelize across experts, so always index 0. + const int shmem_e_idx = 0; +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][tid_colwise_X][i][shmem_e_idx] = quantized_values[i]; + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + // For per expert col major, + const int output_tma_x_offset = chunk_offset_X; + const int output_tma_y_offset = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + // Pass in TMA offsets in the same order as the tensor map exists (N, K, E) + // which is fastest moving dim (stride 1) -> slowest moving. + cuda::device::experimental::cp_async_bulk_tensor_3d_shared_to_global( + &tensor_map_output, + output_tma_y_offset, // N + output_tma_x_offset, // K + expert_idx, // E + reinterpret_cast(&out_colwise_sh[buff])); + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + destroy_barriers(mbar, is_master_thread); + // #endif +} + +// Simple wrapper class for MXFP8 quantization +class MXFP8Quantizer { +public: + // Quantize a 2D tensor using MXFP8 + // input: pointer to input data + // output_rowwise: pointer to row-wise quantized output (can be nullptr) + // output_colwise: pointer to column-wise quantized output (can be nullptr) + // scales_rowwise: pointer to row-wise scaling factors (required if + // output_rowwise is not null) scales_colwise: pointer to column-wise scaling factors (required if output_colwise is not null) + // rows, cols: tensor dimensions + // input_dtype: data type of input + // output_dtype: FP8 output type (fp8e4m3 or fp8e5m2) + // scale_dim_x: block size for row-wise scaling (typically 32) + // scale_dim_y: block size for column-wise scaling (typically 32) + static void + quantize(const void *input, void *output_rowwise, void *output_colwise, + e8m0_t *scales_rowwise, e8m0_t *scales_colwise, + size_t scales_rowwise_stride_dim0, size_t scales_rowwise_stride_dim1, + size_t scales_colwise_stride_dim0, size_t scales_colwise_stride_dim1, + size_t rows, size_t cols, DType input_dtype, DType output_dtype, + size_t scale_dim_x = 32, size_t scale_dim_y = 32, + ScaleCalculationMode scaling_mode = ScaleCalculationMode::FLOOR, + cudaStream_t stream = 0) { + + // Check parameters + assert((scale_dim_x == 1 || scale_dim_x == 32) && + (scale_dim_y == 1 || scale_dim_y == 32)); + assert(output_rowwise != nullptr || output_colwise != nullptr); + + if (output_rowwise) + assert(scales_rowwise != nullptr); + if (output_colwise) + assert(scales_colwise != nullptr); + + // Calculate grid dimensions + const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + // Create TMA descriptors + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + int32_t input_bits_per_elem = get_dtype_bits(input_dtype); + int32_t output_bits_per_elem = get_dtype_bits(output_dtype); + + create_2D_tensor_map(tensor_map_input, const_cast(input), + input_dtype, + rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, + cols, // stride of "slowest moving" dim + input_bits_per_elem); // bits per elem in input + + if (output_rowwise) { + create_2D_tensor_map( + tensor_map_output_rowwise, output_rowwise, output_dtype, + rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, + cols, // stride of "slowest moving" dim + output_bits_per_elem); // bits per elem in output fp8e4m3 + } + + if (output_colwise) { + create_2D_tensor_map( + tensor_map_output_colwise, output_colwise, output_dtype, + cols, rows, // Swap for column major layout + MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, + rows, // stride of "slowest moving" dim + output_bits_per_elem); // bits per elem in output fp8e4m3 + } + +// Launch kernel based on input/output types and scaling dimensions +// Only compile kernel launches for SM90+ +#if defined(__CUDACC__) && \ + (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= MIN_CUDA_SM) + + // Use TMA and mbarrier instructions +#define LAUNCH_KERNEL(IType, OType, SCALE_Y, SCALE_X, ScalingMode) \ + mxfp8_quantize_kernel \ + <<>>( \ + tensor_map_input, tensor_map_output_rowwise, \ + tensor_map_output_colwise, scales_rowwise, scales_colwise, rows, \ + cols, scales_rowwise_stride_dim0, scales_rowwise_stride_dim1, \ + scales_colwise_stride_dim0, scales_colwise_stride_dim1); + + // Validate output dtype. + if (output_dtype != DType::kFloat8E4M3) { + printf("unsupported output dtype, must be fp8e4m3\n"); + exit(1); + } + + if (scaling_mode == ScaleCalculationMode::FLOOR) { + if (input_dtype == DType::kFloat32) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 32, ScaleCalculationMode::FLOOR); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(float, fp8e4m3, 1, 32, ScaleCalculationMode::FLOOR); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 1, ScaleCalculationMode::FLOOR); + } + } else if (input_dtype == DType::kBFloat16) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 32, ScaleCalculationMode::FLOOR); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 1, 32, ScaleCalculationMode::FLOOR); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 1, ScaleCalculationMode::FLOOR); + } + } else { + printf("unsupported input dtype, must be float32 or bfloat16\n"); + exit(1); + } + } else if (scaling_mode == ScaleCalculationMode::RCEIL) { + if (input_dtype == DType::kFloat32) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 32, ScaleCalculationMode::RCEIL); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(float, fp8e4m3, 1, 32, ScaleCalculationMode::RCEIL); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(float, fp8e4m3, 32, 1, ScaleCalculationMode::RCEIL); + } + } else if (input_dtype == DType::kBFloat16) { + if (scale_dim_x == 32 && scale_dim_y == 32) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 32, ScaleCalculationMode::RCEIL); + } else if (scale_dim_x == 32 && scale_dim_y == 1) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 1, 32, ScaleCalculationMode::RCEIL); + } else if (scale_dim_x == 1 && scale_dim_y == 32) { + LAUNCH_KERNEL(bfloat16, fp8e4m3, 32, 1, ScaleCalculationMode::RCEIL); + } + } else { + printf("unsupported input dtype, must be float32 or bfloat16\n"); + exit(1); + } + } else { + printf("unsupported scaling mode\n"); + exit(1); + } + +#undef LAUNCH_KERNEL + +#endif + } + + // Quantize a 3D tensor using MXFP8 with colwise scaling + // input: pointer to input data with shape (E, N, K) + // output_colwise: pointer to column-wise quantized output in column major format. + // scales_colwise: pointer to column-wise scaling factors with shape (E, num_n_blocks, K) + // E, N, K: tensor dimensions + // scales_colwise_stride_dim0: stride for E dimension in scales + // scales_colwise_stride_dim1: stride for num_n_blocks dimension in scales + // input_dtype: data type of input + // output_dtype: FP8 output type (fp8e4m3 or fp8e5m2) + // scale_dim_n: block size for column-wise scaling along N dimension (typically 32) + static void + quantize_3d(const void *input, void *output_colwise, e8m0_t *scales_colwise, + const size_t E, size_t N, size_t K, + size_t input_stride_dim0, size_t input_stride_dim1, size_t input_stride_dim2, + size_t output_stride_dim0, size_t output_stride_dim1, size_t output_stride_dim2, + size_t scales_colwise_stride_dim0, size_t scales_colwise_stride_dim1, size_t scales_colwise_stride_dim2, + DType input_dtype, DType output_dtype, + size_t scale_dim_n = 32, + ScaleCalculationMode scaling_mode = ScaleCalculationMode::FLOOR, + cudaStream_t stream = 0) { + + // Check parameters + assert(scale_dim_n == 32); // Only support 32 for now + assert(output_colwise != nullptr); + assert(scales_colwise != nullptr); + + // Calculate grid dimensions for 3D tensor: Z handles E dimension, X,Y handle (N,K) + const size_t chunks_Y = DIVUP(N, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(K, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y, E); // 3D grid: Z dimension handles experts + + // Create TMA descriptors for each expert + // Allocate GPU-accessible memory for TMA descriptors + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + int32_t input_bits_per_elem = get_dtype_bits(input_dtype); + int32_t output_bits_per_elem = get_dtype_bits(output_dtype); + + create_2D_tensor_map( + tensor_map_input, const_cast(input), + input_dtype, + E * N, K, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, + K, // stride of "slowest moving" dim (to increment along E*N dimension, we move K elements) + input_bits_per_elem); // bits per elem in input + + size_t shmem_e = 1; + create_3D_tensor_map_output( + tensor_map_output, + output_colwise, + output_dtype, + E, N, K, + shmem_e, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, // Y = N = rows, X = K = cols + output_bits_per_elem); // bits per elem in input + + + +// Launch 3D kernel based on input/output types and scaling dimensions +// Only compile kernel launches for SM90+ +#if defined(__CUDACC__) && \ + (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= MIN_CUDA_SM) + +// Use TMA and mbarrier instructions for 3D +#define LAUNCH_KERNEL_3D(IType, OType, SCALE_Y, SCALE_X, ScalingMode) \ + mxfp8_quantize_kernel_3d \ + <<>>( \ + tensor_map_input, tensor_map_output, \ + scales_colwise, \ + E, N, K, \ + scales_colwise_stride_dim0, scales_colwise_stride_dim1, scales_colwise_stride_dim2); + + // Validate output dtype + if (output_dtype != DType::kFloat8E4M3) { + printf("unsupported output dtype, must be fp8e4m3\n"); + exit(1); + } + + if (scaling_mode == ScaleCalculationMode::FLOOR) { + if (input_dtype == DType::kFloat32) { + LAUNCH_KERNEL_3D(float, fp8e4m3, 32, 1, ScaleCalculationMode::FLOOR); + } else if (input_dtype == DType::kBFloat16) { + LAUNCH_KERNEL_3D(bfloat16, fp8e4m3, 32, 1, ScaleCalculationMode::FLOOR); + } else { + printf("unsupported input dtype, must be float32 or bfloat16\n"); + exit(1); + } + } else if (scaling_mode == ScaleCalculationMode::RCEIL) { + if (input_dtype == DType::kFloat32) { + LAUNCH_KERNEL_3D(float, fp8e4m3, 32, 1, ScaleCalculationMode::RCEIL); + } else if (input_dtype == DType::kBFloat16) { + LAUNCH_KERNEL_3D(bfloat16, fp8e4m3, 32, 1, ScaleCalculationMode::RCEIL); + } else { + printf("unsupported input dtype, must be float32 or bfloat16\n"); + exit(1); + } + } else { + printf("unsupported scaling mode\n"); + exit(1); + } + +#undef LAUNCH_KERNEL_3D + +#endif + } +}; diff --git a/torchao/csrc/cuda/mx_kernels/ptx.cuh b/torchao/csrc/cuda/mx_kernels/ptx.cuh new file mode 100644 index 0000000000..ba06746dbd --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/ptx.cuh @@ -0,0 +1,290 @@ +// Adapted from https://github.com/NVIDIA/TransformerEngine +// License - Apache-2.0 +// https://github.com/NVIDIA/TransformerEngine/blob/main/LICENSE +// * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Portions (c) Meta Platforms, Inc. and affiliates. + +/*! \file ptx.cuh + * \brief BW PTX + */ + +#include +#include + + +namespace ptx { + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init +__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, + const uint32_t count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval +__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void +mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), + "r"(tx_count) + : "memory"); +} + +__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { + asm volatile("fence.mbarrier_init.release.cluster;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void +cp_async_bulk_tensor_1d_global_to_shared(uint64_t *dst_shmem, + const uint64_t *src_global_ptr, + const uint32_t size, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile("cp.async.bulk.shared::cta.global" + ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"( + dst_shmem_ptr), + "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, + const uint32_t offset_x, const uint32_t offset_y, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"( + dst_shmem_ptr), + "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) + : "memory"); +} + +__device__ __forceinline__ bool +mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { + uint32_t waitComplete; + asm volatile("{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(mbar_ptr), "r"(parity) + : "memory"); + return static_cast(waitComplete); +} + +__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, + const uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global( + uint64_t *dst_global_ptr, const uint64_t *src_shmem, const uint32_t size) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"( + dst_global_ptr), + "r"(src_shmem_ptr), "r"(size) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( + const uint64_t *tensor_map_ptr, const uint32_t offset_x, + const uint32_t offset_y, uint64_t *src_shmem) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, " + "{%1, %2}], [%3];" ::"l"(tensor_map_ptr), + "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +__device__ __forceinline__ void cp_async_bulk_wait_group() { + asm volatile("cp.async.bulk.wait_group 0;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} + +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { + asm volatile("cp.async.bulk.wait_group.read 1;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { + asm volatile("cp.async.bulk.wait_group.read 2;"); +} +template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { + asm volatile("cp.async.bulk.wait_group.read 4;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + +// Proxy fence (bi-directional): +__device__ __forceinline__ void fence_proxy_async() { + asm volatile("fence.proxy.async;"); +} + +__device__ __forceinline__ void fence_proxy_async_shared_cta() { + asm volatile("fence.proxy.async.shared::cta;"); +} + +} // namespace ptx + +namespace { + +template +__forceinline__ __device__ void +initialize_barriers(uint64_t *mbar, const bool is_master_thread) { + if (is_master_thread) { + // Initialize barrier. All `blockDim.x * blockDim.y` threads in block + // participate. +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); +} + +template +__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, + const bool is_master_thread) { + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +} + +__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, + const size_t num_bytes, + uint64_t *barrier, + const bool is_master_thread) { + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_1d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), num_bytes, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +} + +__forceinline__ __device__ void +copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, + const size_t chunk_Y, const size_t num_bytes, + uint64_t *barrier, const bool is_master_thread) { + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X, chunk_Y, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +} + +__forceinline__ __device__ void copy_2d_to_sharedx2( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, const size_t chunk_X2, const size_t chunk_Y2, + const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst2), + reinterpret_cast(src2), chunk_X2, chunk_Y2, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +} + +__forceinline__ __device__ void copy_2d_to_sharedx3( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, const size_t chunk_X2, const size_t chunk_Y2, + void *dst3, const void *src3, const size_t chunk_X3, const size_t chunk_Y3, + const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst2), + reinterpret_cast(src2), chunk_X2, chunk_Y2, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(dst3), + reinterpret_cast(src3), chunk_X3, chunk_Y3, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +} +} // namespace diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp index bfaf6bf466..feff97f56a 100644 --- a/torchao/csrc/rocm/swizzle/swizzle.cpp +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -362,7 +362,7 @@ ScalingType get_scaling_type( // Check for RowWise scaling if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { -#if defined(HIPBLASLT_VEC_EXT) +#if defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC) TORCH_CHECK( scale_a.is_contiguous() && scale_b.is_contiguous(), "Both scale_a and scale_b must be contiguous for RowWise scaling."); @@ -619,17 +619,25 @@ void _scaled_gemm( computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); hipblasLtMatmulDescAttributes_t matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; hipblasLtMatmulDescAttributes_t matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; -#if defined(HIPBLASLT_VEC_EXT) +#if defined(HIPBLASLT_OUTER_VEC) + // this case is handled later with HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F +#elif defined(HIPBLASLT_VEC_EXT) if (use_rowwise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } #else - // rowwise isn't supported using cublaslt or older hipblaslt + // rowwise isn't supported using older hipblaslt TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); #endif computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); +#if defined(HIPBLASLT_OUTER_VEC) + if (use_rowwise) { + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + } +#endif if (result_scale_ptr != nullptr) { computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); } diff --git a/torchao/experimental/op_lib.py b/torchao/csrc_meta_ops.py similarity index 58% rename from torchao/experimental/op_lib.py rename to torchao/csrc_meta_ops.py index 456b0ca160..771bbfc4ce 100644 --- a/torchao/experimental/op_lib.py +++ b/torchao/csrc_meta_ops.py @@ -4,50 +4,10 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from pathlib import Path - import torch from torch import Tensor from torch.library import impl -# Load C++ ops - use multiple potential paths -potential_paths = [ - # Standard path from the module location - Path(__file__).parent.parent, - # Site-packages installation path - Path(torch.__file__).parent.parent / "torchao", - # For editable installs - Path(__file__).parent.parent.parent / "torchao", -] - - -def find_and_load_libtorchao_ops(potential_paths): - for lib_path in potential_paths: - libs = list(lib_path.glob("libtorchao_ops_aten.*")) - - if not libs: - continue - - assert len(libs) == 1, ( - f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}" - ) - - target_lib = libs[0] - print(f"Found library at: {target_lib}") - - try: - torch.ops.load_library(str(target_lib)) - return - except Exception as e: - print(f"Error loading library from {target_lib}: {e}") - - raise FileNotFoundError( - "Could not find libtorchao_ops_aten library in any of the provided paths" - ) - - -find_and_load_libtorchao_ops(potential_paths) - # Define meta ops. To support dynamic shapes, some meta ops need to # be defined in python instead of C++. torchao_lib = torch.library.Library("torchao", "IMPL") @@ -84,3 +44,20 @@ def _(packed_weights: Tensor, group_size: int, n: int, k: int, indices: Tensor): assert indices.dim() == 1 num_out = indices.shape[0] return torch.empty(num_out, k, dtype=torch.float32, device="meta") + + +for weight_nbit in range(1, 5): + + @impl(torchao_lib, f"_linear_groupwise_{weight_nbit}bit_weight_with_lut", "Meta") + def _( + activations: Tensor, + packed_weights: Tensor, + scale_group_size: int, + lut_group_size: int, + n: int, + k: int, + ): + assert activations.dim() == 2 + m, k_ = activations.shape + assert k_ == k + return torch.empty(m, n, dtype=activations.dtype, device="meta") diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index b0dde2cf10..07f03c7ed9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -8,8 +8,6 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8 -from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4 from .floatx import ( CutlassSemiSparseLayout, Float8Layout, @@ -64,9 +62,8 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", - "to_fbgemm_int4", - "FbgemmInt4Tensor", "to_fbgemm_fp8", "FbgemmFp8Tensor", "Int8DynamicActInt4WeightCPULayout", + "Int4GroupwisePreshuffleTensor", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 39f9131a9e..0d7ed8d9e2 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -19,10 +19,10 @@ MappingType, ZeroPointDomain, _choose_qparams_affine_dont_preserve_zero, - _choose_qparams_affine_float8, _choose_qparams_affine_floatx, _choose_qparams_affine_tinygemm, _choose_qparams_and_quantize_affine_hqq, + _choose_scale_float8, _dequantize_affine_float8, _dequantize_affine_floatx, _dequantize_affine_no_zero_point, @@ -35,10 +35,7 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor logger = logging.getLogger(__name__) aten = torch.ops.aten @@ -119,6 +116,7 @@ def __init__( dtype=None, strides=None, ): + torch._C._log_api_usage_once(str(type(self))) self.tensor_impl = tensor_impl self.block_size = block_size self.quant_min = quant_min @@ -247,6 +245,9 @@ def from_hp_to_intx( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, ): """Convert a high precision tensor to an integer affine quantized tensor.""" original_shape = input_float.shape @@ -290,7 +291,13 @@ def from_hp_to_intx( ) data = data.to(target_dtype) else: - if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: + if custom_scale is None != custom_zero_point is None: + raise ValueError( + "`custom_scale` and `custom_zero_point` must be both defined or both None" + ) + if custom_scale is not None and custom_zero_point is not None: + scale, zero_point = custom_scale, custom_zero_point + elif zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: scale, zero_point = _choose_qparams_affine_tinygemm( input_float, mapping_type, @@ -462,7 +469,7 @@ def from_hp_to_floatx( if target_dtype in FP8_TYPES: original_shape = input_float.shape input_float = _layout.pre_process(input_float) - scale = _choose_qparams_affine_float8( + scale = _choose_scale_float8( input_float, float8_dtype=target_dtype, block_size=block_size ) data = _quantize_affine_float8(input_float, scale, target_dtype) @@ -613,6 +620,5 @@ def _apply_fn_to_data(self, fn): # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([AffineQuantizedTensor]) +# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([AffineQuantizedTensor]) diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py deleted file mode 100644 index b6c1d72acc..0000000000 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ /dev/null @@ -1,277 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Optional - -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, - fill_defaults, -) - -__all__ = [ - "to_fbgemm_fp8", - "FbgemmFp8Tensor", -] - -aten = torch.ops.aten - - -class FbgemmFp8Tensor(TorchAOBaseTensor): - """ - TODO: needs padding for cutlass kernels - """ - - tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"] - tensor_attributes = ["dtype"] - - def __new__(cls, float8_data, scale, activation_scale_ub, dtype): - shape = float8_data.shape - kwargs = {} - kwargs["device"] = float8_data.device - kwargs["dtype"] = dtype - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, float8_data, scale, activation_scale_ub, dtype): - self.float8_data = float8_data - self.scale = scale - self.activation_scale_ub = activation_scale_ub - - def __tensor_flatten__(self): - return self.tensor_data_attrs, [ - getattr(self, attr) for attr in self.tensor_attributes - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - return cls( - *[tensor_data_dict[name] for name in cls.tensor_data_attrs], - *tensor_attributes, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], - *[getattr(self, attr) for attr in self.tensor_attributes], - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(weight={self.float8_data}, scale={self.scale}, " - f"activation_scale_ub={self.activation_scale_ub}, " - f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - def _quantization_type(self): - return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}" - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.float8_data.to(device), - self.scale.to(device), - self.activation_scale_ub.to(device), - self.dtype, - ) - - @classmethod - def from_float( - cls, - w: torch.Tensor, - activation_scale_ub: Optional[float] = None, - transpose_input: bool = False, - ): - if activation_scale_ub is None: - activation_scale_ub = 1200.0 - - activation_scale_ub = torch.tensor( - [activation_scale_ub], - dtype=torch.float, - device=w.device, - ) - if transpose_input: - if w.ndim == 3: - w = w.transpose(-1, -2) - else: - w = w.t() - - wq, w_scale = torch.ops.triton.quantize_fp8_row(w) - # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) - dtype = w.dtype - del w - return FbgemmFp8Tensor( - wq, - w_scale, - activation_scale_ub=activation_scale_ub, - dtype=dtype, - ) - - -implements = FbgemmFp8Tensor.implements - - -@implements([torch.nn.functional.linear, aten.linear.default]) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - orig_act_size = input_tensor.size() - orig_out_features = weight_tensor.shape[-2] - - # not used - num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - input_tensor, num_tokens, weight_tensor.activation_scale_ub - ) - - a_data = xq - b_data = weight_tensor.float8_data - - res = torch.ops.fbgemm.f8f8bf16_rowwise( - a_data, - b_data, - x_scale, - weight_tensor.scale, - use_fast_accum=True, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: - res = res + bias - - return res - - -@implements(torch.bmm) -def _(func, types, args, kwargs): - input_tensor, weight_tensor = ( - args[0], - args[1], - ) - orig_act_size = input_tensor.size() - # not used - num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - input_tensor, num_tokens, weight_tensor.activation_scale_ub - ) - - a_data = xq - b_data = weight_tensor.float8_data - orig_out_features = b_data.shape[-2] - - res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( - a_data, - b_data, - x_scale, - weight_tensor.scale, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - return res - - -@implements([aten.detach.default, aten.alias.default]) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -def _same_metadata(self: "FbgemmFp8Tensor", src: "FbgemmFp8Tensor") -> bool: - return ( - isinstance(self, FbgemmFp8Tensor) - and isinstance(src, FbgemmFp8Tensor) - and self.shape == src.shape - and self.float8_data.shape == src.float8_data.shape - and self.scale.shape == src.scale.shape - and self.activation_scale_ub.shape == src.activation_scale_ub.shape - and self.dtype == src.dtype - ) - - -@implements(aten.copy_.default) -def _(func, types, args, kwargs): - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - """Only supports slicing for dim == 1 and dim == 2 - original tensor shape has dimension (N, K) - float8_data has dimension (N, K) - scale (per row quantization) has dimension: (N,) - - since float8_data has the same dimension as original tensor, we can directly slice that - for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1 - - Note that we need to call slice on the float8_data and scale directly because slice - is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_fp8` - for - """ - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1 - assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" - if end >= self.shape[dim]: - end = self.shape[dim] - - assert self.float8_data.ndim == 2, ( - f"Expected packed weight to have dim 2, got {self.float8_data.dim}" - ) - - # Always slice the float8_data - sliced_data = aten.slice.Tensor( - self.float8_data, dim, start, end, step - ).contiguous() - - if dim == 0: - # scale has dimension (N,) where N is the dim 0 of `self` - # so we do the same slice on scale for dimension 0 - sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step) - else: - # since scale is per row, slicing along the dim == 1 dimension does - # not change the scale - sliced_scale = self.scale - - return return_and_correct_aliasing( - func, - args, - kwargs, - FbgemmFp8Tensor( - sliced_data, sliced_scale, self.activation_scale_ub, dtype=self.dtype - ), - ) - - -to_fbgemm_fp8 = FbgemmFp8Tensor.from_float - - -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([FbgemmFp8Tensor]) diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py deleted file mode 100644 index 0c00ee1a81..0000000000 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import List - -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, - fill_defaults, -) - -__all__ = [ - "to_fbgemm_int4", - "FbgemmInt4Tensor", -] - -aten = torch.ops.aten - - -try: - from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4 -except: - int4_row_quantize_zp = None - pack_int4 = None - - -class FbgemmInt4Tensor(TorchAOBaseTensor): - tensor_data_attrs = ["packed_weight", "scale", "zero_point"] - tensor_attributes = ["group_size", "shape"] - - def __new__(cls, packed_weight, scale, zero_point, group_size, shape): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["dtype"] = scale.dtype - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, packed_weight, scale, zero_point, group_size, shape): - self.packed_weight = packed_weight - self.scale = scale - self.zero_point = zero_point - self.group_size = group_size - - def __tensor_flatten__(self): - return self.tensor_data_attrs, [ - getattr(self, attr) for attr in self.tensor_attributes - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - return cls( - *[tensor_data_dict[name] for name in cls.tensor_data_attrs], - *tensor_attributes, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], - *[getattr(self, attr) for attr in self.tensor_attributes], - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, " - f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - def _quantization_type(self): - return f"shape={self.shape}, group_size={self.group_size}, device={self.device}" - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.packed_weight.to(device), - self.scale.to(device), - self.zero_point.to(device), - self.group_size, - self.shape, - ) - - @classmethod - def from_float( - cls, - w: torch.Tensor, - block_size: List[int], - transpose_input: bool = False, - ): - assert len(block_size) == w.ndim, ( - f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" - ) - if int4_row_quantize_zp is None: - raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - - if transpose_input: - if w.ndim == 3: - w = w.transpose(-1, -2) - else: - w = w.t() - - group_size = block_size[-1] - original_shape = w.shape - - if w.ndim >= 3: - wq, scale, zero_point = zip( - *[int4_row_quantize_zp(i, group_size) for i in w], strict=False - ) - wq = torch.stack([pack_int4(i) for i in wq], dim=0) - scale = torch.stack(scale, dim=0) - zero_point = torch.stack(zero_point, dim=0) - else: - wq, scale, zero_point = int4_row_quantize_zp(w, group_size) - wq = pack_int4(wq) - - scale = scale.to(w.dtype) - zero_point = zero_point.to(w.dtype) - - del w - return FbgemmInt4Tensor( - packed_weight=wq, - scale=scale, - zero_point=zero_point, - group_size=group_size, - shape=original_shape, - ) - - -implements = FbgemmInt4Tensor.implements - - -@implements([torch.nn.functional.linear, aten.linear.default]) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - orig_act_size = input_tensor.size() - orig_out_features = weight_tensor.shape[-2] - - res = torch.ops.fbgemm.bf16i4bf16_rowwise( - input_tensor, - weight_tensor.packed_weight.contiguous(), - weight_tensor.scale, - weight_tensor.zero_point, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: - res = res + bias - return res - - -@implements(torch.bmm) -def _(func, types, args, kwargs): - input_tensor, weight_tensor = ( - args[0], - args[1], - ) - orig_act_size = input_tensor.size() - orig_out_features = weight_tensor.shape[-2] - - res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( - input_tensor, - weight_tensor.packed_weight.contiguous(), - weight_tensor.scale, - weight_tensor.zero_point, - ) - res = res.reshape(*orig_act_size[:-1], orig_out_features) - return res - - -@implements([aten.detach.default, aten.alias.default]) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -def _same_metadata(self: "FbgemmInt4Tensor", src: "FbgemmInt4Tensor") -> bool: - return ( - isinstance(self, FbgemmInt4Tensor) - and isinstance(src, FbgemmInt4Tensor) - and self.shape == src.shape - and self.packed_weight.shape == src.packed_weight.shape - and self.scale.shape == src.scale.shape - and self.zero_point.shape == src.zero_point.shape - and self.group_size == src.group_size - ) - - -@implements(aten.copy_.default) -def _(func, types, args, kwargs): - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - """Only supports slicing for dim == 1 and dim == 2 - packed_weight has dimension: (N, K/2) - scale and zero_point has dimension: (K/groups, N) - - dim, start, end, step are args that's referring to the original tensor shape - which is (N, K), and we need to map that to the transformed weight shape of packed_weight, - scale and zero_point - - when dim == 0: we do a slice on packed_weight dim 0, and on dim 1 of scale and zero_point, - also adjust the start and end indexes based on the ratio between original shape and the shape - of packed_weight and scale/zero_point - - when dim == 1: we do a slice on packed_weight dim 1 and dim 0 of scale and zero_point and do the - same adjustment based on ratio - - Note that we need to call slice on the packed_weight, scale and zero_point directly because slice - is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_int4` - for - """ - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1 - assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" - if end >= self.shape[dim]: - end = self.shape[dim] - - assert self.packed_weight.ndim == 2, ( - f"Expected packed weight to have dim 2, got {self.packed_weight.dim}" - ) - N, K_by_2 = self.packed_weight.shape - sz_dim0, sz_dim1 = self.scale.shape - - data_len = self.shape[dim] - - if dim == 0: - pw_len = N - sz_len = sz_dim1 - else: - pw_len = K_by_2 - sz_len = sz_dim0 - - sz_dim = 1 - dim - if pw_len == 0 or sz_len == 0: - return return_and_correct_aliasing( - func, - args, - kwargs, - self.__class__( - self.packed_weight, - self.scale, - self.zero_point, - group_size=self.group_size, - shape=self.shape, - ), - ) - - pw_ratio = data_len / pw_len - start_pw = int(start / pw_ratio) - end_pw = int(end / pw_ratio) - - sz_ratio = data_len / sz_len - start_sz = int(start / sz_ratio) - end_sz = int(end / sz_ratio) - - packed_weight = aten.slice.Tensor(self.packed_weight, dim, start_pw, end_pw, step) - scale = aten.slice.Tensor(self.scale, sz_dim, start_sz, end_sz, step) - zero_point = aten.slice.Tensor(self.zero_point, sz_dim, start_sz, end_sz, step) - packed_shape0, packed_shape1 = packed_weight.shape - new_shape = (packed_shape0, packed_shape1 * 2) - new = self.__class__( - packed_weight, scale, zero_point, group_size=self.group_size, shape=new_shape - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -to_fbgemm_int4 = FbgemmInt4Tensor.from_float - - -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with FbgemmInt4Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([FbgemmInt4Tensor]) diff --git a/torchao/dtypes/floatx/README.md b/torchao/dtypes/floatx/README.md index 16aec8362b..092ef01233 100644 --- a/torchao/dtypes/floatx/README.md +++ b/torchao/dtypes/floatx/README.md @@ -9,7 +9,7 @@ This kernel was originally designed for FP16, but was extended to work for BF16 ```python from torchao.quantization import ( quantize_, - fpx_weight_only, + FPXWeightOnlyConfig, ) model = ... @@ -17,7 +17,7 @@ model = ... # for generic Floatx EyMz where x = 1 + y + z # fp6 with ebits = 3 and mbits = 2 -quantize_(model, fpx_weight_only(3, 2)) +quantize_(model, FPXWeightOnlyConfig(3, 2)) # fully compatible with torch.compile() model.compile(mode="max-autotune", fullgraph=True) diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index 45fe451712..e49e8e8129 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -100,6 +100,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) + elif func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + elif func is aten.to.dtype_layout: + dense, scale, _ = args[0].get_plain() + product = dense.to(scale.dtype) * scale + return product.to( + *args[1:], + dtype=kwargs.get("dtype", dense.dtype), + device=kwargs.get("device", dense.device), + ) raise NotImplementedError( f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -123,11 +135,12 @@ def get_plain(self): # semi-structured format, so multiplying with identity matrix, # and using identity scale factors, for the conversion. cols = self.shape[1] - input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device) - input_scale = torch.ones( - (cols,), dtype=self.scale.dtype, device=self.sparse.device - ) + plain_input = torch.eye(cols, device=self.sparse.device) + input = plain_input.to(dtype=self.sparse.dtype) + plain_input_scale = torch.ones((cols,), device=self.sparse.device) + input_scale = plain_input_scale.to(dtype=self.scale.dtype) sparse_scale = torch.ones_like(self.scale) + out_dtype = torch.bfloat16 dense = ( rowwise_scaled_linear_sparse_cutlass_f8f8( diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 40091d2667..4afc5fdfee 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -20,8 +21,10 @@ from torchao.float8.inference import ( Float8MMConfig, _is_rowwise_scaled, + _slice_scale_for_dimension, addmm_float8_unwrapped_inference, preprocess_data, + preprocess_scale, ) from torchao.utils import _is_float8_type, fill_defaults @@ -107,6 +110,9 @@ def __init__( transposed: bool, _layout: Layout, ): + warnings.warn( + "Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2649 for more details" + ) self.float8_data = float8_data self.scale = scale self.transposed = transposed @@ -299,56 +305,6 @@ def _(func, types, args, kwargs): ) -def _slice_scale_for_dimension( - scale: torch.Tensor, - data_shape: List[int], - dim: int, - start: int, - end: int, - step: int, -) -> torch.Tensor: - """ - Slice the scale tensor appropriately based on the data tensor slicing. - - This function calculates how the scale should be sliced when the data tensor - is sliced along a given dimension, taking into account the block structure. - """ - # Unsupported case for now, this would be 1 scale per data element - if scale.shape == data_shape: - return aten.slice.Tensor(scale, dim, start, end, step) - - # Reconstruct block sizes based on data shape and scale shape - block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape))) - - if dim >= len(block_sizes): - # Slicing beyond the dimensions we care about - return scale - - block_size_for_dim = block_sizes[dim] - - if block_size_for_dim == 1: - # Scale is per-element along this dimension - # Slice away as normal - return aten.slice.Tensor(scale, dim, start, end, step) - else: - # There is blocking in this dimension - # Calculate which scale elements correspond to the sliced data - scale_start = start // block_size_for_dim if start is not None else None - scale_end = ( - (end + block_size_for_dim - 1) // block_size_for_dim - if end is not None - else None - ) - - # Error on Step > 1 - if step > 1: - raise NotImplementedError( - "Slicing with step > 1 is not implemented for scale tensors." - ) - - return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) - - ########################## # Float8 Dispatch Kernels ########################## @@ -370,24 +326,6 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return check_aqt(input_tensor) and check_aqt(weight_tensor) -def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int, ...]): - """Ensures input tensor is correctly formatted for _scaled_mm""" - - # For PerTensor quantization, scale should be a scalar or have shape [1] - if input_scale.numel() == 1: - # Already a scalar, ensure it has the right shape for _scaled_mm - return input_scale.reshape(1, 1) - - # For per-row/block quantization, we need to handle the reshaping - input_scale = input_scale.unsqueeze(-1) - - # Match: #input_data.reshape(-1, input_data.shape[-1]) - if input_scale.dim() > 2: - input_scale = input_scale.reshape(-1, input_scale.shape[-1]) - - return input_scale - - def _linear_fp8_act_fp8_weight_impl( input_tensor: "AffineQuantizedTensor", weight_tensor: "AffineQuantizedTensor", diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 698a9391bd..5542a9de58 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -15,14 +15,56 @@ from torch._prims_common import make_contiguous_strides_for from torch.distributed.device_mesh import DeviceMesh -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional -NF4_OPS_TABLE: Dict[Any, Any] = {} +def nf4_all_gather_into_tensor(func, *args, **kwargs): + assert len(args) > 1, "Expected valid input" + assert len(args[0]) == 3, "Expected 3 input args" + nf4tensor = args[0][0] + group_size = args[0][1] + name = args[0][2] + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + updated_attrs[attr] = func(getattr(nf4tensor, attr), group_size, name) + updated_attrs.update( + { + "size": torch.Size((nf4tensor.size()[0] * group_size, nf4tensor.size()[1])), + } + ) + updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + return updatedNF4Tensor + + +def scatter_nf4tensor(func, *args, **kwargs): + assert len(args) > 1, "Expected valid input" + assert len(args[0][0]) == 1, "Expected 1 output tensor" + output_tensor = args[0][0][0] + input_tensors = args[0][1] + new_attr, update_work = [], [] + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + input_attrs = [] + if input_tensors: + for input_tensor in input_tensors[0]: + assert input_tensor.size() == output_tensor.size(), ( + "Input tensor size must match output tensor size, tensors are not evenly divided." + ) + if hasattr(input_tensor, attr): + input_attrs.append(getattr(input_tensor, attr)) + input_attrs = [input_attrs] + new_attr, update_work = func( + [getattr(output_tensor, attr)], input_attrs, *args[0][2:] + ) + # there are 3 works, return one of them, same as the tensor to fit the required output format + return new_attr, update_work + + +NF4_OPS_TABLE: Dict[Any, Any] = { + torch.ops._c10d_functional.all_gather_into_tensor.default: nf4_all_gather_into_tensor, + torch.ops.c10d.scatter_.default: scatter_nf4tensor, +} _INNER_TENSOR_NAMES_FOR_SHARDING = [ @@ -233,7 +275,6 @@ def nf4_split(aten_op, args, kwargs=None): def nf4_new_zeros(aten_op, args, kwargs=None): nf4tensor = args[0] new_size = tuple(args[1]) - if nf4tensor.numel() % math.prod(new_size) != 0: raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") ratio = nf4tensor.numel() // math.prod(new_size) @@ -273,19 +314,37 @@ def nf4_slice(aten_op, args, kwargs=None): aten.view.default, ] ) -@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") def nf4_view(aten_op, args, kwargs=None): nf4tensor = args[0] size = args[1] - if size[0] != -1: - raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") - updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) - updated_attrs.update( - { - "size": [nf4tensor.numel()], - "stride": (1,), - } - ) + if len(size) == 1: + if size[0] != -1: + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") + else: + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + updated_attrs.update( + { + "size": [nf4tensor.numel()], + "stride": (1,), + } + ) + elif len(size) == 2: + if nf4tensor.numel() != size[0] * size[1]: + raise NotImplementedError("NF4Tensor size does not match view size.") + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + attr_size = [getattr(nf4tensor, attr).size()] + updated_attrs[attr] = aten_op( + getattr(nf4tensor, attr), *attr_size, **kwargs + ) + updated_attrs.update( + { + "stride": (size[1], 1), + } + ) + else: + raise NotImplementedError("aten.view(NF4Tensor) with empty size") return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) @@ -457,6 +516,20 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): return tensors +@implements( + [ + torch.ops._c10d_functional.wait_tensor.default, + ] +) +def wait_tensor(func, *args, **kwargs): + nf4tensor = args[0][0] + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + updated_attrs[attr] = func(getattr(nf4tensor, attr)) + updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + return updatedNF4Tensor + + @dataclass(frozen=True) class SubclassTensorArgs: original_shape: torch.Size @@ -868,7 +941,7 @@ def allowed_subclasses(type): f"NF4Tensor dispatch: attempting to run {func}, this is not supported" ) - # Do not force the Float8Tensor type on the returned tensor + # Do not force the Float8TrainingTensor type on the returned tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -1081,6 +1154,5 @@ def nf4_constructor( ) -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals([NF4Tensor]) - torch.serialization.add_safe_globals([NF4Tensor]) +torch.serialization.add_safe_globals([NF4Tensor]) +torch.serialization.add_safe_globals([NF4Tensor]) diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index ced7ec0dd8..c0f2fcdfe5 100644 --- a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -16,10 +16,7 @@ register_layout, ) from torchao.dtypes.utils import Layout, PlainLayout, is_device -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_7, - TORCH_VERSION_AT_LEAST_2_8, -) +from torchao.utils import torch_version_at_least from .int4_cpu_layout import ( Int4CPUAQTTensorImpl, @@ -124,6 +121,10 @@ def from_plain( if zero_point.dim() == 1: zero_point.unsqueeze_(-1) + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack the inner blocks [block_k, block_n] to VNNI layout if AMX is available. + # Pack scales/qzeros from [N, num_groups] to [N / block_n, num_groups, block_n]. + # Compensation shape = [N / block_n, K / block_k, block_n]. weight_int4, scales, qzeros, compensation = ( torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) ) @@ -242,7 +243,7 @@ def _aqt_is_uint4(aqt): def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): return ( - TORCH_VERSION_AT_LEAST_2_7 + torch_version_at_least("2.7.0") and is_device(input_tensor.device.type, "cpu") and is_device(weight_tensor.device.type, "cpu") and (bias is None or is_device(bias.device.type, "cpu")) @@ -258,11 +259,11 @@ def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert TORCH_VERSION_AT_LEAST_2_7, ( + assert torch_version_at_least("2.7.0"), ( f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" ) if _aqt_is_int8(input_tensor): - assert TORCH_VERSION_AT_LEAST_2_8, ( + assert torch_version_at_least("2.8.0"), ( f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" ) assert is_device(input_tensor.device.type, "cpu"), ( @@ -310,3 +311,9 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): y = y.reshape(*orig_act_size[:-1], orig_out_features) return y.to(orig_dtype) + + +# Register the concat linear fusion pass +from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass + +register_da8w4_concat_linear_cpu_pass() diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index da19bbc259..1ae9dca3b6 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -21,11 +22,7 @@ ZeroPointDomain, _quantize_affine_tinygemm, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - fill_defaults, -) +from torchao.utils import fill_defaults aten = torch.ops.aten @@ -82,6 +79,9 @@ def __init__( transposed: bool, _layout: Layout, ): + warnings.warn( + "Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details" + ) self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False @@ -114,29 +114,13 @@ def from_plain( ): assert isinstance(_layout, Int4CPULayout) - if TORCH_VERSION_AT_LEAST_2_6: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( - int_data, - 1, # TODO:remove - ) - elif TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) - else: - assert int_data.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - ) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + ) + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # TODO:remove + ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) @@ -284,8 +268,7 @@ def _is_float(dtype): def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias): return ( - TORCH_VERSION_AT_LEAST_2_6 - and is_device(input_tensor.device.type, "cpu") + is_device(input_tensor.device.type, "cpu") and is_device(weight_tensor.device.type, "cpu") and (bias is None or is_device(bias.device.type, "cpu")) and not is_traceable_wrapper_subclass(input_tensor) @@ -300,9 +283,6 @@ def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias): def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert TORCH_VERSION_AT_LEAST_2_6, ( - f"Requires PyTorch version at least 2.6, but got: {torch.__version__}" - ) assert is_device(input_tensor.device.type, "cpu"), ( f"For CPU device only but got: {input_tensor.device}" ) diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index 955a7a8610..ff6dc68813 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -20,8 +21,8 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, fill_defaults, + torch_version_at_least, ) aten = torch.ops.aten @@ -207,6 +208,9 @@ def __init__( scale: torch.Tensor = None, zero: torch.Tensor = None, ): + warnings.warn( + "Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details" + ) self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False @@ -248,7 +252,7 @@ def from_plain( ): assert isinstance(_layout, Int4XPULayout) - if TORCH_VERSION_AT_LEAST_2_8: + if torch_version_at_least("2.8.0"): assert int_data.dtype == torch.int32, ( "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" ) diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index af1f8040f6..cba2428d94 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass import torch @@ -158,6 +159,9 @@ def __init__( group_size: int, num_bits: int, ): + warnings.warn( + "Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details" + ) self.int_data = int_data self.scale_and_zero = None self.scale = scale diff --git a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py index dc7b073f32..dcae80f365 100644 --- a/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/dtypes/uintx/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +import warnings from enum import Enum, auto from typing import Optional, Tuple, Union @@ -13,13 +14,14 @@ from torchao.dtypes.affine_quantized_tensor import register_layout from torchao.dtypes.utils import AQTTensorImpl, Layout -from torchao.experimental.op_lib_utils import _check_torchao_ops_loaded from torchao.quantization.quant_primitives import ( _DTYPE_TO_BIT_WIDTH, _DTYPE_TO_QVALUE_BOUNDS, ZeroPointDomain, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + _is_kernel_library_loaded, +) logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -76,6 +78,9 @@ def __init__( self, target: Union[str, Target] = "auto", ): + warnings.warn( + "Models quantized with version 1 of IntxWeightOnlyConfig/Int8DynamicActivationIntxWeightConfig are deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2967 for more details" + ) if isinstance(target, str): target = target_from_str(target) self.target = target @@ -168,11 +173,8 @@ def from_plain( ) if layout.target != Target.ATEN: - _check_torchao_ops_loaded() + assert _is_kernel_library_loaded(), "Kernel library is not loaded" else: - assert TORCH_VERSION_AT_LEAST_2_6, ( - "aten target is requires torch version > 2.6.0" - ) assert torch.backends.kleidiai.is_available(), ( "ATEN target requires torch.backends.kleidiai.is_available()" ) @@ -378,7 +380,6 @@ def _impl_2d_aten(input_tensor, weight_tensor): ) if target == Target.ATEN: - assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0" _impl_2d = _impl_2d_aten else: _impl_2d = _impl_2d_non_aten @@ -420,11 +421,6 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor( Constructs an AffineQuantizedTensor with PackedLinearInt8DynamicActivationIntxWeightLayout from plain data. """ - # TORCH_VERSION_AT_LEAST_2_6 is needed for torch.intx with x < 8 - assert TORCH_VERSION_AT_LEAST_2_6, ( - "Using PackedLinearInt8DynamicActivationIntxWeightLayout requires torch version > 2.6.0" - ) - layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=target) bit_width = _DTYPE_TO_BIT_WIDTH[data_dtype] diff --git a/torchao/dtypes/uintx/q_dq_layout.py b/torchao/dtypes/uintx/q_dq_layout.py index 0ae1d865e8..be2c7fe16c 100644 --- a/torchao/dtypes/uintx/q_dq_layout.py +++ b/torchao/dtypes/uintx/q_dq_layout.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +import warnings import torch @@ -95,6 +96,9 @@ def __init__( zero_point: Optional[torch.Tensor], _layout: Layout, ): + warnings.warn( + "Models quantized with version 1 of IntxWeightOnlyConfig/Int8DynamicActivationIntxWeightConfig are deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2967 for more details" + ) self.int_data = int_data self.scale = scale self.zero_point = zero_point diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 591d9a9be1..1961cc33c5 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import logging +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -24,7 +25,6 @@ _quantize_affine_tinygemm, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, fill_defaults, find_multiple, ) @@ -238,6 +238,9 @@ def __init__( transposed: bool, _layout: Layout, ): + warnings.warn( + "Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details" + ) self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False @@ -274,14 +277,9 @@ def from_plain( ) def quant_2d(int_data_2d): - if TORCH_VERSION_AT_LEAST_2_5: - int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( - torch.uint8 - ) - else: - assert int_data_2d.dtype == torch.int32, ( - "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - ) + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( + torch.uint8 + ) return torch.ops.aten._convert_weight_to_int4pack( int_data_2d.contiguous(), _layout.inner_k_tiles ) diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index 96e5401de5..3180e9f2c9 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -14,7 +14,7 @@ from torchao.dtypes.utils import ( Layout, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TorchAOBaseTensor +from torchao.utils import TorchAOBaseTensor from .bitpacking import pack, unpack @@ -24,20 +24,17 @@ _DTYPE_TO_BIT_WIDTH = {} _BIT_WIDTH_TO_DTYPE = {} -if TORCH_VERSION_AT_LEAST_2_3: - _DTYPE_TO_BIT_WIDTH = { - torch.uint1: 1, - torch.uint2: 2, - torch.uint3: 3, - torch.uint4: 4, - torch.uint5: 5, - torch.uint6: 6, - torch.uint7: 7, - } - - _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} -else: - print("uintx feature requires torch 2.3+, please upgrade pytorch") +_DTYPE_TO_BIT_WIDTH = { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, +} + +_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} class UintxTensor(TorchAOBaseTensor): diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index a07188a18d..0a81172112 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -68,6 +68,9 @@ def __repr__(self): def extra_repr(self) -> str: return "" + def __post_init__(self): + torch._C._log_api_usage_once(str(type(self))) + @dataclass(frozen=True) class PlainLayout(Layout): diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index 1d3c28508e..84582f704e 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -17,12 +17,7 @@ endif() # Platform options option(TORCHAO_BUILD_ATEN_OPS "Building torchao ops for ATen." ON) -option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF) option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF) -option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF) -option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) -option(TORCHAO_ENABLE_ARM_NEON_DOT "Enable ARM Neon Dot Product extension" OFF) -option(TORCHAO_ENABLE_ARM_I8MM "Enable ARM 8-bit Integer Matrix Multiply instructions" OFF) if(NOT TORCHAO_INCLUDE_DIRS) set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -36,96 +31,17 @@ endif() add_compile_options("-Wall" "-Werror" "-Wno-deprecated" "-Wno-shorten-64-to-32") include(CMakePrintHelpers) -include(${CMAKE_CURRENT_SOURCE_DIR}/Utils.cmake) message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") include_directories(${TORCHAO_INCLUDE_DIRS}) -# Build cpu/aarch64 kernels -if(TORCHAO_BUILD_CPU_AARCH64) - message(STATUS "Building with cpu/aarch64") - add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) - - # Set aarch64 compiler options - if (CMAKE_SYSTEM_NAME STREQUAL "Linux") - message(STATUS "Add aarch64 linux compiler options") - add_compile_options( - "-fPIC" - "-Wno-error=unknown-pragmas" - "-Wno-array-parameter" - "-Wno-maybe-uninitialized" - "-Wno-sign-compare" - ) - - # Since versions are hierarchical (each includes features from prior versions): - # - dotprod is included by default in armv8.4-a and later - # - i8mm is included by default in armv8.6-a and later - if(TORCHAO_ENABLE_ARM_I8MM) - message(STATUS "Using armv8.6-a (includes 'i8mm' and 'dotprod' flags)") - add_compile_options("-march=armv8.6-a") - elseif(TORCHAO_ENABLE_ARM_NEON_DOT) - message(STATUS "Using armv8.4-a (includes '+dotprod' flag)") - add_compile_options("-march=armv8.4-a") - endif() - endif() - - if(TORCHAO_ENABLE_ARM_NEON_DOT) - message(STATUS "Building with ARM NEON dot product support") - add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) - add_compile_options("-march=armv8.4-a+dotprod") - endif() - - if(TORCHAO_ENABLE_ARM_I8MM) - message(STATUS "Building with ARM I8MM support") - add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) - endif() - - if(TORCHAO_BUILD_KLEIDIAI) - message(STATUS "Building with Arm KleidiAI library") - add_compile_definitions(TORCHAO_ENABLE_KLEIDI) - endif() - - # Defines torchao_kernels_aarch64 - add_subdirectory(kernels/cpu/aarch64) -endif() - - - -if (NOT TARGET cpuinfo) - # For some reason cpuinfo package has unused functions/variables - # TODO (T215533422): fix upstream - set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "Disable unit tests" FORCE) - set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "Disable mock tests" FORCE) - set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "Disable benchmarks" FORCE) - add_compile_options(-Wno-unused-function -Wno-unused-variable) - include(FetchContent) - FetchContent_Declare(cpuinfo - GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git - GIT_TAG c61fe919607bbc534d7a5a5707bdd7041e72c5ff) - FetchContent_MakeAvailable( - cpuinfo) -endif() - # Build ATen ops if(TORCHAO_BUILD_ATEN_OPS) find_package(Torch REQUIRED) - set(_torchao_op_srcs_aten) - list(APPEND _torchao_op_srcs_aten - ops/embedding_xbit/op_embedding_xbit_aten.cpp - ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp - ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp - ) - list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") - add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten}) - target_link_torchao_parallel_backend(torchao_ops_aten "${TORCHAO_PARALLEL_BACKEND}") - if (TORCHAO_BUILD_CPU_AARCH64) - target_link_libraries(torchao_ops_aten PRIVATE torchao_kernels_aarch64) - endif() - target_link_libraries(torchao_ops_aten PRIVATE cpuinfo) - target_include_directories(torchao_ops_aten PRIVATE "${TORCH_INCLUDE_DIRS}") - target_link_libraries(torchao_ops_aten PRIVATE "${TORCH_LIBRARIES}") - target_compile_definitions(torchao_ops_aten PRIVATE USE_ATEN=1) + + # Use the Python extension name if provided + add_library(torchao_ops_aten SHARED) # Add MPS support if enabled if (TORCHAO_BUILD_MPS_OPS) @@ -141,40 +57,3 @@ if(TORCHAO_BUILD_ATEN_OPS) DESTINATION lib ) endif() - - -# Build ExecuTorch ops -if(TORCHAO_BUILD_EXECUTORCH_OPS) - # ExecuTorch package is not required, but EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES must - # be defined and EXECUTORCH_LIBRARIES must include the following libraries installed by ExecuTorch: - # libexecutorch.a - # libextension_threadpool.a - # libcpuinfo.a - # libpthreadpool.a - if(NOT DEFINED EXECUTORCH_INCLUDE_DIRS AND NOT DEFINED EXECUTORCH_LIBRARIES) - message(WARNING "EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES are not defined. Looking for ExecuTorch.") - find_package(ExecuTorch HINTS ${CMAKE_PREFIX_PATH}/executorch/share/cmake) - endif() - set(_torchao_op_srcs_executorch) - list(APPEND _torchao_op_srcs_executorch - ops/embedding_xbit/op_embedding_xbit_executorch.cpp - ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp - ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp - ) - list(TRANSFORM _torchao_op_srcs_executorch PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") - add_library(torchao_ops_executorch STATIC ${_torchao_op_srcs_executorch}) - target_link_torchao_parallel_backend(torchao_ops_executorch executorch) - target_include_directories(torchao_ops_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}") - target_compile_definitions(torchao_ops_executorch PRIVATE USE_EXECUTORCH=1) - target_link_libraries(torchao_ops_executorch PRIVATE "${EXECUTORCH_LIBRARIES}") - if (TORCHAO_BUILD_CPU_AARCH64) - target_link_libraries(torchao_ops_executorch PRIVATE torchao_kernels_aarch64) - endif() - target_link_libraries(torchao_ops_executorch PRIVATE cpuinfo) - install( - TARGETS - torchao_ops_executorch - EXPORT _targets - DESTINATION lib - ) -endif() diff --git a/torchao/experimental/benchmark_infra/ios/output_redirect.mm b/torchao/experimental/benchmark_infra/ios/output_redirect.mm index 93c1164c16..692ec59d07 100644 --- a/torchao/experimental/benchmark_infra/ios/output_redirect.mm +++ b/torchao/experimental/benchmark_infra/ios/output_redirect.mm @@ -40,6 +40,13 @@ close(stdout_dupfd_); close(stderr_dupfd_); fclose(redirect_out_); + /* write done file to detect end of benchmark*/ + std::string file_name = + std::string(std::getenv("HOME")) + "/tmp/BENCH_DONE"; + FILE *donefile = fopen(file_name.c_str(), "w"); + std::string done_str = "DONE BENCHMARKING"; + fwrite(done_str.c_str(), 1, done_str.size(), donefile); + fclose(donefile); } } diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md deleted file mode 100644 index a178c9b328..0000000000 --- a/torchao/experimental/docs/readme.md +++ /dev/null @@ -1,141 +0,0 @@ -# TorchAO experimental - -TorchAO experimental contains lowbit ARM CPU and Metal kernels for linear and -embedding ops. - -## Building ARM CPU kernels - -To build torch ops that use the lowbit kernels, run -`sh build_torchao_ops.sh ` from torchao/experimental. - -For example, to build ATen ops, run `sh build_torchao_ops.sh aten` (this -requires PyTorch). Similarly, to build the ExecuTorch ops, run -`sh build_torchao_ops executorch` (this requires ExecuTorch). - -After running the script, the op libraries will be in - -``` -cmake-out/lib/libtorchao_ops_aten.{dylib|so} # ATen op library -cmake-out/lib/libtorchao_ops_executorch.a # ExecuTorch op library -``` - -## Quantizing models - -Once the ATen ops are built, you can quantize PyTorch models with them. The -quantized models can be run in eager model, compiled, used with AOTI, or -exported. The exported models can be lowered to ExecuTorch. - -```python -import torch -torch.ops.load_library("cmake-out/lib/libtorchao_ops_aten.dylib") # make sure this path is correct on your machine -from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer, IntxWeightEmbeddingQuantizer - -my_model = Model() - -embedding_quantizer = IntxWeightEmbeddingQuantizer( - device="cpu", - precision=torch.float32, - bitwidth=2, # bitwidth to quantize embedding weights to (values 1-7 are supported) - groupsize=32, # groupsize for embedding weights (any multiple of 32 is supported) -) -quantized_model = embedding_quantizer.quantize(my_model) - - -linear_quantizer = Int8DynActIntxWeightLinearQuantizer( - device="cpu", - precision=torch.float32, - bitwidth=4, # bitwidth to quantize linear weights to (values 1-7 are supported) - groupsize=256, # groupsize for quantization (any multiple of 16 is supported) - has_weight_zeros=False, # whether to quantize weights with scales and zeros, or scales-only -) -quantized_model = linear_quantizer.quantize(quantized_model) -``` - -If you get stuck on the above steps, working examples for both linear and -embedding are in -torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py and -torchao/experimental/tests/test_embedding_xbit_quantizer.py. For example, -running `python tests/test_linear_8bit_act_xbit_weight_quantizer.py` loads the -ops, creates a toy model, quantizes the model, and runs it in eager, compile, -AOTI, and exports the model. - -### Subclass API - -For linear, you can also use the new subclass API in torchao. First install the -kernels by running the following command from the ao directory. (Note: takeshis -will only install the kernels if run on a Mac with Apple Silicon.) - -``` -USE_CPP=1 pip install . -``` - -Once the kernels are installed, you can quantize your model as follows: - -```python -from torchao.dtypes import PlainLayout -from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, -) -from torchao.experimental.quant_api import ( - int8_dynamic_activation_intx_weight, -) -from torchao.quantization.granularity import ( - PerGroup, - PerRow, -) -from torchao.quantization.quant_api import quantize_ - -my_model = Model() - -quantize_( - my_model, - int8_dynamic_activation_intx_weight( - weight_dtype=torch.int4, - granularity=PerGroup(256), # PerRow() is also supported - has_weight_zeros=False, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() is also supported, but much slower on CPU - ), -) -``` - -KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: - -```python -from torchao.dtypes import PlainLayout -from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, -) -from torchao.experimental.quant_api import ( - int8_dynamic_activation_intx_weight, -) -from torchao.quantization.granularity import ( - PerGroup, - PerRow, -) -from torchao.quantization.quant_api import quantize_ -from torchao.quantization.quant_primitives import MappingType - -my_model = Model() - -quantize_( - my_model, - int8_dynamic_activation_intx_weight( - weight_dtype=torch.int4, - granularity=PerGroup(32), # PerRow() is also supported - has_weight_zeros=True, # Should be True - weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"), - ), -) -``` - -If you get stuck, consult -`torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py` -for a working example. - -## Available in torchchat - -TorchAO experimental kernels are -[available in torchchat](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels), -PyTorch's solution for running LLMs locally. Torchchat integration uses similar -steps to above. diff --git a/torchao/experimental/install_requirements.sh b/torchao/experimental/install_requirements.sh deleted file mode 100644 index 96c70cfc8f..0000000000 --- a/torchao/experimental/install_requirements.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Install requirements for experimental torchao ops. -if [[ -z $PIP ]]; -then - PIP=pip -fi - -NIGHTLY_VERSION="dev20241011" -$PIP install "executorch==0.5.0.$NIGHTLY_VERSION" --extra-index-url https://download.pytorch.org/whl/nightly/cpu diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt deleted file mode 100644 index f38794d4a8..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -if (TORCHAO_BUILD_CPU_AARCH64) - add_library( - torchao_kernels_aarch64 - ${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduction/compute_sum.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/quantization/quantize.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/valpacking/interleave.cpp - ) - if (TORCHAO_BUILD_KLEIDIAI) - include(FetchContent) - # KleidiAI is an open-source library that provides optimized - # performance-critical routines, also known as micro-kernels, for artificial - # intelligence (AI) workloads tailored for Arm® CPUs. - FetchContent_Declare(kleidiai - GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git - GIT_TAG v1.5.0) - FetchContent_MakeAvailable(kleidiai) - - target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) - endif() - -install( - TARGETS torchao_kernels_aarch64 - DESTINATION lib -) -endif() diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt deleted file mode 100644 index 5227ff1090..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(benchmarks) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -include(FetchContent) -FetchContent_Declare(googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG main) # need main for benchmark::benchmark - -set(BENCHMARK_ENABLE_TESTING OFF) -FetchContent_MakeAvailable( - googlebenchmark) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) - -add_executable(benchmark_quantization benchmark_quantization.cpp) -target_link_libraries( - benchmark_quantization - PRIVATE - benchmark::benchmark - dep -) - -add_executable(benchmark_bitpacking benchmark_bitpacking.cpp) -target_link_libraries( - benchmark_bitpacking - PRIVATE - benchmark::benchmark - dep -) - -add_executable(benchmark_linear benchmark_linear.cpp) -target_link_libraries( - benchmark_linear - PRIVATE - benchmark::benchmark - dep -) diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh deleted file mode 100644 index e7fa9402e2..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -eu -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -set -eu - -if [[ $# -ne 1 ]]; then - echo "Usage: $0 "; - exit 1; -fi - -BENCHMARK_TYPE="${1}" -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) - -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks - -# Build -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \ - -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -# Run -case "${BENCHMARK_TYPE}" in - quantization) ${CMAKE_OUT}/benchmark_quantization; ;; - bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;; - linear) ${CMAKE_OUT}/benchmark_linear; ;; - *) echo "Unknown benchmark: $1. Please specify quantization, bitpacking, or linear."; exit 1; ;; -esac diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt deleted file mode 100644 index 1fd2828fc5..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(tests) -set(CMAKE_CXX_STANDARD 17) - -include(FetchContent) -FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip -) -FetchContent_MakeAvailable(googletest) - -if (ANDROID_ABI) - # We are cross compiling, delay test discovery till runtime - set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) -endif() - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp -) - -if(NOT TORCHAO_INCLUDE_DIRS) - set(TORCHAO_INCLUDE_DIRS ${TORCHAO_LIBRARIES}) -endif() - -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) - -if(TORCHAO_BUILD_KLEIDIAI) - add_compile_definitions(TORCHAO_ENABLE_KLEIDI) - add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) -endif() - -if(TORCHAO_BUILD_ARM_I8MM) - add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) -endif() - -enable_testing() - -if (ANDROID_ABI) - # Given where we are today this is sufficent. But needs to be revisited. - # This is also needed for native builds, but keeping it only for cross builds - # for now given the hacky nature. - file(GLOB DOTPROD_SRC_FILES test*.cpp) - message(SRC_FILES: ${DOTPROD_SRC_FILES}) - set_property(SOURCE - ${DOTPROD_SRC_FILES} - APPEND_STRING PROPERTY - COMPILE_FLAGS " -march=armv8.2-a+dotprod ") -endif() - -add_executable(test_quantization test_quantization.cpp) -target_link_libraries( - test_quantization - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_reduction test_reduction.cpp) -target_link_libraries( - test_reduction - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_bitpacking test_bitpacking.cpp) -target_link_libraries( - test_bitpacking - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_linear test_linear.cpp) -target_link_libraries( - test_linear - PRIVATE - GTest::gtest_main - dep - torchao_kernels_aarch64 -) - - -add_executable(test_embedding test_embedding.cpp) -target_link_libraries( - test_embedding - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_weight_packing test_weight_packing.cpp) -target_link_libraries( - test_weight_packing - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_qmatmul test_qmatmul.cpp) -target_link_libraries( - test_qmatmul - PRIVATE - GTest::gtest_main - dep -) - -include(GoogleTest) -gtest_discover_tests(test_quantization) -gtest_discover_tests(test_reduction) -gtest_discover_tests(test_bitpacking) -gtest_discover_tests(test_linear) -gtest_discover_tests(test_embedding) -gtest_discover_tests(test_weight_packing) -gtest_discover_tests(test_qmatmul) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh deleted file mode 100644 index 5d28ea01cc..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -eu -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -set -eu -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/kernel_tests - -target=${1:-"native"} - -EXTRA_ARGS="" -if [[ "${target}" == "android" ]]; then - if [[ -z ${ANDROID_NDK} ]]; then - echo "Need to set ANDROID_NDK env variable to build for Android"; - exit 1; - fi - android_abi=arm64-v8a - android_platform=28 # must be >=28 for aligned_alloc - IS_ARM64=1 - BUILD_ARM_I8MM=1 # Hardcoded for now - CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} - toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" - if [[ -z ${toolchain_file} ]]; then - echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" - exit 1; - fi - EXTRA_ARGS="\ - -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ - -DANDROID_ABI=${android_abi} \ - -DANDROID_PLATFORM=${android_platform} - " - echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" -fi - -cmake \ - ${EXTRA_ARGS} \ - -DCMAKE_BUILD_TYPE=Debug \ - -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -DTORCHAO_BUILD_CPU_AARCH64=ON \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \ - -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -echo "Successfully built tests." - -if [[ "${target}" != "native" ]]; then - echo "Skip running tests when cross compiling."; - exit 0; -fi - -# Run -${CMAKE_OUT}/test_quantization -${CMAKE_OUT}/test_reduction -${CMAKE_OUT}/test_bitpacking -${CMAKE_OUT}/test_linear -${CMAKE_OUT}/test_embedding -${CMAKE_OUT}/test_weight_packing -${CMAKE_OUT}/test_qmatmul diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index 370c6d400c..8071398eba 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -73,11 +73,11 @@ using DispatchFn = void (*)(id, int32_t, int32_t, int32_t, int32_t); inline void linear_lowbit_quant_weights_mps_impl( - id a_buf, - id b_buf, - id s_buf, - id z_buf, - id out_buf, + std::pair, size_t> a_buf_offset, + std::pair, size_t> b_buf_offset, + std::pair, size_t> s_buf_offset, + std::pair, size_t> z_buf_offset, + std::pair, size_t> out_buf_offset, int32_t M, int32_t K, int32_t N, @@ -97,11 +97,11 @@ inline void linear_lowbit_quant_weights_mps_impl( metal_lowbit_quantized_lib.getPipelineStateForFunc(shader_func); const auto maxThreadsPerGroup = [cpl maxTotalThreadsPerThreadgroup]; [computeEncoder setComputePipelineState:cpl]; - [computeEncoder setBuffer:a_buf offset:0 atIndex:0]; - [computeEncoder setBuffer:b_buf offset:0 atIndex:1]; - [computeEncoder setBuffer:s_buf offset:0 atIndex:2]; - [computeEncoder setBuffer:z_buf offset:0 atIndex:3]; - [computeEncoder setBuffer:out_buf offset:0 atIndex:4]; + [computeEncoder setBuffer:a_buf_offset.first offset:a_buf_offset.second atIndex:0]; + [computeEncoder setBuffer:b_buf_offset.first offset:b_buf_offset.second atIndex:1]; + [computeEncoder setBuffer:s_buf_offset.first offset:s_buf_offset.second atIndex:2]; + [computeEncoder setBuffer:z_buf_offset.first offset:z_buf_offset.second atIndex:3]; + [computeEncoder setBuffer:out_buf_offset.first offset:out_buf_offset.second atIndex:4]; [computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:5]; @@ -133,12 +133,12 @@ std::tuple get_shader_func_and_dispatch( // LowBit Quantized Weights Linear on Metal template void linear_lowbit_quant_weights_mps( - id a_buf, - id b_buf, + std::pair, size_t> a_buf_offset, + std::pair, size_t> b_buf_offset, int64_t qGroupSize, - id s_buf, - id z_buf, - id out_buf, + std::pair, size_t> s_buf_offset, + std::pair, size_t> z_buf_offset, + std::pair, size_t> out_buf_offset, int32_t M, int32_t K, int32_t N, @@ -154,11 +154,11 @@ void linear_lowbit_quant_weights_mps( const DispatchFn dispatch_fn = std::get<1>(shader_func_and_dispatch); return linear_lowbit_quant_weights_mps_impl( - a_buf, - b_buf, - s_buf, - z_buf, - out_buf, + a_buf_offset, + b_buf_offset, + s_buf_offset, + z_buf_offset, + out_buf_offset, M, K, N, diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index 524aee738d..8481e5cef6 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -118,12 +118,12 @@ void pack() { void linear() { LowBitQuantWeights::linear( - buf_A, - buf_B, + {buf_A, 0}, + {buf_B, 0}, qGroupSize, - buf_S, - buf_Z, - buf_C, + {buf_S, 0}, + {buf_Z, 0}, + {buf_C, 0}, M, K, N, diff --git a/torchao/experimental/op_lib_utils.py b/torchao/experimental/op_lib_utils.py deleted file mode 100644 index 25cb8a1ed2..0000000000 --- a/torchao/experimental/op_lib_utils.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch - - -def _check_torchao_ops_loaded(): - # Check kernels are installed/loaded - try: - torch.ops.torchao._pack_8bit_act_4bit_weight - except AttributeError: - raise Exception( - "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." - + " You can also set target to 'aten' if you are using ARM CPU." - ) diff --git a/torchao/experimental/ops/benchmarks/CMakeLists.txt b/torchao/experimental/ops/benchmarks/CMakeLists.txt deleted file mode 100644 index d06526cf84..0000000000 --- a/torchao/experimental/ops/benchmarks/CMakeLists.txt +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(benchmarks) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) -add_compile_options("-Wall" "-Werror") - -set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) -set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) - -include(FetchContent) -FetchContent_Declare(googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG main) # need main for benchmark::benchmark - -set(BENCHMARK_ENABLE_TESTING OFF) -FetchContent_MakeAvailable( - googlebenchmark) - -include_directories(${TORCHAO_INCLUDE_DIRS}) - -set(TORCHAO_PARALLEL_BACKEND "openmp") - -include(${TORCHAO_ROOT}/Utils.cmake) - -add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) - -add_executable(benchmark_linear_8bit_act_xbit_weight - benchmark_linear_8bit_act_xbit_weight.cpp - ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp -) -target_link_torchao_parallel_backend(benchmark_linear_8bit_act_xbit_weight "${TORCHAO_PARALLEL_BACKEND}") -target_link_libraries( - benchmark_linear_8bit_act_xbit_weight - PRIVATE - benchmark::benchmark - torchao_kernels_aarch64 -) diff --git a/torchao/experimental/ops/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/ops/benchmarks/build_and_run_benchmarks.sh deleted file mode 100644 index b837b36fe4..0000000000 --- a/torchao/experimental/ops/benchmarks/build_and_run_benchmarks.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} - -export CMAKE_OUT=/tmp/cmake-out/torchao/benchmarks -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -S . \ - -B ${CMAKE_OUT} \ - -DOpenMP_ROOT=$(brew --prefix libomp) \ - -DTORCHAO_PARALLEL_OMP=ON - -cmake --build ${CMAKE_OUT} - -# Run -${CMAKE_OUT}/benchmark_linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/CMakeLists.txt b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/CMakeLists.txt deleted file mode 100644 index 7ba8d20c6d..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/CMakeLists.txt +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -project(examples) - -cmake_minimum_required(VERSION 3.19) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -include(CMakePrintHelpers) - -set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) -set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) - -include_directories(${TORCHAO_INCLUDE_DIRS}) - -set(TORCHAO_PARALLEL_BACKEND "openmp") -add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) - -include(${TORCHAO_ROOT}/Utils.cmake) - -add_executable(separate_function_wrappers - separate_function_wrappers.cpp - ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp -) -target_link_libraries( - separate_function_wrappers - PRIVATE - torchao_kernels_aarch64 -) -target_link_torchao_parallel_backend(separate_function_wrappers "${TORCHAO_PARALLEL_BACKEND}") - -add_executable(stateful_class_wrapper - stateful_class_wrapper.cpp - ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp -) -target_link_libraries( - stateful_class_wrapper - PRIVATE - torchao_kernels_aarch64 -) -target_link_torchao_parallel_backend(stateful_class_wrapper "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h deleted file mode 100644 index 2250a60706..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include -#include -#include -#include -#include - -namespace torchao::ops::linear_8bit_act_xbit_weight { - -class Linear8BitActXBitWeightOperator { - private: - torchao::aligned_byte_ptr packed_weight_data_{nullptr, nullptr}; - int packed_weight_data_size_{0}; - int preferred_packed_weight_data_alignment_{0}; - - torchao::aligned_byte_ptr activation_data_buffer_{nullptr, nullptr}; - - int m_{0}; - int n_{0}; - int k_{0}; - int group_size_{0}; - - // The class does not own this data - const int8_t* weight_qvals_{nullptr}; - const float* weight_scales_{nullptr}; - const int8_t* weight_zeros_{nullptr}; - - bool initialized_{false}; - - UKernelConfig ukernel_config_; - PackWeightDataTilingParams pack_weight_tiling_params_; - LinearTilingParams linear_tiling_params_; - LinearTileSchedulingPolicy linear_scheduling_policy_; - - public: - Linear8BitActXBitWeightOperator( - UKernelConfig ukernel_config, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - int initial_m = 1, - std::optional pack_weight_tiling_params = {}, - std::optional linear_tiling_params = {}, - std::optional linear_scheduling_policy = {}) - : m_{initial_m}, - n_{n}, - k_{k}, - group_size_(group_size), - weight_qvals_{weight_qvals}, - weight_scales_{weight_scales}, - weight_zeros_{weight_zeros} { - TORCHAO_CHECK(n_ >= 1, "n must be >= 1"); - TORCHAO_CHECK(k_ >= 1, "k must be >= 1"); - TORCHAO_CHECK(group_size_ >= 1, "group_size must be >= 1"); - TORCHAO_CHECK(m_ >= 1, "initial_m must be >= 1"); - - ukernel_config_ = ukernel_config; - if (pack_weight_tiling_params.has_value()) { - pack_weight_tiling_params_ = pack_weight_tiling_params.value(); - } else { - pack_weight_tiling_params_ = get_default_pack_weight_data_tiling_params( - ukernel_config_, n_, /*target_panels_per_thread=*/1); - } - - if (linear_tiling_params.has_value()) { - linear_tiling_params_ = linear_tiling_params.value(); - } else { - linear_tiling_params_ = get_default_linear_tiling_params( - ukernel_config_, m_, n_, /*target_tiles_per_thread=*/5); - } - - if (linear_scheduling_policy.has_value()) { - linear_scheduling_policy_ = linear_scheduling_policy.value(); - } else { - linear_scheduling_policy_ = - LinearTileSchedulingPolicy::single_mc_parallel_nc; - } - } - - int get_m() { - return m_; - } - int get_n() { - return n_; - } - int get_k() { - return k_; - } - int get_group_size() { - return group_size_; - } - - void initialize() { - if (initialized_) { - return; - } - - // Pack weight data - auto packed_weight_data_size = - get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_); - auto preferred_packed_weight_data_alignment = - get_preferred_packed_weight_data_alignment(ukernel_config_); - - packed_weight_data_size_ = packed_weight_data_size; - preferred_packed_weight_data_alignment_ = preferred_packed_weight_data_alignment; - packed_weight_data_ = torchao::make_aligned_byte_ptr( - preferred_packed_weight_data_alignment, packed_weight_data_size); - - pack_weight_data_operator( - ukernel_config_, - pack_weight_tiling_params_, - packed_weight_data_.get(), - n_, - k_, - group_size_, - weight_qvals_, - weight_scales_, - weight_zeros_); - - // Pre-allocate space for quantized/packed activations - // This buffer may be resized when calling the operator if m is changed - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config_, - linear_tiling_params_, - linear_scheduling_policy_, - m_, - k_, - group_size_); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config_); - activation_data_buffer_ = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - - // Mark as initialized - initialized_ = true; - } - - void operator()( - float* output, - const float* activations, - int m, - int k, - const float* bias, - float clamp_min, - float clamp_max) { - TORCHAO_CHECK(initialized_, "kernel is not initialized."); - TORCHAO_CHECK( - k == this->k_, - "activations have incompatible size with initialized kernel."); - - // Resize activation buffer if needed - if (m > m_) { - m_ = m; - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config_, - linear_tiling_params_, - linear_scheduling_policy_, - m_, - k_, - group_size_); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config_); - activation_data_buffer_ = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - } - - // Run linear operator - linear_operator( - ukernel_config_, - linear_tiling_params_, - linear_scheduling_policy_, - activation_data_buffer_.get(), - output, - // To support dynamic shapes, we use m from args, not m_ - // Note m_ can be larger than m - m, - n_, - k_, - group_size_, - packed_weight_data_.get(), - activations, - bias, - clamp_min, - clamp_max); - } -}; -} // namespace - // torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/build_and_run_examples.sh b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/build_and_run_examples.sh deleted file mode 100644 index 01185fdd3f..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/build_and_run_examples.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" -echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torchao/examples -cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ - -S . \ - -B ${CMAKE_OUT} \ - -DOpenMP_ROOT=$(brew --prefix libomp) -cmake --build ${CMAKE_OUT} - -# Run -case "$1" in - separate_function_wrappers) ${CMAKE_OUT}/separate_function_wrappers; ;; - stateful_class_wrapper) ${CMAKE_OUT}/stateful_class_wrapper; ;; - *) echo "Unknown example: $1. Please specify one of: separate_function_wrappers, stateful_class_wrapper."; exit 1; ;; -esac diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/separate_function_wrappers.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/separate_function_wrappers.cpp deleted file mode 100644 index 961c03e985..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/separate_function_wrappers.cpp +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#include -// This file contains an example of wrapping the torchao weight packing and -// linear operators into two operators: one for weight packing and another -// for running the linear operator. Each surface (PyTorch custom class, PyTorch -// operator, ExecuTorch operator, ExecuTorch delegate) will need to write its -// own wrapper). In the example here, std::vector is used for storage, but in -// PyTorch a PyTorch Tensor would be used and in ExecuTorch, an ExecuTorch -// Tensor would be used. -// -// It is more efficient to combine weight-packing and the linear operator into -// one stateful class, but not all surfaces support this (see -// examples/stateful_class_wrapper.cpp for an example of this). - -namespace torchao::ops::linear_8bit_act_xbit_weight { - -template -UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; -} - -torchao::aligned_byte_ptr pack_weight_data_operator( - UKernelConfig ukernel_config, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - std::optional tiling_params = {}) { - PackWeightDataTilingParams tiling_params_; - if (tiling_params.has_value()) { - tiling_params_ = tiling_params.value(); - } else { - tiling_params_ = get_default_pack_weight_data_tiling_params( - ukernel_config, n, /*target_panels_per_thread=*/1); - } - - auto packed_weight_data_size = - get_packed_weight_data_size(ukernel_config, n, k, group_size); - auto preferred_packed_weight_data_alignment = - get_preferred_packed_weight_data_alignment(ukernel_config); - auto packed_weight_data = torchao::make_aligned_byte_ptr( - preferred_packed_weight_data_alignment, packed_weight_data_size); - - pack_weight_data_operator( - ukernel_config, - tiling_params_, - packed_weight_data.get(), - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros); - - return packed_weight_data; -} - -void linear_operator( - UKernelConfig ukernel_config, - float* output, - int m, - int n, - int k, - int group_size, - void* packed_weight_data, - float* activations, - const float* bias, - float clamp_min, - float clamp_max, - std::optional tiling_params = {}, - std::optional scheduling_policy = {}) { - LinearTilingParams tiling_params_; - if (tiling_params.has_value()) { - tiling_params_ = tiling_params.value(); - } else { - tiling_params_ = get_default_linear_tiling_params( - ukernel_config, m, n, /*target_tiles_per_thread=*/5); - } - - LinearTileSchedulingPolicy scheduling_policy_; - if (scheduling_policy.has_value()) { - scheduling_policy_ = scheduling_policy.value(); - } else { - scheduling_policy_ = LinearTileSchedulingPolicy::single_mc_parallel_nc; - } - - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, tiling_params_, scheduling_policy_, m, k, group_size); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - - linear_operator( - ukernel_config, - tiling_params_, - scheduling_policy_, - activation_data_buffer.get(), - output, - m, - n, - k, - group_size, - packed_weight_data, - activations, - bias, - clamp_min, - clamp_max); -} - -} // namespace - // torchao::ops::linear_8bit_act_xbit_weight - -int main() { - using namespace torchao::ops::linear_8bit_act_xbit_weight; - - torchao::set_num_threads(8); - std::cout << "Using " << torchao::get_num_threads() << " threads." - << std::endl; - - constexpr int weight_nbit = 3; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - - int m = 1; - int n = 4096 + 1; - int k = 4096; - int group_size = 16; - - std::cout << "Generating random test case." << std::endl; - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp); - - auto output = std::vector(m * n); - - auto ukernel_config = - get_ukernel_config(); - - std::cout << "Running pack_weight_data_operator." << std::endl; - auto packed_weight_data = pack_weight_data_operator( - ukernel_config, - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - test_case.weight_zeros.data()); - - std::cout << "Running linear_operator." << std::endl; - linear_operator( - ukernel_config, - output.data(), - m, - n, - k, - group_size, - packed_weight_data.get(), - test_case.activations.data(), - test_case.bias.data(), - test_case.clamp_min, - test_case.clamp_max); - - std::cout << "Checking results." << std::endl; - - bool passed = true; - float tol = 0.001; - for (int i = 0; i < output.size(); i++) { - if (std::abs(test_case.expected_output[i] - output[i]) > tol) { - std::cout << "Bad result at index " << i << "."; - std::cout << " Output: " << output[i] - << ". Expected: " << test_case.expected_output[i] << "." - << std::endl; - passed = false; - } - } - if (passed) { - std::cout << "Test passed." << std::endl; - } else { - std::cout << "Test failed." << std::endl; - } - - return 0; -} diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/stateful_class_wrapper.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/stateful_class_wrapper.cpp deleted file mode 100644 index a45c32811b..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/stateful_class_wrapper.cpp +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#include - -// This file contains an example of wrapping the torchao weight packing and -// linear operators into one stateful LinearOperator class. Each surface -// (PyTorch custom class, PyTorch operator, ExecuTorch operator, ExecuTorch -// delegate) will need to write its own wrapper. In the example here, -// std::vector is used for storage, but in PyTorch a PyTorch Tensor would be -// used and in ExecuTorch, an ExecuTorch Tensor would be used. -// -// Although more efficient, not all surfaces support stateful operators. See -// examples/separate_function_wrappers.cpp for an example of how to split the -// operations into two steps. - -using namespace torchao::ops::linear_8bit_act_xbit_weight; - -template -UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; -} - -int main() { - int m = 13; - int n = 4096 + 1; - int k = 4096; - int group_size = 16; - - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - - std::cout << "Generating random test case." << std::endl; - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp); - - torchao::set_num_threads(8); - std::cout << "Using " << torchao::get_num_threads() << " threads." - << std::endl; - - std::cout << "Initializing linear_operator." << std::endl; - auto ukernel_config = - get_ukernel_config(); - - auto linear_operator = - Linear8BitActXBitWeightOperator( - ukernel_config, - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - test_case.weight_zeros.data(), - // m may be resized during call to support dynamic shapes - /*initial_m=*/1); - - linear_operator.initialize(); - - std::cout << "Calling linear_operator." << std::endl; - auto output = std::vector(m * n); - linear_operator( - output.data(), - test_case.activations.data(), - m, - k, - test_case.bias.data(), - test_case.clamp_min, - test_case.clamp_max); - - std::cout << "Checking results." << std::endl; - - bool passed = true; - float tol = 0.001; - for (int i = 0; i < output.size(); i++) { - if (std::abs(test_case.expected_output[i] - output[i]) > tol) { - std::cout << "Bad result at index " << i << "."; - std::cout << " Output: " << output[i] - << ". Expected: " << test_case.expected_output[i] << "." - << std::endl; - passed = false; - break; - } - } - if (passed) { - std::cout << "Test passed." << std::endl; - } else { - std::cout << "Test failed." << std::endl; - } - - return 0; -} diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm index 972caa039a..e8fcdb2699 100644 --- a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm @@ -97,12 +97,12 @@ Tensor linear_mps_kernel_out( auto K = A.size(1); LowBitQuantWeights::linear( - getMTLBufferStorage(A), - getMTLBufferStorage(B), + {getMTLBufferStorage(A), A.storage_offset() * A.element_size()}, + {getMTLBufferStorage(B), B.storage_offset() * B.element_size()}, group_size, - getMTLBufferStorage(S), - getMTLBufferStorage(Z), - getMTLBufferStorage(C), + {getMTLBufferStorage(S), S.storage_offset() * S.element_size()}, + {getMTLBufferStorage(Z), Z.storage_offset() * Z.element_size()}, + {getMTLBufferStorage(C), C.storage_offset() * C.element_size()}, M, K, N, diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm index f8a8ffdae9..22693b417e 100644 --- a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm @@ -95,12 +95,12 @@ bool check_linear_mps_args( auto K = A.size(1); torchao::kernels::mps::lowbit::LowBitQuantWeights::linear( - getMTLBufferStorage(A), - getMTLBufferStorage(B), + {getMTLBufferStorage(A), A.storage_offset() * A.element_size()}, + {getMTLBufferStorage(B), B.storage_offset() * B.element_size()}, group_size, - getMTLBufferStorage(S), - getMTLBufferStorage(Z), - getMTLBufferStorage(out), + {getMTLBufferStorage(S), S.storage_offset() * S.element_size()}, + {getMTLBufferStorage(Z), Z.storage_offset() * Z.element_size()}, + {getMTLBufferStorage(out), out.storage_offset() * out.element_size()}, M, K, N, diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 04273fb1af..e7d035fb61 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -86,6 +86,42 @@ def test_export(self, nbit): == f"torchao._linear_fp_act_{nbit}bit_weight.default" ) + @parameterized.expand(BITWIDTHS) + def test_export_accuracy(self, nbit): + group_size = 32 + m = 3 + n = 12 + k = 64 + with torch.no_grad(): + activations = torch.rand(m, k, dtype=torch.float32, device="mps") + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + # Compute expected result + weight_cpu = model[0].weight.data + weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize( + weight_cpu, group_size, nbit, True, torch.uint8 + ) + weight_zeros_cpu = -weight_zeros_cpu * weight_scales_cpu + expected = self._reference_linear_lowbit_quant_weights( + activations.cpu(), + weight_qvals_cpu, + group_size, + weight_scales_cpu, + weight_zeros_cpu, + ) + + quantized_model = self._quantize_model( + model, torch.float32, nbit, group_size + ) + + ep = torch.export.export(quantized_model, (activations,), strict=True) + path = torch._inductor.aoti_compile_and_package(ep) + compiled_model = torch._inductor.aoti_load_package(path) + result = compiled_model(activations) + + # Compare results + torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001) + @parameterized.expand(BITWIDTHS) def test_2d_output_device_and_shape(self, nbit): model, group_size, k0, n = self._model_setup() diff --git a/torchao/experimental/ops/tests/CMakeLists.txt b/torchao/experimental/ops/tests/CMakeLists.txt deleted file mode 100644 index 8245fdd746..0000000000 --- a/torchao/experimental/ops/tests/CMakeLists.txt +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(tests) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) -add_compile_options("-Wall" "-Werror") - -set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) -set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) - -include(FetchContent) -FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip -) -FetchContent_MakeAvailable(googletest) -enable_testing() - -if(TORCHAO_BUILD_CPU_AARCH64) - add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64=1) - add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) -endif() - -if(TORCHAO_BUILD_KLEIDIAI) - add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) -endif() - -if(TORCHAO_BUILD_ARM_I8MM) - add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) -endif() - -if (ANDROID_ABI) - # We are cross compiling, delay test discovery till runtime - set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) -endif() - -include_directories(${TORCHAO_INCLUDE_DIRS}) - -set(TORCHAO_PARALLEL_BACKEND "test_dummy") - -if (TORCHAO_BUILD_CPU_AARCH64) - add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) - add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) -endif() - -include(${TORCHAO_ROOT}/Utils.cmake) - -if (ANDROID_ABI) - # Given where we are today this is sufficent. But needs to be revisited. - # This is also needed for native builds, but keeping it only for cross builds - # for now given the hacky nature. - file(GLOB DOTPROD_SRC_FILES test*.cpp) - message(SRC_FILES: ${DOTPROD_SRC_FILES}) - set_property(SOURCE - ${DOTPROD_SRC_FILES} - APPEND_STRING PROPERTY - COMPILE_FLAGS " -march=armv8.2-a+dotprod ") -endif() - -add_executable( - test_linear_8bit_act_xbit_weight - test_linear_8bit_act_xbit_weight.cpp - ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp -) -target_link_libraries( - test_linear_8bit_act_xbit_weight - PRIVATE - GTest::gtest_main -) -if (TORCHAO_BUILD_CPU_AARCH64) - target_link_libraries( - test_linear_8bit_act_xbit_weight - PRIVATE - torchao_kernels_aarch64 - ) -endif() -target_link_torchao_parallel_backend(test_linear_8bit_act_xbit_weight "${TORCHAO_PARALLEL_BACKEND}") - -include(GoogleTest) -gtest_discover_tests(test_linear_8bit_act_xbit_weight) diff --git a/torchao/experimental/ops/tests/build_and_run_tests.sh b/torchao/experimental/ops/tests/build_and_run_tests.sh deleted file mode 100644 index 6a73b91219..0000000000 --- a/torchao/experimental/ops/tests/build_and_run_tests.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -target=${1:-"native"} -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests - -export TORCH_DIR=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib() + '/torch/share/cmake/Torch')") - -IS_ARM64=0 -BUILD_ARM_I8MM=0 -EXTRA_ARGS="" -if [[ "${target}" == "android" ]]; then - if [[ -z ${ANDROID_NDK} ]]; then - echo "Need to set ANDROID_NDK env variable to build for Android"; - exit 1; - fi - android_abi=arm64-v8a - android_platform=28 # must be >=28 for aligned_alloc - IS_ARM64=1 - BUILD_ARM_I8MM=1 # Hardcoded for now - CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} - toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" - if [[ -z ${toolchain_file} ]]; then - echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" - exit 1; - fi - EXTRA_ARGS="\ - -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ - -DANDROID_ABI=${android_abi} \ - -DANDROID_PLATFORM=${android_platform} - " - echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" -fi - -hash arch; retval=$? -if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then - IS_ARM64=1 -fi - -cmake \ - ${EXTRA_ARGS} \ - -DCMAKE_BUILD_TYPE=Debug \ - -DTORCHAO_BUILD_CPU_AARCH64=${IS_ARM64} \ - -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ - -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ - -DTorch_DIR=${TORCH_DIR} \ - -S . \ - -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -echo "Successfully built tests." - -if [[ "${target}" != "native" ]]; then - echo "Skip running tests when cross compiling."; - exit 0; -fi - -# Run -${CMAKE_OUT}/test_linear_8bit_act_xbit_weight diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py deleted file mode 100644 index b6b9fcbcc5..0000000000 --- a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# TODO: delete this file. -# File is kept in torchao/experimental to avoid breaking existing code -import logging - -logging.warning( - "torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout.py is deprecated and will be removed. Please use torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout.py instead." -) -from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, - Target, -) - -__all__ = [ - "PackedLinearInt8DynamicActivationIntxWeightLayout", - "Target", -] diff --git a/torchao/experimental/q_dq_layout.py b/torchao/experimental/q_dq_layout.py deleted file mode 100644 index 5eeea7f4bd..0000000000 --- a/torchao/experimental/q_dq_layout.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# TODO: delete this file. -# File is kept in torchao/experimental to avoid breaking existing code -import logging - -logging.warning( - "torchao.experimental.q_dq_layout.py is deprecated and will be removed. Please use torchao.dtypes.uintx.q_dq_layout.py instead." -) -from torchao.dtypes import QDQLayout - -__all__ = [ - "QDQLayout", -] diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 2e50587c2a..dd2168868d 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -6,7 +6,7 @@ import logging import sys -from typing import Callable, List, Mapping, Optional, Tuple, Union +from typing import Optional import torch import torch.nn as nn @@ -23,452 +23,6 @@ handler.setFormatter(formatter) logger.addHandler(handler) -from dataclasses import dataclass - -from torchao.core.config import AOBaseConfig -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, -) -from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, - Target, -) -from torchao.experimental.op_lib_utils import _check_torchao_ops_loaded -from torchao.quantization.granularity import Granularity, PerAxis, PerGroup, PerRow -from torchao.quantization.quant_api import ( - Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationIntxWeightConfig_NonExperimental, -) -from torchao.quantization.quant_api import ( - IntxWeightOnlyConfig, - MappingType, - quantize_, -) -from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH - - -@dataclass -class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): - weight_dtype: torch.dtype = torch.int4 - granularity: Union[PerRow, PerGroup] = PerRow() - has_weight_zeros: bool = False - weight_mapping_type: MappingType = MappingType.ASYMMETRIC - act_mapping_type: MappingType = MappingType.ASYMMETRIC - round_weight_scale_to_bf16: bool = True - layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=Target.AUTO) - - def __post_init__(self): - raise NotImplementedError( - "Int8DynamicActivationIntxWeightConfig has moved from torchao.experimental.quant_api to torchao.quantization.quant_api.\n" - "Please migrate to using the new version. The following args are renamed in the new version:\n" - "* granularity -> weight_granularity\n" - "* has_weight_zeros=True -> weight_mapping_type=torchao.quantization.quant_api.MappingType.ASYMMETRIC\n" - "* has_weight_zeros=False -> weight_zero_point_domain=torchao.quantization.quant_api.MappingType.SYMMETRIC\n" - "* round_weight_scale_to_bf16=True -> weight_scale_dtype=torch.bfloat16\n" - "* layout default has changed to QDQLayout(). IF YOU WANT CPU PERFORMANCE, USE layout=PackedLinearInt8DynamicActivationIntxWeightLayout()." - ) - - -# For BC -int8_dynamic_activation_intx_weight = Int8DynamicActivationIntxWeightConfig - - -class QuantizedEmbedding(nn.Module): - def __init__( - self, - bit_width, - ): - super().__init__() - self.bit_width = bit_width - - def quantize_and_pack_weights(self, weights, group_size, mapping_type): - num_embeddings, embedding_dim = weights.shape - - embedding = torch.nn.Embedding(num_embeddings, embedding_dim) - embedding.weight = weights - quantize_( - embedding, - IntxWeightOnlyConfig( - weight_dtype=getattr(torch, f"int{self.bit_width}"), - granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0), - mapping_type=mapping_type, - ), - lambda m, fqn: isinstance(m, torch.nn.Embedding), - ) - weight_qvals, weight_scales, weight_zeros = ( - embedding.weight.tensor_impl.get_plain() - ) - assert weight_zeros is not None - weight_scales = weight_scales.reshape(num_embeddings, -1) - weight_zeros = weight_zeros.reshape(num_embeddings, -1).to(torch.int8) - self.register_buffer( - "packed_weight_qvals", - getattr(torch.ops.torchao, f"_pack_embedding_{self.bit_width}bit")( - weight_qvals.to(torch.int8) - ), - ) - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.register_buffer("weight_scales", weight_scales) - self.register_buffer("weight_zeros", weight_zeros) - - def forward(self, x): - shape = x.shape - return getattr(torch.ops.torchao, f"_embedding_{self.bit_width}bit")( - self.packed_weight_qvals, - self.num_embeddings, - self.embedding_dim, - self.weight_scales, - # embedding op requires weight_zeros be passed, even if they are all 0 - self.weight_zeros, - x.reshape(-1), - ).reshape(*shape, -1) - - -class QuantizedEmbeddingFallback(nn.Module): - def __init__( - self, - bit_width, - ): - super().__init__() - self.bit_width = bit_width - - def quantize_and_pack_weights(self, weights, group_size, mapping_type): - self.embedding = torch.nn.Embedding(*weights.shape) - self.embedding.weight = weights - quantize_( - self.embedding, - IntxWeightOnlyConfig( - weight_dtype=getattr(torch, f"int{self.bit_width}"), - granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0), - mapping_type=mapping_type, - ), - lambda m, fqn: isinstance(m, torch.nn.Embedding), - ) - - def forward(self, x): - return self.embedding(x) - - -class QuantizedSharedEmbedding(nn.Module): - def __init__(self, bit_width, unembedding_packed_weights, group_size, n, k): - super().__init__() - self.bit_width = bit_width - self.register_buffer("unembedding_packed_weights", unembedding_packed_weights) - self.n = n - self.k = k - if group_size == -1: - self.group_size = k - else: - self.group_size = group_size - self.shared_embedding_op = getattr( - torch.ops.torchao, f"_shared_embedding_{bit_width}bit" - ) - - def forward(self, x): - shape = x.shape - return self.shared_embedding_op( - self.unembedding_packed_weights, - self.group_size, - self.n, - self.k, - x.reshape(-1), - ).reshape(*shape, -1) - - -def _replace_embedding_with_quantized_embedding( - module: nn.Module, - kwargs={}, - fqn: str = "", -): - group_size = kwargs.get("group_size", None) - bit_width = kwargs.get("bit_width", None) - use_fallback = kwargs.get("use_fallback", None) - mapping_type = kwargs.get("mapping_type", None) - embedding_fqn_to_quantized_unembedding = kwargs.get( - "embedding_fqn_to_quantized_unembedding", None - ) - - assert not isinstance(module, nn.Embedding) - for name, child in module.named_children(): - child_fqn = f"{fqn}.{name}" if fqn != "" else name - - if not isinstance(child, nn.Embedding): - _replace_embedding_with_quantized_embedding(child, kwargs, child_fqn) - else: - assert child.weight.device == torch.device("cpu"), "Only CPU is supported" - assert child.weight.dtype == torch.float32, "Only float32 is supported" - - if use_fallback: - qembedding = QuantizedEmbeddingFallback(bit_width) - setattr(module, name, qembedding) - getattr(module, name).quantize_and_pack_weights( - child.weight, - group_size, - mapping_type, - ) - else: - _check_torchao_ops_loaded() - if embedding_fqn_to_quantized_unembedding is None: - qembedding = QuantizedEmbedding(bit_width) - setattr(module, name, qembedding) - getattr(module, name).quantize_and_pack_weights( - child.weight, - group_size, - mapping_type, - ) - else: - if child_fqn not in embedding_fqn_to_quantized_unembedding: - continue - weight_tensor = embedding_fqn_to_quantized_unembedding[child_fqn] - n, k = weight_tensor.shape - group_size = weight_tensor.tensor_impl.get_layout().group_size - packed_weight = weight_tensor.tensor_impl.packed_weight - bit_width = weight_tensor.tensor_impl.get_layout().bit_width - - assert n == child.num_embeddings, ( - "num_embeddings must match n in shared_unembedding" - ) - assert k == child.embedding_dim, ( - "embedding_dim must match k in shared_unembedding" - ) - qembedding = QuantizedSharedEmbedding( - bit_width, - packed_weight, - group_size, - n, - k, - ) - setattr(module, name, qembedding) - - -class EmbeddingQuantizer: - def __init__( - self, - weight_dtype: torch.dtype = torch.int4, - granularity: Granularity = PerAxis(0), - mapping_type: MappingType = MappingType.ASYMMETRIC, - use_fallback: bool = False, - ): - assert weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)] - bit_width = _DTYPE_TO_BIT_WIDTH[weight_dtype] - - if isinstance(granularity, PerGroup): - group_size = granularity.group_size - elif isinstance(granularity, PerAxis): - assert granularity.axis == 0 - group_size = -1 - else: - raise ValueError(f"Unsupported granularity: {granularity}") - - self.bit_width = bit_width - self.group_size = group_size - self.use_fallback = use_fallback - self.mapping_type = mapping_type - - def quantize(self, model: nn.Module) -> nn.Module: - _replace_embedding_with_quantized_embedding( - model, - kwargs={ - "group_size": self.group_size, - "bit_width": self.bit_width, - "use_fallback": self.use_fallback, - "mapping_type": self.mapping_type, - }, - ) - return model - - -def _get_fqns_with_filter( - module: nn.Module, - filter_fn: Callable[Tuple[str, nn.Module], bool], - fqn: str, - fqns: List[str], -): - for name, child in module.named_children(): - child_fqn = f"{fqn}.{name}" if fqn != "" else name - if filter_fn(child, child_fqn): - fqns.append(child_fqn) - else: - _get_fqns_with_filter(child, filter_fn, child_fqn, fqns) - - -def get_fqns_with_filter( - module: nn.Module, filter_fn: Callable[Tuple[str, nn.Module], bool] -) -> List[str]: - fqns = [] - _get_fqns_with_filter(module, filter_fn, "", fqns) - return fqns - - -class QuantizedLinear(nn.Module): - def __init__(self, packed_weight, n, k, group_size, bit_width, bias): - super().__init__() - self.register_buffer("packed_weight", packed_weight) - self.n = n - self.k = k - self.group_size = group_size - self.bit_width = bit_width - self.bias = bias - - def _forward_2d(self, x): - assert x.dim() == 2 - m, k = x.shape - assert k == self.k - return getattr( - torch.ops.torchao, f"_linear_8bit_act_{self.bit_width}bit_weight" - )(x, self.packed_weight, self.group_size, self.n, self.k) - - def forward(self, x): - if x.dim() == 2: - res = self._forward_2d(x) - else: - assert x.dim() >= 3 - lead_shape = x.shape[0:-2] - m, k = x.shape[-2], x.shape[-1] - assert k == self.k - res = self._forward_2d(x.reshape(-1, k)) - res = res.reshape(*lead_shape, m, self.n) - - if self.bias is not None: - res = res + self.bias - return res - - -def quantized_linear_from_aqt( - weight: Optional[torch.Tensor], bias: Optional[torch.Tensor] -): - n, k = weight.shape - group_size = weight.tensor_impl.get_layout().group_size - bit_width = weight.tensor_impl.get_layout().bit_width - packed_weight = weight.tensor_impl.packed_weight - if weight.tensor_impl.get_layout().has_bias: - assert bias is None - return QuantizedLinear(packed_weight, n, k, group_size, bit_width, bias) - - -def replace_linear_tensor_subclass_with_module(module: nn.Module): - assert not isinstance(module, nn.Linear) - for name, child in module.named_children(): - if not isinstance(child, nn.Linear): - replace_linear_tensor_subclass_with_module(child) - else: - if not isinstance(child.weight, AffineQuantizedTensor): - continue - if not isinstance( - child.weight.tensor_impl.get_layout(), - PackedLinearInt8DynamicActivationIntxWeightLayout, - ): - continue - if child.weight.tensor_impl.get_layout().target == Target.ATEN: - continue - setattr(module, name, quantized_linear_from_aqt(child.weight, child.bias)) - - -class SharedEmbeddingQuantizer: - def __init__( - self, - weight_dtype: torch.dtype = torch.int4, - granularity: Granularity = PerAxis(0), - mapping_type: MappingType = MappingType.ASYMMETRIC, - ): - self.weight_dtype = weight_dtype - self.granularity = granularity - self.mapping_type = mapping_type - - def quantize( - self, - model: nn.Module, - embedding_to_unembedding: Optional[Mapping[str, str]] = None, - ): - embedding_fqns = get_fqns_with_filter( - model, lambda m, fqn: isinstance(m, nn.Embedding) - ) - linear_fqns = get_fqns_with_filter( - model, lambda m, fqn: isinstance(m, nn.Linear) - ) - state_dict = model.state_dict() - - # If embedding_to_unembedding is not provided, automatically detect shared embeddings and unembeddings - if embedding_to_unembedding is None: - embedding_to_unembedding = {} - for embedding_fqn in embedding_fqns: - embedding_w = state_dict[embedding_fqn + ".weight"] - for linear_fqn in linear_fqns: - linear_w = state_dict[linear_fqn + ".weight"] - if embedding_w.shape == linear_w.shape and torch.allclose( - embedding_w, linear_w - ): - print( - f"Found shared embedding {embedding_fqn} and unembedding {linear_fqn}" - ) - if embedding_fqn not in embedding_to_unembedding: - embedding_to_unembedding[embedding_fqn] = linear_fqn - else: - raise ValueError( - f"Found multiple candidate unembeddings ({embedding_to_unembedding[embedding_fqn]}, {linear_fqn}) for embedding {embedding_fqn}. This is not supported yet. Please explicitly define the input embedding_to_unembedding." - ) - - # Construct reverse mapping - unembedding_to_embedding = {} - for v, k in embedding_to_unembedding.items(): - if k not in unembedding_to_embedding: - unembedding_to_embedding[k] = v - else: - raise ValueError( - f"Found multiple candidate embeddings ({unembedding_to_embedding[k]}, {v}) for unembedding {k}. This is not supported yet." - ) - - # Check that embeddings are shared, embeddings are embeddings, and unembeddings are linear ops - for embedding_fqn, unembedding_fqn in embedding_to_unembedding.items(): - assert embedding_fqn in embedding_fqns, ( - f"Embedding {embedding_fqn} is not found in model" - ) - assert unembedding_fqn in linear_fqns, ( - f"Unembedding {unembedding_fqn} is not found in model" - ) - assert torch.allclose( - state_dict[embedding_fqn + ".weight"], - state_dict[unembedding_fqn + ".weight"], - ), ( - f"Embedding {embedding_fqn} does not share weights with unembedding {unembedding_fqn}" - ) - - # Quantize unembeddings - quantize_( - model, - Int8DynamicActivationIntxWeightConfig_NonExperimental( - weight_dtype=self.weight_dtype, - weight_granularity=self.granularity, - weight_mapping_type=self.mapping_type, - # Only universal layout is supported for shared embedding - layout=PackedLinearInt8DynamicActivationIntxWeightLayout( - target="universal" - ), - ), - filter_fn=lambda m, fqn: isinstance(m, nn.Linear) - and fqn in list(embedding_to_unembedding.values()), - ) - - embedding_fqn_to_quantized_unembedding = {} - for fqn, t in model.state_dict().items(): - if ( - fqn.endswith(".weight") - and fqn[: -len(".weight")] in unembedding_to_embedding - ): - embedding_fqn = unembedding_to_embedding[fqn[: -len(".weight")]] - embedding_fqn_to_quantized_unembedding[embedding_fqn] = t - - _replace_embedding_with_quantized_embedding( - model, - kwargs={ - "embedding_fqn_to_quantized_unembedding": embedding_fqn_to_quantized_unembedding, - }, - ) - - # Remove subclasses. Otherwise there are two packed_weight objects in exported model, - # even though they have the same id in eager mode - replace_linear_tensor_subclass_with_module(model) - def _quantize( vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, signed=True diff --git a/torchao/experimental/quant_passes.py b/torchao/experimental/quant_passes.py deleted file mode 100644 index a7189d792b..0000000000 --- a/torchao/experimental/quant_passes.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import itertools -from collections import defaultdict -from typing import Callable, Optional - -import torch - -# import this for pt2e_quant.dequantize_affine op definition -# should be removed after removing dep on `torch._export.passes.constant_folding` -import torch.ao.quantization.pt2e._affine_quantization # noqa: F401 - -# TODO: remove dependency on ConstantFolder -from torch._export.passes.constant_folding import ( - ConstantFolder, - replace_node_with_constant, -) -from torch.fx import subgraph_rewriter - - -def constant_fold( - gm: torch.fx.GraphModule, - constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, - skip_constructors: bool = False, -): - with torch.utils._python_dispatch._disable_current_modes(): - # The ConstantFolder has a bug where it throws if dequantize_affine is not defined - # TODO: fix upstream - try: - getattr(torch.ops.torchao, "dequantize_affine") - except AttributeError: - setattr(torch.ops.torchao, "dequantize_affine", None) - - cf = ConstantFolder(gm, skip_constructors) - cf.run() - - for node, constant in cf.node_replacements.items(): - if constraint_fn is not None and not constraint_fn(node): - continue - replace_node_with_constant(gm, node, constant) - - erased_params = [] - # Get all attr users by looking up the graph instead from node.users, because in this case - # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. - - # opcode name target args kwargs - # ------------- ------------------- ---------------- --------------------------- -------- - # placeholder arg0_1 arg0 () {} - # get_attr _tensor_constant0 state () {} - # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} - # get_attr _tensor_constant0_1 state () {} - # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} - # output output output ([add],) {} - - get_attr_node_users = defaultdict(list) - for node in gm.graph.nodes: - if node.op == "get_attr": - get_attr_node_users[node.target].extend(node.users.keys()) - for node in gm.graph.find_nodes(op="get_attr"): - if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: - if hasattr(gm, node.target): - delattr(gm, node.target) - erased_params.append(node) - for node in erased_params: - gm.graph.erase_node(node) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - - -def _get_q_dq_linear_patterns_replacements_and_filters( - weight_bit_width, has_weight_zeros, target -): - glbs = globals() - glbs["weight_bit_width"] = weight_bit_width - glbs["target"] = target - glbs["w_quant_min"] = -(1 << (weight_bit_width - 1)) - glbs["w_quant_max"] = (1 << (weight_bit_width - 1)) - 1 - glbs["a_target_dtype"] = torch.int8 - glbs["a_quant_min"] = None - glbs["a_quant_max"] = None - glbs["a_mapping_type"] = "ASYMMETRIC" - glbs["a_scale_dtype"] = torch.float32 - glbs["a_eps"] = torch.finfo(torch.float32).eps - - lcls = {} - - pattern_str = """ -def pattern( - a, a_block_size, a_zero_point_dtype, - w_int_data, w_block_size, w_scale, w_zero_point, w_target_dtype, - bias): - a_scale, a_zero_point = torch.ops.torchao.choose_qparams_affine.default( - a, - a_mapping_type, - a_block_size, - a_target_dtype, - a_quant_min, - a_quant_max, - a_eps, - a_scale_dtype, - a_zero_point_dtype, - ) - a_int_data = torch.ops.torchao.quantize_affine.default( - a, a_block_size, a_scale, a_zero_point, a_target_dtype, a_quant_min, a_quant_max, - ) - dq_a = torch.ops.torchao.dequantize_affine.default( - a_int_data, a_block_size, a_scale, a_zero_point, a_target_dtype, a_quant_min, a_quant_max - ) - dq_w = torch.ops.torchao.dequantize_affine.default( - w_int_data, - w_block_size, - w_scale, - w_zero_point, - w_target_dtype, - w_quant_min, - w_quant_max, - ) - return torch.ops.aten.linear.default(dq_a, dq_w, bias) -""" - exec(pattern_str, glbs, lcls) - pattern = lcls["pattern"] - - replacement_str = f""" -def replacement( - a, a_block_size, a_zero_point_dtype, - w_int_data, w_block_size, w_scale, w_zero_point, w_target_dtype, - bias,): - n = w_int_data.size(0) - k = a_block_size[-1] - group_size = w_block_size[-1] - out_shape = a.shape[:-1] + (n,) - packed_weight = getattr( - torch.ops.torchao, - f"_pack_8bit_act_{weight_bit_width}bit_weight", - )( - w_int_data.to(torch.int8), - w_scale.reshape(-1), - {"w_zero_point.reshape(-1).to(torch.int8)" if has_weight_zeros else "None"}, - group_size, - bias, - target, - ) - return getattr( - torch.ops.torchao, f"_linear_8bit_act_{weight_bit_width}bit_weight" - )(a.reshape(-1, k), packed_weight, group_size, n, k).reshape(out_shape) -""" - - exec(replacement_str, glbs, lcls) - replacement = lcls["replacement"] - - def match_filter(match, x, y): - def get_val(name): - node = [n for n in match.nodes_map if n.name == name][0] - return match.nodes_map[node] - - int_types = [torch.int8, torch.int16, torch.int32, torch.int64] - - a_zero_point_dtype = get_val("a_zero_point_dtype") - if a_zero_point_dtype not in int_types: - return False - - # We only want a_block_size with shape [1, ..., 1, k] - a_block_size = get_val("a_block_size") - for d in a_block_size[0:-1]: - if d != 1: - print("a_block_size not [1, ..., 1, k]") - return False - - # We only want w_block_size with shape [1, group_size] - w_block_size = get_val("w_block_size") - if len(w_block_size) != 2 or w_block_size[0] != 1: - return False - - return True - - return pattern, replacement, match_filter - - -def replace_q_dq_patterns_with_quantized_linear_ops_pass( - ep: torch.export.ExportedProgram, - target=None, -) -> torch.export.ExportedProgram: - """ - This replaces Q/DQ patterns with torchao quantized linear ops. - It is intended for converting Q/DQ nodes exported with QDQLayout to using - the lowbit quantized linear ops. - """ - # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) - # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ - assert len(ep.range_constraints) == 0, ( - "ExportedProgram with range constraints are not supported" - ) - - # ep.module() unlifts the weight inputs, which we need for constant folding - gm = ep.module() - for weight_bit_width, has_weight_zeros in itertools.product( - range(1, 9), [True, False] - ): - pattern, replacement, match_filter = ( - _get_q_dq_linear_patterns_replacements_and_filters( - weight_bit_width, has_weight_zeros, target - ) - ) - subgraph_rewriter.replace_pattern_with_filters( - gm, pattern, replacement, match_filters=[match_filter] - ) - - # Constant fold evaluates and removes the packing ops - constant_fold(gm) - - # Re-export - return torch.export.export(gm, *ep.example_inputs) - - -def _get_q_dq_embedding_patterns_replacements_and_filters( - weight_bit_width, -): - w_quant_min = -(1 << (weight_bit_width - 1)) - w_quant_max = (1 << (weight_bit_width - 1)) - 1 - w_target_dtype = torch.int8 - - def pattern( - indices, - w_int_data, - w_block_size, - w_scale, - w_zero_point, - ): - dq_w = torch.ops.torchao.dequantize_affine.default( - w_int_data, - w_block_size, - w_scale, - w_zero_point, - w_target_dtype, - w_quant_min, - w_quant_max, - ) - return torch.ops.aten.embedding.default(dq_w, indices) - - def replacement( - indices, - w_int_data, - w_block_size, - w_scale, - w_zero_point, - ): - num_embeddings, embedding_dim = w_int_data.size() - packed_weight_qvals = getattr( - torch.ops.torchao, f"_pack_embedding_{weight_bit_width}bit" - )(w_int_data) - out_shape = indices.shape + (embedding_dim,) - group_size = w_block_size[-1] - n_groups = embedding_dim // group_size - w_scale = w_scale.reshape(-1, n_groups) - w_zero_point = w_zero_point.reshape(-1, n_groups) - return getattr(torch.ops.torchao, f"_embedding_{weight_bit_width}bit")( - packed_weight_qvals, - num_embeddings, - embedding_dim, - w_scale, - w_zero_point, - indices.reshape(-1), - ).reshape(out_shape) - - def match_filter(match, x, y): - def get_val(name): - node = [n for n in match.nodes_map if n.name == name][0] - return match.nodes_map[node] - - # We only want w_block_size with shape [1, group_size] - w_block_size = get_val("w_block_size") - if len(w_block_size) != 2 or w_block_size[0] != 1: - return False - - return True - - return pattern, replacement, match_filter - - -def replace_q_dq_patterns_with_quantized_embedding_ops_pass( - ep: torch.export.ExportedProgram, -) -> torch.export.ExportedProgram: - """ - This replaces Q/DQ patterns with torchao quantized embedding ops. - It is intended for converting Q/DQ nodes exported with QDQLayout to using - the lowbit quantized embedding ops. - """ - # TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export) - # See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/ - assert len(ep.range_constraints) == 0, ( - "ExportedProgram with range constraints are not supported" - ) - - # ep.module() unlifts the weight inputs, which we need for constant folding - gm = ep.module() - for weight_bit_width in range(1, 9): - pattern, replacement, match_filter = ( - _get_q_dq_embedding_patterns_replacements_and_filters( - weight_bit_width, - ) - ) - subgraph_rewriter.replace_pattern_with_filters( - gm, pattern, replacement, match_filters=[match_filter] - ) - - # Constant fold evaluates and removes the packing ops - constant_fold(gm) - - # Re-export - return torch.export.export(gm, *ep.example_inputs) diff --git a/torchao/experimental/temp_build.py b/torchao/experimental/temp_build.py deleted file mode 100644 index 3195e24581..0000000000 --- a/torchao/experimental/temp_build.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import glob -import subprocess -import tempfile - -import torch - - -def cmake_build_torchao_ops(cmake_lists_path, temp_build_dir): - from distutils.sysconfig import get_python_lib - - print("Building torchao ops for ATen target") - cmake_prefix_path = get_python_lib() - subprocess.run( - [ - "cmake", - "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, - "-DCMAKE_INSTALL_PREFIX=" + temp_build_dir.name, - "-S " + cmake_lists_path, - "-B " + temp_build_dir.name, - ] - ) - subprocess.run( - [ - "cmake", - "--build", - temp_build_dir.name, - "-j 16", - "--target install", - "--config Release", - ] - ) - - -def temp_build_and_load_torchao_ops(cmake_lists_path): - temp_build_dir = tempfile.TemporaryDirectory() - cmake_build_torchao_ops(cmake_lists_path, temp_build_dir) - libs = glob.glob(f"{temp_build_dir.name}/lib/libtorchao_ops_aten.*") - libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) - assert len(libs) == 1 - torch.ops.load_library(libs[0]) - print(f"TorchAO ops are loaded from {libs[0]}") diff --git a/torchao/experimental/tests/test_load_libtorchao_ops.py b/torchao/experimental/tests/test_load_libtorchao_ops.py deleted file mode 100644 index 4fec52f494..0000000000 --- a/torchao/experimental/tests/test_load_libtorchao_ops.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from pathlib import Path -from unittest.mock import MagicMock, patch - - -class TestLibTorchAoOpsLoader(unittest.TestCase): - def test_find_and_load_success(self): - mock_paths = [Path("/test/path1")] - mock_lib = MagicMock() - mock_lib.__str__.return_value = "/test/path1/libtorchao_ops_aten.so" - - with patch("pathlib.Path.glob", return_value=[mock_lib]): - with patch("torch.ops.load_library") as mock_load: - from ..op_lib import find_and_load_libtorchao_ops - - find_and_load_libtorchao_ops(mock_paths) - - mock_load.assert_called_once_with("/test/path1/libtorchao_ops_aten.so") - - def test_no_library_found(self): - mock_paths = [Path("/test/path1"), Path("/test/path2")] - - with patch("pathlib.Path.glob", return_value=[]): - from ..op_lib import find_and_load_libtorchao_ops - - with self.assertRaises(FileNotFoundError): - find_and_load_libtorchao_ops(mock_paths) - - def test_multiple_libraries_error(self): - mock_paths = [Path("/test/path1")] - mock_lib1 = MagicMock() - mock_lib2 = MagicMock() - mock_libs = [mock_lib1, mock_lib2] - - with patch("pathlib.Path.glob", return_value=mock_libs): - from ..op_lib import find_and_load_libtorchao_ops - - try: - find_and_load_libtorchao_ops(mock_paths) - self.fail("Expected AssertionError was not raised") - except AssertionError as e: - expected_error_msg = f"Expected to find one libtorchao_ops_aten.* library at {mock_paths[0]}, but found 2" - self.assertIn(expected_error_msg, str(e)) - - -if __name__ == "__main__": - unittest.main() diff --git a/torchao/experimental/tests/test_quant_passes.py b/torchao/experimental/tests/test_quant_passes.py deleted file mode 100644 index b133e1ee01..0000000000 --- a/torchao/experimental/tests/test_quant_passes.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import itertools -import unittest - -import torch -from parameterized import param, parameterized -from torch.testing import FileCheck - -from torchao.dtypes import QDQLayout -from torchao.experimental.quant_passes import ( - replace_q_dq_patterns_with_quantized_embedding_ops_pass, - replace_q_dq_patterns_with_quantized_linear_ops_pass, -) -from torchao.quantization.granularity import PerAxis, PerGroup -from torchao.quantization.quant_api import ( - Int8DynamicActivationIntxWeightConfig, - IntxWeightOnlyConfig, - MappingType, - quantize_, -) - - -class TestQuantPasses(unittest.TestCase): - def test_replace_q_dq_patterns_with_quantized_linear_ops_pass(self): - layers = [] - layer_to_weight_dtype = {} - layer_to_weight_mapping_type = {} - layer_to_weight_granularity = {} - for ( - weight_dtype, - weight_mapping_type, - weight_granularity, - has_bias, - ) in itertools.product( - [getattr(torch, f"int{i}") for i in range(1, 9)], - [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], - [PerAxis(0), PerGroup(32)], - [True, False], - ): - idx = len(layers) - layer_to_weight_dtype[idx] = weight_dtype - layer_to_weight_mapping_type[idx] = weight_mapping_type - layer_to_weight_granularity[idx] = weight_granularity - layers.append(torch.nn.Linear(64, 64, bias=has_bias)) - - activations = torch.randn(2, 1, 64, dtype=torch.float32) - model = torch.nn.Sequential(*layers) - for idx in range(len(layers)): - quantize_( - model, - Int8DynamicActivationIntxWeightConfig( - weight_dtype=layer_to_weight_dtype[idx], - weight_mapping_type=layer_to_weight_mapping_type[idx], - weight_granularity=layer_to_weight_granularity[idx], - layout=QDQLayout(), - ), - lambda m, fqn: fqn == str(idx), - ) - - eager_results = model(activations) - exported = torch.export.export(model, (activations,), strict=True) - exported = replace_q_dq_patterns_with_quantized_linear_ops_pass( - exported, target="universal" - ) - - # We should not find pack op because it gets constant folded - FileCheck().check_not("torch.ops.torchao._pack_8bit_act").run( - exported.graph_module.code - ) - - # We should find len(layers) torchao linear ops - FileCheck().check_count( - "torch.ops.torchao._linear_8bit_act_", count=len(layers), exactly=True - ).run(exported.graph_module.code) - - # We should not find Q/DQ ops - FileCheck().check_not("torch.ops.torchao.quantize_affine.default").run( - exported.graph_module.code - ) - FileCheck().check_not("torch.ops.torchao.dequantize_affine.default").run( - exported.graph_module.code - ) - FileCheck().check_not("torch.ops.torchao.choose_qparams_affine.default").run( - exported.graph_module.code - ) - - # Numerics should match - exported_results = exported.module()(activations) - self.assertTrue(torch.allclose(exported_results, eager_results)) - - @parameterized.expand( - [ - param(weight_dtype=weight_dtype, granularity=granularity) - for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)] - for granularity in [PerAxis(0), PerGroup(32)] - ], - name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", - ) - def test_replace_q_dq_patterns_with_quantized_embedding_ops_pass( - self, weight_dtype, granularity - ): - # Calling torch.export many times in a parametrized test causes - # torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached error - # Setting cache_size_limit to a large number to avoid this error - torch._dynamo.config.cache_size_limit = 10000 - - mapping_type = MappingType.ASYMMETRIC - - model = torch.nn.Sequential( - *[torch.nn.Embedding(5000, 512), torch.nn.Linear(512, 512)] - ) - indices = torch.randint(0, 5000, (4, 5, 17), dtype=torch.int32) - - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=weight_dtype, - granularity=granularity, - mapping_type=mapping_type, - layout=QDQLayout(), - ), - lambda m, fqn: isinstance(m, torch.nn.Embedding), - ) - eager_results = model(indices) - - exported = torch.export.export(model, (indices,), strict=True) - exported = replace_q_dq_patterns_with_quantized_embedding_ops_pass(exported) - - # We should not find pack op because it gets constant folded - FileCheck().check_not("torch.ops.torchao._pack_embedding").run( - exported.graph_module.code - ) - - # We should find - FileCheck().check_count( - "torch.ops.torchao._embedding", count=1, exactly=True - ).run(exported.graph_module.code) - - # We should not find Q/DQ ops - FileCheck().check_not("torch.ops.torchao.dequantize_affine.default").run( - exported.graph_module.code - ) - - # Numerics should match - exported_results = exported.module()(indices) - self.assertTrue(torch.allclose(exported_results, eager_results)) - - -if __name__ == "__main__": - unittest.main() diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 8533a05779..9747070ac6 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -6,26 +6,68 @@ and up to [**1.25x at 8 GPU / 8B parameter count scale**](#training-benchmarks). The codebase strives to stay small, hackable, debuggable with native PyTorch tooling and composable with key systems such as autograd, ```torch.compile``` and distributed. +## Key features + +* e2e pretraining speedups of up to [**1.5x at 512 GPU / 405B parameter count scale**](https://pytorch.org/blog/training-using-float8-fsdp2/), +and up to [**1.25x at 8 GPU / 8B parameter count scale**](#training-benchmarks), with performance and accuracy validated on up to [**2k GPUs**](https://pytorch.org/blog/accelerating-large-scale-training-and-convergence-with-pytorch-float8-rowwise-on-crusoe-2k-h200s/), via [torchtitan's float8 integration](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md) +* seamless composability with [torch.compile](https://docs.pytorch.org/docs/stable/torch.compiler.html), [DTensor](https://docs.pytorch.org/docs/stable/distributed.tensor.html), [FSDP2 with float8 weight all-gather](https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359), [Async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487), and [PyTorch AC](https://pytorch.org/blog/activation-checkpointing-techniques/) +* three recipes to trade off performance vs accuracy: `tensorwise` (fastest), `rowwise`, `rowwise_with_gw_hp` (most accurate) +* supports both NVIDIA and AMD hardware + ℹ️ See the [feature tracker](https://github.com/pytorch/ao/issues/556) for upcoming features. -ℹ️ These APIs are training-only and float8-only, and we plan to [unify them with the rest of torchao](https://github.com/pytorch/ao/issues/894) in the future. +# e2e training benchmarks -# Single GPU User API +[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance. + +#### NVIDIA H100 + +- Single-node training on 8xH100 GPUs, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC +- pytorch version: `2.7.0a0+gitb98af95`, torchao version: `0.10.0+git890e0ac8`, torchtitan version: `0.0.2` + +| Model | Scaling | Peak Memory (GB) | Median tokens/second | Speedup over baseline +| ------------- | ---------------------------------- | ------------------| -------------------- | --------------------- +| Llama3-8b | none (bfloat16) | 47.65 | 6150 | - +| Llama3-8b | tensorwise with float8 all-gather | 47.77 | 7689.5 | 25.03% +| Llama3-8b | rowwise with bfloat16 all-gather | 47.79 | 6768 | 10.05% + +#### AMD MI300x -## float8 linear with dynamic tensorwise scaling +- Single-node training on 8xMI300X GPUs, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC +- pytorch version: `2.9.0.dev20250811+rocm6.4`, torchao version `0.13.0+git4fc4068d6`, torchtitan commit `2c8b5947991239913d67e2f7d22a255c3e2a9694` -This is the default recipe, with a good balance of performance and accuracy. +| Model | Scaling | Peak Memory (GB) | Median tokens/second | Speedup over baseline +| ------------- | ---------------------------------- | ------------------| -------------------- | --------------------- +| Llama3-8b | none (bfloat16) | 39.09 | 5376.5 | - +| Llama3-8b | tensorwise with float8 all-gather | 39.07 | 6166.0 | 14.68% +| Llama3-8b | rowwise_with_gw_hp with bfloat16 all-gather | 39.32 | 6100.0 | 13.46% +| Llama3-8b | rowwise with bfloat16 all-gather | 39.32 | 5891.0 | 9.57% + +**Important notes**: +- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ([example](https://pytorch.org/blog/training-using-float8-fsdp2/)). +- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve. + +**Reproducing training benchmarks** +To reproduce these benchmarks, you can follow these steps: + +1. On a machine with compatible GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation), +including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer). +2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation). +3. From the `torchao/` directory, you can run the following commands to reproduce the benchmarks above: + - bf16 + compile: `TORCHTITAN_ROOT= ./benchmarks/float8/training/llama3.sh` + - float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./benchmarks/float8/training/llama3.sh` + - float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./benchmarks/float8/training/llama3.sh` + +See the float8 training benchmarking [guide](.torchao/benchmarks/float8/training/README.md) for more details. + +# Single GPU User API ```python import time import torch import torch.nn as nn -from torchao.float8 import convert_to_float8_training -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") +from torchao.float8 import convert_to_float8_training, Float8LinearConfig # create model and sample input M, K, N = 4096, 8192, 4096 @@ -47,8 +89,15 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str): return False return True +# configure float8 recipe +# valid recipe names: "tensorwise", "rowwise", "rowwise_with_gw_hp" +config = Float8LinearConfig.from_recipe_name("tensorwise") + # convert specified `torch.nn.Linear` modules to `Float8Linear` -convert_to_float8_training(m, module_filter_fn=module_filter_fn) +convert_to_float8_training(m, config=config, module_filter_fn=module_filter_fn) + +# display converted model +print(m) # enable torch.compile for competitive performance m = torch.compile(m) @@ -75,55 +124,6 @@ end_time = time.time() print("Training time:", end_time - start_time) ``` -## float8 linear with rowwise scaling - -This is a more accurate recipe compared to tensorwise, with more granular scaling. - -```python -import torch -import torch.nn as nn -from torchao.float8 import convert_to_float8_training, Float8LinearConfig -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") - -# create model and sample input -m = nn.Sequential( - nn.Linear(2048, 4096), - nn.Linear(4096, 128), -).bfloat16().cuda() -x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) -optimizer = torch.optim.SGD(m.parameters(), lr=0.1) - -# optional: filter modules from being eligible for float8 conversion -def module_filter_fn(mod: torch.nn.Module, fqn: str): - # don't convert the last module - if fqn == "1": - return False - # don't convert linear modules with weight dimensions not divisible by 16 - if isinstance(mod, torch.nn.Linear): - if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: - return False - return True - -# configure rowwise scaling -config = Float8LinearConfig.from_recipe_name("rowwise") - -# convert specified `torch.nn.Linear` modules to `Float8Linear` -convert_to_float8_training(m, config=config, module_filter_fn=module_filter_fn) - -# enable torch.compile for competitive performance -m = torch.compile(m) - -# toy training loop -for _ in range(10): - optimizer.zero_grad() - y = m(x) - y.sum().backward() - optimizer.step() -``` - # Multi GPU User API We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html), @@ -134,30 +134,32 @@ on using `torchao.float8` in a distributed setting. A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the tables below for a microbenchmark based speedup estimate on NVIDIA H100: -### Tensorwise scaling +### tensorwise scaling -float8_speedup +Image -Example 1 (small shapes): -* forward input tensor size 1024x2048, linear weight size 2048x1024; M, K, N = 1024, 2048, 1024 -* benchmark speedup is 0.80 -* recommendation: leave this linear in bfloat16, the shapes are too small to benefit from float8 compute +```lang=shell +# reproduction: run the script below +python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep +``` -Example 2 (large shapes): -* forward input tensor size 4096x8192, linear weight size 8192x16384; M, K, N = 4096, 8192, 16384 -* benchmark speedup is 1.39 -* recommendation: enable float8 for this linear to get a speedup +### rowwise scaling -To reproduce the raw data for table above, you can run the following script +Image ```lang=shell -python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep +# reproduction: run the script below +python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep --float8_recipe_name rowwise ``` -### Rowwise scaling +### rowwise_with_gw_hp scaling -float8_rowwise_speedup +Image +```lang=shell +# reproduction: run the script below +python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep --float8_recipe_name rowwise_with_gw_hp +``` ## Derivation @@ -205,56 +207,6 @@ python test/float8/test_fsdp2/test_fsdp2.py ./test/float8/test_everything.sh ``` -# Benchmarking - -```bash -# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes -./benchmarks/float8/bench_matmul.py - -# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes -# make sure to turn on torch.compile to get the best performance -./benchmarks/float8/bench_linear_float8.py -o ../tmp/test.txt --compile -``` - -### Training benchmarks - -[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance, for both rowwise -and tensorwise scaling. The training benchmarks were all run using: - -- Single-node training on 8xH100 GPUs -- Batch size 1 -- Sequence length 8192 -- Steps 100 -- `torch.compile` -- FSDP2 -- pytorch version: `2.7.0a0+gitb98af95` -- torchao version: `0.10.0+git890e0ac8` -- torchtitan version: `0.0.2` - - -| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline -| ------------- | ---------------------------------- | ------------------------ | ------------------| -------------------- | --------------------- -| Llama3-8b | none (bfloat16) | per op SAC | 47.65 | 6150 | - -| Llama3-8b | tensorwise with float8 all-gather | per op SAC | 47.77 | 7689.5 | 25.03% -| Llama3-8b | rowwise with bfloat16 all-gather | per op SAC | 47.79 | 6768 | 10.05% - -**Important notes**: -- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ([example](https://pytorch.org/blog/training-using-float8-fsdp2/)). -- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve. - -**Reproducing training benchmarks** -To reproduce these benchmarks, you can follow these steps: - -1. On a machine with 8 H100 GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation), -including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer). -2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation). -3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above: - - bf16 + compile: `TORCHTITAN_ROOT= ./float8_training_benchmark.sh` - - float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./float8_training_benchmark.sh` - - float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT= FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh` - -See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details. - # E2E training + inference flow The first step in the E2E is to train your model and save a checkpoint. The second step is to load the checkpoint and optionally apply inference quantization before serving the model. @@ -267,10 +219,6 @@ import torch.nn.functional as F from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_linear import Float8Linear from torchao.float8 import convert_to_float8_training -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = nn.Sequential( diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 4f90292918..04589312a2 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -1,4 +1,7 @@ # Lets define a few top level things here +# Needed to load Float8TrainingTensor with weights_only = True +from torch.serialization import add_safe_globals + from torchao.float8.config import ( CastConfig, Float8GemmConfig, @@ -10,8 +13,8 @@ _auto_filter_for_recipe, convert_to_float8_training, ) -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, ScaledMMConfig, @@ -19,22 +22,17 @@ from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from torchao.float8.inference import Float8MMConfig from torchao.float8.types import FP8Granularity -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if TORCH_VERSION_AT_LEAST_2_5: - # Needed to load Float8Tensor with weights_only = True - from torch.serialization import add_safe_globals - add_safe_globals( - [ - Float8Tensor, - ScaledMMConfig, - GemmInputRole, - LinearMMConfig, - Float8MMConfig, - ScalingGranularity, - ] - ) +add_safe_globals( + [ + Float8TrainingTensor, + ScaledMMConfig, + GemmInputRole, + LinearMMConfig, + Float8MMConfig, + ScalingGranularity, + ] +) __all__ = [ # configuration @@ -50,5 +48,5 @@ "_auto_filter_for_recipe", # types "FP8Granularity", - # note: Float8Tensor and Float8Linear are not public APIs + # note: Float8TrainingTensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 939f68e59a..b362390946 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -333,6 +333,7 @@ def from_recipe_name( cast_config_input_for_grad_weight=cc_i_gw, cast_config_weight_for_grad_input=cc_w_gi, cast_config_grad_output_for_grad_weight=cc_go_gw, + round_scales_to_power_of_2=True, ) else: diff --git a/torchao/float8/distributed_utils.py b/torchao/float8/distributed_utils.py index cd1560fabd..a278640af8 100644 --- a/torchao/float8/distributed_utils.py +++ b/torchao/float8/distributed_utils.py @@ -8,7 +8,7 @@ import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import DTensor -from torchao.float8.float8_tensor import Float8Tensor +from torchao.float8.float8_training_tensor import Float8TrainingTensor def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: @@ -16,7 +16,7 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: Check if the tensor is already casted to fp8, works if the local tensor is wrapped in DTensor. """ - if isinstance(tensor, Float8Tensor): + if isinstance(tensor, Float8TrainingTensor): return True elif isinstance(tensor, DTensor): # TODO: shall we stick to public API and directly use tensor.to_local() here? diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index fbafc1a393..a946835a4d 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -11,51 +11,20 @@ import torch -from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType +from torchao.float8.config import Float8LinearConfig, ScalingType from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( get_maybe_axiswise_dim, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( +from torchao.float8.float8_training_tensor import ( GemmInputRole, LinearMMConfig, ScaledMMConfig, - hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -def _get_weight_scale( - weight: torch.Tensor, - scaling_type_weight: ScalingType, - config: Float8LinearConfig, -) -> Optional[torch.Tensor]: - if tensor_already_casted_to_fp8(weight): - return None - assert scaling_type_weight is ScalingType.DYNAMIC - return tensor_to_scale(weight, config.cast_config_weight.target_dtype) - - -def _cast_weight_to_float8_t( - weight: torch.Tensor, - config: Float8LinearConfig, - linear_mm_config: LinearMMConfig, - weight_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - if tensor_already_casted_to_fp8(weight): - return weight.t() - weight_fp8 = hp_tensor_and_scale_to_float8( - weight, - weight_scale, - config.cast_config_weight.target_dtype, - linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) - return weight_fp8.t() - - @torch._dynamo.allow_in_graph class matmul_with_hp_or_float8_args(torch.autograd.Function): """ @@ -159,21 +128,6 @@ def backward(ctx, grad_output): elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED: weight_t_maybe_fp8_dim0 = weight_hp_t else: - if ( - c.cast_config_weight_for_grad_input.scaling_granularity - is ScalingGranularity.AXISWISE - ): - # workaround from https://github.com/pytorch/pytorch/issues/141881 - # to avoid saving float8 weight from forward to backward when - # FSDP is on: add a fake dependency on `grad_output`. - g_reshaped = grad_output.reshape(-1, grad_output.shape[-1]) * 0 - zero = g_reshaped[:1] * 0 - weight_hp_t = weight_hp_t + zero - - # Note: we need https://github.com/pytorch/pytorch/issues/136267 - # to be solved to have a chance to reuse max(abs(weight, dim=...)) - # from the forward to get max(abs(weight)) here without reading - # the entire tensor. weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, c.cast_config_weight_for_grad_input.target_dtype, @@ -307,39 +261,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: autocast_dtype = torch.get_autocast_gpu_dtype() input = input.to(autocast_dtype) - has_any_axiswise_scaling = any( - cc.scaling_granularity is ScalingGranularity.AXISWISE - for cc in [ - self.config.cast_config_input, - self.config.cast_config_weight, - self.config.cast_config_grad_output, - self.config.cast_config_input_for_grad_weight, - self.config.cast_config_weight_for_grad_input, - self.config.cast_config_grad_output_for_grad_weight, - ] - ) - - weight_maybe_fp8_t = self.weight.t() - - # TODO(future PR): check for axiswise scaling for input, weight, - # grad_output separately instead of together - if not has_any_axiswise_scaling: - # TODO(future PR): now that `force_recompute_fp8_weight_in_bwd` is - # deprecated, we can simplify the below code and unify the per-tensor - # and per-axis paths further. - weight_scale = _get_weight_scale( - self.weight, self.scaling_type_weight, self.config - ) - weight_maybe_fp8_t = _cast_weight_to_float8_t( - self.weight, - self.config, - self.linear_mm_config, - weight_scale, - ) - output = matmul_with_hp_or_float8_args.apply( input, - weight_maybe_fp8_t, + self.weight.t(), self.linear_mm_config, self.config, ) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 0d9674e6c3..e0def790b8 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -7,6 +7,7 @@ from functools import partial from typing import Callable, List, Optional, Union +import torch import torch.nn as nn from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName @@ -101,6 +102,7 @@ def convert_to_float8_training( Returns: nn.Module: The modified module with swapped linear layers. """ + torch._C._log_api_usage_once("torchao.float8.convert_to_float8_training") if config is None: config = Float8LinearConfig() diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 4071d83e4f..3f244ddadf 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -8,7 +8,10 @@ import torch from torch.utils._pytree import tree_map -from torchao.float8.float8_tensor import Float8Tensor, choose_scaled_mm_config +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, + choose_scaled_mm_config, +) from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul aten = torch.ops.aten @@ -18,7 +21,7 @@ # [Note] Usage of scales -# The meaning of scale in this library can be found in the definition of the Float8Tensor +# The meaning of scale in this library can be found in the definition of the Float8TrainingTensor # Cublas defines scale to always mean a multiplicative factor for the respective matrices # For a,b going from fp8 -> fp32 we multiple by the inverse of the scale # For output going from fp32 -> fp8 we multiply by the scale @@ -33,7 +36,7 @@ def addmm_float8_unwrapped( use_fast_accum: bool = False, ) -> torch.Tensor: """ - This is the unwrapped version of addmm_float8, which does not take in Float8Tensors + This is the unwrapped version of addmm_float8, which does not take in Float8TrainingTensors as inputs. This is used to standardize the logic between subclassed and non subclassed versions of the linear module. """ @@ -54,6 +57,12 @@ def addmm_float8_unwrapped( a_inverse_scale = a_inverse_scale.new_ones(()) b_inverse_scale = a_inverse_scale.new_ones(()) + # work around torch._scaled_mm not having float32 output type + # TODO(pytorch/pytorch#156771): remove this once torch._scaled_mm supports float32 output + orig_dtype = output_dtype + if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling: + output_dtype = torch.bfloat16 + post_bias = None if output_dtype == torch.float32: # Bias is not supported by _scaled_mm when output is fp32 @@ -76,6 +85,9 @@ def addmm_float8_unwrapped( if post_bias is not None: output += post_bias + if orig_dtype in (torch.float16, torch.float32) and is_rowwise_scaling: + output = output.to(orig_dtype) + return output @@ -115,7 +127,7 @@ def decorator(func): def float8_desugar_op(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) new_data = aten_op(args[0]._data, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( new_data, args[0]._scale, args[0]._orig_dtype, @@ -132,7 +144,7 @@ def float8_desugar_op(aten_op, args, kwargs=None): def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( new_data, new_scale, args[0]._orig_dtype, @@ -165,7 +177,7 @@ def float8_transpose(aten_op, args, kwargs=None): else: new_axiswise_dim == 0 - return Float8Tensor( + return Float8TrainingTensor( new_data, new_scale, args[0]._orig_dtype, @@ -183,7 +195,7 @@ def float8_view(aten_op, args, kwargs=None): # note that we have to create a new wrapper to make PyTorch internals happy if new_shape == list(t._data.shape): new_data = aten_op(args[0]._data, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( new_data, args[0]._scale, args[0]._orig_dtype, @@ -203,7 +215,7 @@ def float8_view(aten_op, args, kwargs=None): new_data = aten_op(t._data, new_shape, **kwargs) new_scale_shape = [1, new_shape[-1]] new_scale = aten_op(t._scale, new_scale_shape, **kwargs) - return Float8Tensor( + return Float8TrainingTensor( new_data, new_scale, t._orig_dtype, @@ -216,7 +228,7 @@ def float8_view(aten_op, args, kwargs=None): new_scale_shape = [new_shape[0], 1] new_scale = aten_op(t._scale, new_scale_shape, **kwargs) new_axiswise_dim = -1 - return Float8Tensor( + return Float8TrainingTensor( new_data, new_scale, t._orig_dtype, @@ -236,7 +248,7 @@ def float8_split(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) def make_float8(data): - return Float8Tensor( + return Float8TrainingTensor( data, args[0]._scale, args[0]._orig_dtype, @@ -251,7 +263,7 @@ def make_float8(data): # Errors cant `cat_cuda float8 e4m3fn` @implements([aten.cat.default]) def float8_cat(aten_op, args, kwargs=None): - chunked_tensors: Tuple[Float8Tensor] = args[0] + chunked_tensors: Tuple[Float8TrainingTensor] = args[0] orig_dtype = chunked_tensors[0]._orig_dtype scale = chunked_tensors[0]._scale @@ -260,8 +272,8 @@ def float8_cat(aten_op, args, kwargs=None): gemm_input_role = chunked_tensors[0]._gemm_input_role chunk_data = [] for chunk in chunked_tensors: - assert isinstance(chunk, Float8Tensor), ( - "Expecting all chunks to be of type Float8Tensor" + assert isinstance(chunk, Float8TrainingTensor), ( + "Expecting all chunks to be of type Float8TrainingTensor" ) assert chunk._orig_dtype == orig_dtype, ( "Expecting all chunks to be of the same dtype" @@ -283,7 +295,7 @@ def float8_cat(aten_op, args, kwargs=None): new_data = aten_op(chunk_data, *args[1:], **kwargs) new_data = new_data.view(fp8_dtype) - return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) + return Float8TrainingTensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) @implements([aten.sum.dim_IntList]) @@ -298,7 +310,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) def unwrap(x): - if isinstance(x, Float8Tensor): + if isinstance(x, Float8TrainingTensor): return x.to_original_precision() return x @@ -307,7 +319,7 @@ def unwrap(x): return aten_op(*new_args, **new_kwargs) -def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): +def preprocess_addmm(a: Float8TrainingTensor, b: Float8TrainingTensor): a_data = a._data a_scale = a._scale b_data = b._data @@ -353,10 +365,10 @@ def float8_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] - assert isinstance(a, Float8Tensor) and isinstance(b, Float8Tensor), ( - "Expecting both Float8Tensor for mm inputs but found {} and {}".format( - type(a), type(b) - ) + assert isinstance(a, Float8TrainingTensor) and isinstance( + b, Float8TrainingTensor + ), "Expecting both Float8TrainingTensor for mm inputs but found {} and {}".format( + type(a), type(b) ) a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype @@ -387,8 +399,8 @@ def float8_mm(aten_op, args, kwargs=None): def float8_addmm(aten_op, args, kwargs=None): assert ( isinstance(args[0], torch.Tensor) - and isinstance(args[1], Float8Tensor) - and isinstance(args[2], Float8Tensor) + and isinstance(args[1], Float8TrainingTensor) + and isinstance(args[2], Float8TrainingTensor) ) bias = args[0] a = args[1] @@ -429,24 +441,24 @@ def float8_is_same_size(aten_op, args, kwargs=None): @implements([aten._to_copy.default]) def autocast_to_copy(aten_op, args, kwargs=None): """This gets called when running matmul under autocast - when the input is a Float8Tensor, presenting as a fp32 + when the input is a Float8TrainingTensor, presenting as a fp32 tensor. """ - _assert_tensorwise_scale(aten_op, args[0]._scale) - assert isinstance(args[0], Float8Tensor) + assert isinstance(args[0], Float8TrainingTensor) assert len(kwargs) == 1 and "dtype" in kwargs, ( "Only support dtype kwarg for autocast" ) assert kwargs["dtype"] in { torch.float16, torch.bfloat16, - }, "Only support floating point conversion for autocast w/ Float8Tensor" - return Float8Tensor( + }, "Only support floating point conversion for autocast w/ Float8TrainingTensor" + return Float8TrainingTensor( args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._linear_mm_config, args[0]._gemm_input_role, + args[0]._axiswise_dim, ) @@ -462,14 +474,14 @@ def allgather_fp8(aten_op, args, kwargs=None): """ _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] - assert isinstance(fp8_input, Float8Tensor), ( - f"expecting a Float8Tensor for allgather but found {type(fp8_input)}" + assert isinstance(fp8_input, Float8TrainingTensor), ( + f"expecting a Float8TrainingTensor for allgather but found {type(fp8_input)}" ) fp8_data = fp8_input._data fp8_data = fp8_data.contiguous() fp8_out = aten_op(fp8_data, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( fp8_out, fp8_input._scale, fp8_input._orig_dtype, @@ -482,11 +494,11 @@ def allgather_fp8(aten_op, args, kwargs=None): def wait_tensor_fp8(aten_op, args, kwargs=None): _assert_tensorwise_scale(aten_op, args[0]._scale) fp8_input = args[0] - assert isinstance(fp8_input, Float8Tensor) + assert isinstance(fp8_input, Float8TrainingTensor) fp8_data = fp8_input._data fp8_out = aten_op(fp8_data, *args[1:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( fp8_out, fp8_input._scale, fp8_input._orig_dtype, @@ -499,8 +511,8 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): def index_put_fp8(aten_op, args, kwargs=None): fp8_self = args[0] fp8_values = args[2] - assert isinstance(fp8_self, Float8Tensor) - assert isinstance(fp8_values, Float8Tensor) + assert isinstance(fp8_self, Float8TrainingTensor) + assert isinstance(fp8_values, Float8TrainingTensor) _assert_tensorwise_scale(fp8_self, args[0]._scale) assert fp8_self._scale == fp8_values._scale assert fp8_self.dtype == fp8_values.dtype @@ -509,7 +521,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_data = fp8_self._data fp8_values_data = fp8_values._data fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( fp8_out, fp8_self._scale, fp8_self._orig_dtype, @@ -520,39 +532,43 @@ def index_put_fp8(aten_op, args, kwargs=None): @implements([aten.copy_.default]) def copy_fp8(aten_op, args, kwargs=None): - # For a copy op with Float8Tensors involved, only the following combinations are allowed: - # 1. self is a high precision (hp) tensor, src is a Float8Tensor: + # For a copy op with Float8TrainingTensors involved, only the following combinations are allowed: + # 1. self is a high precision (hp) tensor, src is a Float8TrainingTensor: # in this case src is upcasted and unscaled to go into the hp tensor - # 2. self and src are Float8Tensors: - # the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat) + # 2. self and src are Float8TrainingTensors: + # the copy is only allowed if all the Float8TrainingTensor properties are equal (a la torch.cat) # Every other combination is banned as the semantics are not well defined self = args[0] src = args[1] - if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + if not isinstance(self, Float8TrainingTensor) and isinstance( + src, Float8TrainingTensor + ): src_hp = src.to_original_precision() _assert_tensorwise_scale(aten_op, src._scale) return aten_op(self, src_hp, *args[2:], **kwargs) - elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + elif isinstance(self, Float8TrainingTensor) and isinstance( + src, Float8TrainingTensor + ): _assert_tensorwise_scale(aten_op, src._scale) assert self._orig_dtype == src._orig_dtype, ( - "Expecting both Float8Tensors to be of the same dtype" + "Expecting both Float8TrainingTensors to be of the same dtype" ) assert self._scale == src._scale, ( - "Expecting both Float8Tensors to have thee same scale" + "Expecting both Float8TrainingTensors to have thee same scale" ) assert self._linear_mm_config == src._linear_mm_config, ( - "Expecting both Float8Tensors to have thee same mm config" + "Expecting both Float8TrainingTensors to have thee same mm config" ) assert self._data.dtype == src._data.dtype, ( - "Expecting both Float8Tensors to be of the same dtypet" + "Expecting both Float8TrainingTensors to be of the same dtypet" ) assert self._gemm_input_role == src._gemm_input_role, ( - "Expecting both Float8Tensors to have the same gemm_input_role" + "Expecting both Float8TrainingTensors to have the same gemm_input_role" ) fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) - return Float8Tensor( + return Float8TrainingTensor( fp8_out, self._scale, self._orig_dtype, @@ -560,4 +576,4 @@ def copy_fp8(aten_op, args, kwargs=None): self._gemm_input_role, ) else: - raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") + raise RuntimeError("Unsupported semantics for copy_ in Float8TrainingTensor") diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 31f2db6b4e..5a9138a1e9 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -14,8 +14,8 @@ from torchao.float8.config import ScalingGranularity from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, hp_tensor_and_scale_to_float8, @@ -36,10 +36,10 @@ def hp_tensor_to_float8_dynamic( scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, round_scales_to_power_of_2: bool = False, -) -> Float8Tensor: +) -> Float8TrainingTensor: """ Given a high precision tensor `hp_tensor`, - scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result. + scales `hp_tensor` dynamically and returns a `Float8TrainingTensor` of the result. Args: hp_tensor: the tensor to convert diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index 36ae6d587e..175712c231 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -19,7 +19,7 @@ NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import GemmInputRole +from torchao.float8.float8_training_tensor import GemmInputRole # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support @@ -62,7 +62,7 @@ def _prepare_input_fn( mod.config.cast_config_input.target_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + ) # DTensor(Float8TrainingTensor) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: @@ -79,7 +79,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me placements=output_layouts, async_op=True ) # DTensor(torch.Tensor) - # fwd noop bwd cast to DTensor(Float8Tensor) + # fwd noop bwd cast to DTensor(Float8TrainingTensor) outputs = NoopFwToFloat8BwDynamic.apply( outputs, mod.linear_mm_config, @@ -126,7 +126,7 @@ def _prepare_input_fn( mod.config.cast_config_input.target_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + ) # DTensor(Float8TrainingTensor) if input_layouts != desired_input_layouts: input_tensor = input_tensor.redistribute( @@ -142,7 +142,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me if outputs.placements != output_layouts: outputs = outputs.redistribute(placements=output_layouts, async_op=True) - # fwd noop bwd cast to DTensor(Float8Tensor) + # fwd noop bwd cast to DTensor(Float8TrainingTensor) outputs = NoopFwToFloat8BwDynamic.apply( outputs, mod.linear_mm_config, @@ -173,7 +173,7 @@ class PrepareFloat8ModuleInput(PrepareModuleInput): currently assumes tensorwise scaling. The only difference from `PrepareModuleInput` is that - after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) + after we prepare the input DTensor, we cast the input to DTensor(Float8TrainingTensor) This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) so that if there are multiple float8 users of the input activation, we perform fp8 allgather only once. @@ -234,7 +234,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): e4m3_dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, - ) # DTensor(Float8Tensor) + ) # DTensor(Float8TrainingTensor) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_training_tensor.py similarity index 94% rename from torchao/float8/float8_tensor.py rename to torchao/float8/float8_training_tensor.py index 6b5177e1fe..568721a3d7 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_training_tensor.py @@ -66,7 +66,7 @@ class LinearMMConfig(NamedTuple): Configuration for different gemm operations in LinearMM. This configuration is not user-facing and exists for convenience, - allowing Float8Tensor to use the right config based on which gemm + allowing Float8TrainingTensor to use the right config based on which gemm from gemms with outputs `output`, `grad_input`, `grad_weight` is being called. Attributes: @@ -82,7 +82,7 @@ class LinearMMConfig(NamedTuple): class GemmInputRole(enum.Enum): """ - Given a Float8Tensor, the enum below describes the expected role of this + Given a Float8TrainingTensor, the enum below describes the expected role of this tensor in the three gemms present in the fw + bw pass of a Linear layer. This is used to choose the right config for a float8 gemm when the gemm is performed. @@ -138,7 +138,7 @@ def forward( axiswise_dim: Optional[int] = None, ): """ - This function will apply the scaling, and then convert to a Float8Tensor + This function will apply the scaling, and then convert to a Float8TrainingTensor Note: We will call this function with a DTensor subclass. Ideally this would be an aten OP @@ -161,7 +161,7 @@ def forward( bits_placements = bits_fp8.placements local_bits = bits_fp8.to_local() local_scale = scale.to_local() - inner_float8_tensor = Float8Tensor( + inner_float8_tensor = Float8TrainingTensor( local_bits, local_scale, tensor.dtype, @@ -178,7 +178,7 @@ def forward( stride=bits_fp8.stride(), ) - return Float8Tensor( + return Float8TrainingTensor( bits_fp8, scale, tensor.dtype, @@ -219,10 +219,10 @@ def hp_tensor_and_scale_to_float8( ): """ Given a high precision tensor `hp_tensor` and a precalculated scale `s`, - scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result. + scales `hp_tensor` by `s` and returns a `Float8TrainingTensor` of the result. Autograd-aware, the derivative is pass-through. - DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor). + DTensor-aware, if the input is a DTensor the output will be DTensor(Float8TrainingTensor). Args: hp_tensor: the tensor to convert @@ -239,7 +239,7 @@ def hp_tensor_and_scale_to_float8( ) -class Float8Tensor(torch.Tensor): +class Float8TrainingTensor(torch.Tensor): """ Note: this is **not** a public API and is only intended to be used inside of this repository. Please file an issue if you would benefit @@ -319,7 +319,7 @@ def __new__( return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" + return f"Float8TrainingTensor(lp_dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { @@ -333,7 +333,7 @@ def __tensor_flatten__(self): @staticmethod def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): assert len(inner_tensors) == 2 - return Float8Tensor( + return Float8TrainingTensor( inner_tensors["_data"], inner_tensors["_scale"], metadata["_orig_dtype"], @@ -355,7 +355,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Lazy import to avoid circular dependency from torchao.float8.float8_ops import FLOAT8_OPS_TABLE - # All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs + # All ops in the FLOAT8_OPS_TABLE expect Float8TrainingTensor as inputs # And don't support mixed tensor subclasses. This will trigger the handler for # the next type in the dispatch list def allowed_subclasses(type): @@ -374,5 +374,5 @@ def allowed_subclasses(type): return FLOAT8_OPS_TABLE[func](func, args, kwargs) raise NotImplementedError(f"attempting to run {func}, this is not supported") - # Do not force the Float8Tensor type on the returned tensor + # Do not force the Float8TrainingTensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 625fb29235..5cb93ac0a0 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -144,8 +144,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x: The original tensor. y: The tensor to compare to the original tensor. """ - Ps = torch.norm(x) - Pn = torch.norm(x - y) + Ps = torch.linalg.vector_norm(x) + Pn = torch.linalg.vector_norm(x - y) return 20 * torch.log10(Ps / Pn) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 7b24dc2b53..79e62c7e10 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -15,8 +15,8 @@ from torchao.float8.float8_scaling_utils import ( hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( - Float8Tensor, +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, GemmInputRole, LinearMMConfig, hp_tensor_and_scale_to_float8, @@ -39,6 +39,10 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: from torchao.float8.float8_linear import Float8Linear + torch._C._log_api_usage_once( + "torchao.float8.precompute_float8_dynamic_scale_for_fsdp" + ) + float8_linears: List[Float8Linear] = [ m for m in module.modules() @@ -217,7 +221,7 @@ def __repr__(self): def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: - float8_tensor = hp_tensor_and_scale_to_float8( + float8_training_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, self._dtype, @@ -225,7 +229,7 @@ def fsdp_pre_all_gather(self, mesh): GemmInputRole.WEIGHT, ) else: - float8_tensor = hp_tensor_to_float8_dynamic( + float8_training_tensor = hp_tensor_to_float8_dynamic( self._tensor, self._dtype, self._linear_mm_config, @@ -233,7 +237,7 @@ def fsdp_pre_all_gather(self, mesh): gemm_input_role=GemmInputRole.WEIGHT, device_mesh=mesh, ) - return (float8_tensor._data,), (float8_tensor._scale,) + return (float8_training_tensor._data,), (float8_training_tensor._scale,) def fsdp_post_all_gather( self, @@ -248,21 +252,25 @@ def fsdp_post_all_gather( if out is not None: from torch.distributed._tensor import DTensor - if isinstance(out, Float8Tensor): + if isinstance(out, Float8TrainingTensor): out._scale = scale elif isinstance(out, DTensor) and isinstance( - out._local_tensor, Float8Tensor + out._local_tensor, Float8TrainingTensor ): out._local_tensor._scale = scale else: raise RuntimeError( - f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}" + f"out must be a Float8TrainingTensor or DTensor(_local_tensor=Float8TrainingTensor), but got {out}" ) return - return Float8Tensor( + return Float8TrainingTensor( data, scale, param_dtype, self._linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ), (data,) + + +# Needed to allowlist this subclass for deserialization used for restoring checkpoints. +torch.serialization.add_safe_globals([WeightWithDynamicFloat8CastTensor]) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 144f1fa6f2..f15d38576c 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -7,7 +7,7 @@ Defines an nn module designed to be used during inference """ -from typing import NamedTuple, Optional, Tuple, Union +from typing import List, NamedTuple, Optional, Tuple, Union import torch @@ -67,6 +67,24 @@ def preprocess_data( return a_data, b_data +def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int, ...]): + """Ensures input tensor is correctly formatted for _scaled_mm""" + + # For PerTensor quantization, scale should be a scalar or have shape [1] + if input_scale.numel() == 1: + # Already a scalar, ensure it has the right shape for _scaled_mm + return input_scale.reshape(1, 1) + + # For per-row/block quantization, we need to handle the reshaping + input_scale = input_scale.unsqueeze(-1) + + # Match: #input_data.reshape(-1, input_data.shape[-1]) + if input_scale.dim() > 2: + input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + + return input_scale + + def addmm_float8_unwrapped_inference( a_data: Tensor, a_scale: Tensor, @@ -78,7 +96,7 @@ def addmm_float8_unwrapped_inference( use_fast_accum: bool = False, ) -> Tensor: """ - This is the unwrapped version of addmm_float8, which does not take in Float8Tensors + This is the unwrapped version of addmm_float8, which does not take in Float8TrainingTensors as inputs. This is used to standardize the logic between subclassed and non subclassed versions of the linear module. """ @@ -107,12 +125,75 @@ def addmm_float8_unwrapped_inference( ) -def _is_rowwise_scaled(x) -> bool: - """Checks if an AQT tensor is rowwise scaled +def _slice_scale_for_dimension( + scale: torch.Tensor, + data_shape: List[int], + dim: int, + start: int, + end: int, + step: int, +) -> torch.Tensor: + """ + Slice the scale tensor appropriately based on the data tensor slicing. + This function calculates how the scale should be sliced when the data tensor + is sliced along a given dimension, taking into account the block structure. + """ + aten = torch.ops.aten + + # Unsupported case for now, this would be 1 scale per data element + if scale.shape == data_shape: + return aten.slice.Tensor(scale, dim, start, end, step) + + # Reconstruct block sizes based on data shape and scale shape + block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape))) + + if dim >= len(block_sizes): + # Slicing beyond the dimensions we care about + return scale + + block_size_for_dim = block_sizes[dim] + + if block_size_for_dim == 1: + # Scale is per-element along this dimension + # Slice away as normal + return aten.slice.Tensor(scale, dim, start, end, step) + else: + # There is blocking in this dimension + # Calculate which scale elements correspond to the sliced data + scale_start = start // block_size_for_dim if start is not None else None + scale_end = ( + (end + block_size_for_dim - 1) // block_size_for_dim + if end is not None + else None + ) + + # Error on Step > 1 + if step > 1: + raise NotImplementedError( + "Slicing with step > 1 is not implemented for scale tensors." + ) + + return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) + + +def _is_rowwise_scaled(x: torch.Tensor) -> bool: + """Checks if a quantized tensor is rowwise scaled + Args: + x: quantized tensor (should have `block_size` attribute) + """ + assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" + return tuple(x.block_size) == (1,) * (x.dim() - 1) + (x.shape[-1],) + + +def _is_tensorwise_scaled(x: torch.Tensor) -> bool: + """Checks if a quantized tensor is rowwise scaled Args: - x: AffineQuantizedTensor tensor + x: quantized tensor (should have `block_size` attribute) """ - return x.block_size == (1,) * (x.dim() - 1) + (x.shape[-1],) + assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" + return all( + x.block_size[i] == -1 or x.block_size[i] == x.shape[i] for i in range(x.ndim) + ) def _normalize_granularity( diff --git a/torchao/kernel/bsr_triton_ops.py b/torchao/kernel/bsr_triton_ops.py index 18cfba9ad9..4d80c4c577 100644 --- a/torchao/kernel/bsr_triton_ops.py +++ b/torchao/kernel/bsr_triton_ops.py @@ -9,15 +9,7 @@ from typing import Optional import torch - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - -if TORCH_VERSION_AT_LEAST_2_4: - from torch._dynamo.utils import warn_once -else: - import warnings - - warn_once = warnings.warn +from torch._dynamo.utils import warn_once from torch.sparse._triton_ops import ( broadcast_batch_dims, launch_kernel, diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 2f064b3f2f..292b67380d 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -7,18 +7,16 @@ import os import torch +from torch._dynamo import is_compiling as dynamo_is_compiling +from torch._higher_order_ops.out_dtype import out_dtype -from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, check_cpu_version +from torchao.utils import check_cpu_version logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) try: - # Only works for torch2.2 or newer. - if TORCH_VERSION_AT_LEAST_2_2: - from torchao.kernel import intmm_triton - else: - intmm_triton = None + from torchao.kernel import intmm_triton except ImportError: logger.warning( "Warning: Detected no triton, on systems without Triton certain kernels will not work" @@ -28,85 +26,63 @@ AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0))) -# torch._int_mm doesn't exist before 2.2 -if TORCH_VERSION_AT_LEAST_2_2: - from torch._dynamo import is_compiling as dynamo_is_compiling - from torch._higher_order_ops.out_dtype import out_dtype - - def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - """ - Performs a safe integer matrix multiplication, considering different paths for - torch.compile, cublas, and fallback cases. - - Args: - input (torch.Tensor): The input tensor of shape [i, j]. - mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. - - Returns: - torch.Tensor: The result of the matrix multiplication. - - Raises: - AssertionError: If the tensors are not on the same device. - """ - # torch.compile path - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - if input.device.type == "cpu": - # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend - return out_dtype( - torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float() - ) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - - # error checking for cublas path - assert mat2.device == input.device, ( - f"need both tensors to be on the same device but got {mat2.device} and {input.device}" - ) - device_cpu = "cpu" in [mat2.device.type, input.device.type] - # with input.shape = [i,j] and mat2.shape = [j,k] - j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) - k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) - bad_dimensions_for_cublas = not ( - j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 - ) - if device_cpu or bad_dimensions_for_cublas: - # fallback path - return torch.matmul( - input.cpu().to(torch.int32), mat2.cpu().to(torch.int32) - ).to(input.device.type) - - # cublas paths - if not mat2.is_contiguous(): # silently gives incorrect result without this - mat2 = mat2.contiguous() - if (not input.is_contiguous()) and ( - input.shape[0] % 8 != 0 - ): # gives cryptic error without this - input = ( - input.contiguous() - ) # (it seems the transpose makes cublas check the above j constraint on i) - try: - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - except Exception: - # fallback path, would run on H100 for float8 dtypes - # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' - return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( - torch.int32 +def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + """ + Performs a safe integer matrix multiplication, considering different paths for + torch.compile, cublas, and fallback cases. + + Args: + input (torch.Tensor): The input tensor of shape [i, j]. + mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. + + Returns: + torch.Tensor: The result of the matrix multiplication. + + Raises: + AssertionError: If the tensors are not on the same device. + """ + # torch.compile path + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + if input.device.type == "cpu": + # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend + return out_dtype( + torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float() ) -else: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - """ - Performs a fallback integer matrix multiplication for torch versions before 2.2. + # error checking for cublas path + assert mat2.device == input.device, ( + f"need both tensors to be on the same device but got {mat2.device} and {input.device}" + ) + device_cpu = "cpu" in [mat2.device.type, input.device.type] + # with input.shape = [i,j] and mat2.shape = [j,k] + j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) + k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) + bad_dimensions_for_cublas = not ( + j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 + ) - Args: - input (torch.Tensor): The input tensor of shape [i, j]. - mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. + if device_cpu or bad_dimensions_for_cublas: + # fallback path + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + input.device.type + ) - Returns: - torch.Tensor: The result of the matrix multiplication in int32. - """ - # We can improve on this by writing Triton code that works for older versions of Triton - # that ship with 2.1 or 2.0. + # cublas paths + if not mat2.is_contiguous(): # silently gives incorrect result without this + mat2 = mat2.contiguous() + if (not input.is_contiguous()) and ( + input.shape[0] % 8 != 0 + ): # gives cryptic error without this + input = ( + input.contiguous() + ) # (it seems the transpose makes cublas check the above j constraint on i) + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except Exception: + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( torch.int32 ) diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 1a516a7163..6f657cdfd8 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -10,7 +10,6 @@ import triton.language as tl from torchao.kernel.autotuner import get_best_config_fn -from torchao.utils import TORCH_VERSION_AFTER_2_5 # TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option int8_mm_kernel_configs = sum( @@ -38,16 +37,15 @@ [], ) -if TORCH_VERSION_AFTER_2_5: - if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE": - int8_mm_kernel_configs = [ - (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) - for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( - [16, 32, 64, 128, 256], repeat=3 - ) - for num_stages in [1, 2, 3, 4, 5, 6, 7, 8] - for num_warps in [2, 4, 8] - ] +if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE": + int8_mm_kernel_configs = [ + (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5, 6, 7, 8] + for num_warps in [2, 4, 8] + ] # Baseline configs from pytorch/pytorch diff --git a/torchao/ops.py b/torchao/ops.py index babe5506c0..b6348f90a5 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -9,8 +9,6 @@ import torch from torch import Tensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - lib = torch.library.Library("torchao", "FRAGMENT") lib.define( "quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor" @@ -70,24 +68,21 @@ lib.define( "da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" ) +lib.define( + "_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor" +) def register_custom_op(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.register_fake(f"{name}")(func) - else: - return torch.library.impl_abstract(f"{name}")(func) + return torch.library.register_fake(f"{name}")(func) return decorator def register_custom_op_impl(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.custom_op(f"{name}", mutates_args=())(func) - else: - return torch.library.impl(f"{name}", "CUDA")(func) + return torch.library.custom_op(f"{name}", mutates_args=())(func) return decorator @@ -1106,3 +1101,19 @@ def _( assert weight.dim() == 4 N = weight.size(0) * weight.size(3) * 2 return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) + + +@register_custom_op("torchao::_scaled_embedding_bag") +def _( + qweight: Tensor, + indices: Tensor, + offsets: Tensor, + w_scales: Tensor, + o_scale: float, + mode: int, + include_last_offset: bool, +) -> Tensor: + # Only support include_last_offset == True + assert include_last_offset == True + batch_size = offsets.shape[0] - 1 + return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype) diff --git a/torchao/optim/adam.py b/torchao/optim/adam.py index ddbdc8b12f..8beaffb627 100644 --- a/torchao/optim/adam.py +++ b/torchao/optim/adam.py @@ -39,7 +39,7 @@ def __init__( if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict( - lr=torch.tensor(lr), + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, @@ -50,6 +50,14 @@ def __init__( self.bf16_stochastic_round = bf16_stochastic_round self.is_adamw = is_adamw + def add_param_group(self, param_group: dict) -> None: + super().add_param_group(param_group) + + # convert LR to a tensor + group = self.param_groups[-1] + if not isinstance(group["lr"], Tensor): + group["lr"] = torch.tensor(group["lr"], dtype=torch.float32) + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: @@ -225,6 +233,7 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=False, ) + torch._C._log_api_usage_once("torchao.optim.Adam8bit") @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -255,6 +264,7 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=False, ) + torch._C._log_api_usage_once("torchao.optim.Adam4bit") @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -285,6 +295,7 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=False, ) + torch._C._log_api_usage_once("torchao.optim.AdamFp8") @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -315,6 +326,7 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=True, ) + torch._C._log_api_usage_once("torchao.optim.AdamW8bit") @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -345,6 +357,7 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=True, ) + torch._C._log_api_usage_once("torchao.optim.AdamW4bit") @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -375,6 +388,7 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=True, ) + torch._C._log_api_usage_once("torchao.optim.AdamWFp8") @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): diff --git a/torchao/optim/cpu_offload.py b/torchao/optim/cpu_offload.py index cca55749db..53acd4057f 100644 --- a/torchao/optim/cpu_offload.py +++ b/torchao/optim/cpu_offload.py @@ -8,7 +8,7 @@ import torch from torch.optim.optimizer import Optimizer, ParamsT -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, get_available_devices +from torchao.utils import get_available_devices # NOTE: We make this inherit Optimizer so it works with PyTorch's built-in LR @@ -36,11 +36,7 @@ def __init__( kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`. """ # default to fused CPU AdamW - if ( - optimizer_class is torch.optim.AdamW - and TORCH_VERSION_AT_LEAST_2_4 - and "fused" not in kwargs - ): + if optimizer_class is torch.optim.AdamW and "fused" not in kwargs: kwargs.update(fused=True) param_groups = list(params) diff --git a/torchao/optim/subclass_4bit.py b/torchao/optim/subclass_4bit.py index bc5fd33414..82bb6a3788 100644 --- a/torchao/optim/subclass_4bit.py +++ b/torchao/optim/subclass_4bit.py @@ -7,13 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor from .quant_utils import ( create_dynamic_map, @@ -113,25 +110,6 @@ def __repr__(self): ) -# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when -# dtype is the same but device is different. thus, we must override .to() method instead. -if not TORCH_VERSION_AT_LEAST_2_4: - - def _to(self, *args, **kwargs): - # ignore other args/kwargs - device = kwargs.pop("device", None) - return OptimState4bit( - self.codes.to(device), - self.scale.to(device), - self.qmap.to(device), - self.signed, - self.shape, - ) - - OptimState4bit.to = _to - del _to # make sure to not re-use - - @OptimState4bit.implements(aten.copy_.default) def _(func, types, args, kwargs): dst = args[0] @@ -268,7 +246,4 @@ def _(func, types, args, kwargs): return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimState4bit]) +add_safe_globals([OptimState4bit]) diff --git a/torchao/optim/subclass_8bit.py b/torchao/optim/subclass_8bit.py index d3f7634526..bbc6cfa958 100644 --- a/torchao/optim/subclass_8bit.py +++ b/torchao/optim/subclass_8bit.py @@ -7,13 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor from .quant_utils import ( create_dynamic_map, @@ -101,24 +98,6 @@ def __repr__(self): ) -# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when -# dtype is the same but device is different. thus, we must override .to() method instead. -if not TORCH_VERSION_AT_LEAST_2_4: - - def _to(self, *args, **kwargs): - # ignore other args/kwargs - device = kwargs.pop("device", None) - return OptimState8bit( - self.codes.to(device), - self.scale.to(device), - self.qmap.to(device), - self.signed, - ) - - OptimState8bit.to = _to - del _to # make sure to not re-use - - @OptimState8bit.implements(aten.copy_.default) def _(func, types, args, kwargs): dst = args[0] @@ -237,7 +216,4 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimState8bit]) +add_safe_globals([OptimState8bit]) diff --git a/torchao/optim/subclass_fp8.py b/torchao/optim/subclass_fp8.py index 1ae670dd6d..e898932138 100644 --- a/torchao/optim/subclass_fp8.py +++ b/torchao/optim/subclass_fp8.py @@ -7,9 +7,10 @@ import torch from torch import Tensor +from torch.serialization import add_safe_globals from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor +from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -192,7 +193,4 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - from torch.serialization import add_safe_globals - - add_safe_globals([OptimStateFp8]) +add_safe_globals([OptimStateFp8]) diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md index 18f3663427..a67b3be9f0 100644 --- a/torchao/prototype/autoround/README.md +++ b/torchao/prototype/autoround/README.md @@ -78,7 +78,7 @@ multi_t_input_ids = MultiTensor(input_ids_lst) out = model(multi_t_input_ids) ``` #### Step 3: Finalize Quantization -After obtaining optimized `zero_point` and `scale` values, create the `AffineQuantizedTensor` +After obtaining optimized `zero_point` and `scale` values, create the `AffineQuantizedTensor` for each target weight to select the right low-bits kernel. ```python @@ -114,7 +114,7 @@ quantize_(model, apply_auto_round(), is_target_module) | autoround-4bit* | 0.6338 | 0.4566 | 0.7661 | 0.6646 | 0.5688 | 0.7130 | > [!NOTE] -> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`int4_weight_only(group_size=128)`) while leaving the `lm-head` unquantized.
+> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`Int4WeightOnlyConfig(group_size=128, version=1)`) while leaving the `lm-head` unquantized.
> - `auto-round-4bit` uses the deafult configuration from [quick start](#quick-start).
> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `batch_size=4`, which accumulating two batches(4 samples per batch) before performing the backward pass.
> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`. diff --git a/torchao/prototype/autoround/autoround_llm.py b/torchao/prototype/autoround/autoround_llm.py index 822ee6554b..8d29fe3388 100644 --- a/torchao/prototype/autoround/autoround_llm.py +++ b/torchao/prototype/autoround/autoround_llm.py @@ -88,7 +88,7 @@ def main(args): # Get the model, tokenizer, and decoder_cls model_name_or_path = args.model_name_or_path model, tokenizer, decoder_cls = ar_utils.get_float_model_info( - model_name_or_path, torch_dtype=torch.bfloat16 + model_name_or_path, dtype=torch.bfloat16 ) # Disable the `use_cache` for calibration stage. model.config.use_cache = False diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index 16c1736843..4846f919cc 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -12,7 +12,6 @@ import torchao import torchao.prototype.autoround.utils as ar_utils import torchao.quantization -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 logger = logging.getLogger(__name__) @@ -87,7 +86,7 @@ def main(args): with torch.no_grad(): model_name_or_path = args.model_name_or_path model, tokenizer, decoder_cls = ar_utils.get_float_model_info( - model_name_or_path, torch_dtype=torch.bfloat16 + model_name_or_path, dtype=torch.bfloat16 ) model.eval() model_device = args.model_device @@ -102,25 +101,28 @@ def main(args): # Evaluate the quantized model if args.woq_int4: msg += " (int4wo)" - from torchao.quantization import int4_weight_only, quantize_ + from torchao.quantization import Int4WeightOnlyConfig, quantize_ quantize_( model, - int4_weight_only(group_size=args.group_size), + Int4WeightOnlyConfig(group_size=args.group_size, version=1), filter_fn=filter_fn, device=model_device, ) elif args.uintx: msg += f" (uintx {args.bits} bits)" from torchao.dtypes.uintx.uintx import _BIT_WIDTH_TO_DTYPE - from torchao.quantization.quant_api import quantize_, uintx_weight_only + from torchao.quantization.quant_api import ( + UIntXWeightOnlyConfig, + quantize_, + ) bits = args.bits assert bits in _BIT_WIDTH_TO_DTYPE, f"Invalid bits: {bits}" dtype = _BIT_WIDTH_TO_DTYPE[bits] quantize_( model, - uintx_weight_only(dtype=dtype, group_size=args.group_size), + UIntXWeightOnlyConfig(dtype=dtype, group_size=args.group_size), filter_fn=filter_fn, device=model_device, ) @@ -165,7 +167,7 @@ def main(args): bench_accuracy(model, tokenizer, tasks=args.tasks, msg=msg) -if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available(): +if __name__ == "__main__" and torch.cuda.is_available(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) diff --git a/torchao/prototype/autoround/utils.py b/torchao/prototype/autoround/utils.py index 0ca0d83fd3..bac1c494ed 100644 --- a/torchao/prototype/autoround/utils.py +++ b/torchao/prototype/autoround/utils.py @@ -140,11 +140,11 @@ def _auto_detect_decoder_cls(model): return type(first_module) -def get_float_model_info(model_name_or_path, torch_dtype=torch.float32): +def get_float_model_info(model_name_or_path, dtype=torch.float32): import transformers model = transformers.AutoModelForCausalLM.from_pretrained( - model_name_or_path, torch_dtype=torch_dtype + model_name_or_path, dtype=dtype ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) decoder_cls = _auto_detect_decoder_cls(model) diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 570b0821d4..cd5c447d4c 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,8 +1,8 @@ -from .api import awq_uintx, insert_awq_observer_ -from .core import AWQObservedLinear +from .api import AWQConfig +from .core import AWQObservedLinear, AWQStep __all__ = [ - "awq_uintx", - "insert_awq_observer_", "AWQObservedLinear", + "AWQConfig", + "AWQStep", ] diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 5806c29ce6..918b7a1817 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -3,185 +3,114 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import logging import types from dataclasses import dataclass -from typing import Optional import torch -import torchao from torchao.core.config import AOBaseConfig -from torchao.dtypes import ( - Int4XPULayout, - Layout, - TensorCoreTiledLayout, - to_affine_quantized_intx, -) -from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout -from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata -from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_api import ( _linear_extra_repr, - _replace_with_custom_fn_if_matches_filter, -) -from torchao.quantization.quant_primitives import ( - _DTYPE_TO_QVALUE_BOUNDS, - MappingType, - ZeroPointDomain, ) +from torchao.quantization.quantize_.common import SupportsActivationPreScaling from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) +from torchao.utils import DummyModule from .core import ( AWQObservedLinear, AWQObserver, + AWQStep, ) -assert len(_DTYPE_TO_BIT_WIDTH) > 0, ( - "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+" -) - - -def insert_awq_observer_( - model: torch.nn.Module, - n_validation_examples: int, - validation_sequence_len: int, - quant_dtype: torch.dtype = torch.uint4, - scale_search_space_size: int = 20, - group_size: int = 128, -): - """ - Inserts AWQObserver into Linear layers of a given model. - - Args: - model: The model to be modified (in place). Ensure model is on the desired device for calibration - n_validation_examples: Number of examples used to validate scale options - validation_sequence_len: Number of tokens in each validation example - quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 - scale search space size: how many different scale options to try. Original AWQ implementation uses 20. A larger size can lead to better results but takes longer to calibrate - group_size: Quantization granularity. Use -1 for channel wise quantization - """ - _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, ( - "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" - ) - # AQT config - mapping_type = MappingType.ASYMMETRIC - quantization_granularity = PerGroup(group_size) - quant_min = 0 - quant_max = ( - 255 if quant_dtype == torch.uint8 else 2 ** _DTYPE_TO_BIT_WIDTH[quant_dtype] - 1 - ) - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - - def replace_with_observer(layer): - # creates observer and replaces linear layers with AWQObservedLinear layers - observer = AWQObserver( - layer.weight, - layer.bias, - quantization_granularity, - mapping_type, - quant_dtype, - n_validation_examples, - validation_sequence_len, - scale_search_space_size, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - zero_point_dtype=zero_point_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) - return AWQObservedLinear.from_float(layer, observer) - - _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) +logger = logging.getLogger(__name__) @dataclass -class AWQUIntXConfig(AOBaseConfig): +class AWQConfig(AOBaseConfig): """ Configuration for quantizing linear layers when passed into quantize_() Args: - quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 - `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` - group_size: Quantization granularity. Use -1 for channel wise quantization - weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used - set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. + base_config (AOBaseConfig): The quantization config that we can apply awq on top of, e.g. 8da4w, int4 weight only + step (AWQStep): specifies the step for AWQ, one of PREPARE, CONVERT and PREPARE_FOR_LOADING indicating the step of AWQ process + PREPARE: insert AWQ Observers to linear + CONVERT: convert the observed linear modules to linear modules with awq quantized weights + PREPARE_FOR_LOADING: convert the floating point model to a dummy awq quantized model, so we can + load the quantized weights through copy_ later + can use the corresponding string "prepare", "convert", "prepare_for_loading" for simplicity + scale_search_space_size (int): the number of scales to search for """ - quant_dtype: torch.dtype = torch.uint4 - layout: Optional[Layout] = TensorCoreTiledLayout(inner_k_tiles=8) - group_size: int = 64 - use_hqq: bool = False - set_inductor_config: bool = True - + base_config: AOBaseConfig + step: AWQStep + scale_search_space_size: int = 20 -# for bc -awq_uintx = AWQUIntXConfig + def __post_init__(self): + self.step = self.step.lower() + all_step_values = [s.value for s in AWQStep] + if self.step not in all_step_values: + raise ValueError(f"{self.step} is not one of {all_step_values}") -@register_quantize_module_handler(AWQUIntXConfig) -def _awq_uintx_transform( +@register_quantize_module_handler(AWQConfig) +def _awq_transform( module: torch.nn.Module, - config: AWQUIntXConfig, + config: AWQConfig, ) -> torch.nn.Module: - quant_dtype = config.quant_dtype - group_size = config.group_size - use_hqq = config.use_hqq - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - observed_linear = module - - assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, ( - "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" - ) + step = config.step + scale_search_space_size = config.scale_search_space_size + observed_linear = None + base_config = config.base_config - equalization_scale = observed_linear.act_obs.calculate_qparams() - # AQT config - if quant_dtype == torch.uint4: - target_dtype = torch.int32 - eps = 1e-6 - preserve_zero = False - _layout = config.layout - if isinstance(_layout, Int4XPULayout): - zero_point_dtype = torch.int8 - zero_point_domain = ZeroPointDomain.INT - else: - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + if step == AWQStep.PREPARE: + observer = AWQObserver( + module.weight, + module.bias, + base_config, + scale_search_space_size, + ) + return AWQObservedLinear.from_float(module, observer) + elif step == AWQStep.PREPARE_FOR_LOADING: + # loading from pre-quantized checkpoint + observer = AWQObserver( + module.weight, + module.bias, + base_config, + scale_search_space_size, + ) + observed_linear = AWQObservedLinear.from_float(module, observer) + example_input = torch.randn( + (1, module.weight.shape[1]), + device=module.weight.device, + dtype=module.weight.dtype, + ) + observed_linear(example_input) else: - target_dtype = torch.uint8 - eps = torch.finfo(torch.float32).eps - preserve_zero = True - zero_point_dtype = torch.int64 - zero_point_domain = ZeroPointDomain.INT - _layout = UintxLayout(quant_dtype) + assert step == AWQStep.CONVERT, f"Unexpected step: {step}" + if not isinstance(module, AWQObservedLinear): + logger.info( + f"convert: module is not AWQObservedLinear, skipping: {type(module)}" + ) + return module + observed_linear = module + + assert observed_linear is not None + equalization_scale = observed_linear.act_obs.calculate_qparams() - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] - qw = to_affine_quantized_intx( - observed_linear.weight * equalization_scale, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - _layout=_layout, - use_hqq=use_hqq, + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + dummy_mod = DummyModule(observed_linear.weight * equalization_scale) + quant_mod = base_config_handler(dummy_mod, config.base_config) + qw = quant_mod.weight + assert isinstance(qw, SupportsActivationPreScaling), ( + "weight must support activation scaling through implementing `SupportsActivationPreScaling`" ) - - qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) + # since we want to do `act` * `act_pre_scale` during runtime for speed, we'll save the + # reciprocal of the `equalization_scale` + qw.act_pre_scale = 1.0 / equalization_scale linear = torch.nn.Linear( observed_linear.in_features, @@ -191,6 +120,6 @@ def _awq_uintx_transform( dtype=observed_linear.weight.dtype, ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) - linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.extra_repr = types.MethodType(_linear_extra_repr, linear) linear.bias = observed_linear.bias return linear diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index e5ee96fea2..c26a036733 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -3,145 +3,94 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from enum import Enum from typing import Optional import torch import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_intx -from torchao.dtypes.uintx.uintx_layout import UintxLayout -from torchao.quantization.granularity import Granularity -from torchao.quantization.observer import ( - AffineQuantizedObserverBase, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, ) +from torchao.utils import DummyModule + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class AWQStep(str, Enum): + PREPARE = "prepare" + CONVERT = "convert" + PREPARE_FOR_LOADING = "prepare_for_loading" + +@torch.no_grad() +def get_act_scale(x): + return x.abs().view(-1, x.shape[-1]).mean(0) -class AWQObserver(AffineQuantizedObserverBase): + +class AWQObserver(torch.nn.Module): def __init__( self, weight: torch.Tensor, - bias: torch.Tensor, - quantization_granularity: Granularity, - mapping_type: MappingType, - target_dtype: torch.dtype, - n_validation_examples: int, - validation_sequence_len: int, + bias: Optional[torch.Tensor], + base_config: AOBaseConfig, scale_search_space_size: int = 20, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero: Optional[bool] = True, - zero_point_domain=ZeroPointDomain.INT, ): """ A custom observer for Activation aware Weight Quantization (AWQ) + Note: this only applies to weight only quantization: https://github.com/pytorch/ao/issues/2388#issuecomment-3062863647 Args: - weight: The weight tensor to be observed. - bias: The bias tensor to be observed. - quantization_granularity: Granularity which specifies how many weights share the same scale/zero point - input_dtype: The data type of the input tensor. - mapping_type: Always set to asymmetric - target_dtype: The target data type of the quantized tensor - n_validation_examples: Number of examples used to calibrate observer - validation_sequence_len: Number of tokens in each example - scale_search_space_size: The number of scales to search for. - quant_min: The minimum quantized value - quant_max: The maximum quantized value - eps: The minimum scale. - scale_dtype: The data type of the scale tensor. - zero_point_dtype: The data type of the zero point tensor. - preserve_zero: A flag to indicate whether we need zero to be exactly - representable or not. - zero_point_domain: The domain of the zero point. + weight (torch.Tensor: The weight tensor to be observed. + bias (Optional[torch.Tensor]): The bias tensor to be observed. + config (AOBaseConfig): the configuration for quantize_, that we'll use to apply awq on top of + scale_search_space_size (int): search space size for searching the best scale for weight and input activation """ - super().__init__( - mapping_type, - target_dtype, - quantization_granularity, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - ) - self.quantization_granularity = quantization_granularity + super().__init__() + self.base_config = base_config self.weight = weight self.bias = bias - self.n_validation_examples = n_validation_examples - self.validation_sequence_len = validation_sequence_len - self.calibration_token_count = 0 self.inputs = [] - self.outputs = [] self.scale_options = scale_search_space_size self.device = self.weight.device - self.average = torch.zeros((1, weight.shape[1]), device=self.device) if self.bias is not None: self.bias.to(self.device) @torch.no_grad() def forward(self, input: torch.Tensor, output: torch.Tensor): - # import pdb - # pdb.set_trace() - # print(input.shape, input.abs().sum(1).shape, self.average.shape) - if len(self.inputs) < self.n_validation_examples: - self.inputs.append(input.to("cpu")) - self.outputs.append(output.to("cpu")) - self.calibration_token_count += input.shape[-2] - self.average += input.abs().sum(-2) + self.inputs.append(input.to("cpu")) def calculate_qparams(self): - # import pdb - # pdb.set_trace() - assert self.outputs != None, ( + assert self.inputs != None, ( "calibrate observer first by running model on exemplar data" ) - self.average /= self.calibration_token_count - for i in range(self.n_validation_examples): + for i in range(len(self.inputs)): self.inputs[i] = self.inputs[i].to(self.device) - self.outputs[i] = self.outputs[i].to(self.device) + if self.bias is not None: + self.bias = self.bias.to(self.device) + + acc = torch.cat(self.inputs, dim=-2) + x_max = get_act_scale(acc) best_loss = float("inf") best_scales = None for i in range(self.scale_options): ratio = i * 1 / self.scale_options - scales = self.average.pow(ratio).to(self.weight.dtype) + scales = x_max.pow(ratio).to(self.weight.dtype).clamp(min=1e-4).view(-1) + if best_scales is None: + best_scales = scales scales = scales / (scales.max() * scales.min()).sqrt() - layout = UintxLayout(self.target_dtype) - # regardless of weight dtype, we have to store as packed uint8 tensors - tensor_dtype = torch.uint8 - w = to_affine_quantized_intx( - self.weight * scales, - self.mapping_type, - (1, self.quantization_granularity.group_size), - tensor_dtype, - quant_min=self.quant_min, - quant_max=self.quant_max, - eps=self.eps, - scale_dtype=self.scale_dtype, - zero_point_dtype=self.zero_point_dtype, - preserve_zero=self.preserve_zero, - zero_point_domain=self.zero_point_domain, - _layout=layout, - ) - loss = 0 - for i in range(self.n_validation_examples): - q_out = F.linear(self.inputs[i] / scales, w, self.bias) - loss += (self.outputs[i] - q_out).pow(2).mean().item() + config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.base_config)] + dummy_mod = DummyModule(self.weight * scales) + quant_mod = config_handler(dummy_mod, self.base_config) + w = quant_mod.weight + orig_out = F.linear(acc, self.weight, self.bias) + q_out = F.linear(acc / scales, w, self.bias) + loss = (orig_out - q_out).pow(2).mean().item() if loss < best_loss: best_scales = scales best_loss = loss - for i in range(self.n_validation_examples): - self.inputs[i].to("cpu") - self.outputs[i].to("cpu") return best_scales.detach() diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 7ff6092b05..2750c42b3a 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -6,14 +6,18 @@ import argparse import time +import lm_eval import torch from datasets import load_dataset +from lm_eval import evaluator +from lm_eval.models.huggingface import HFLM from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig -from torchao.dtypes import Int4XPULayout -from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ -from torchao.quantization import int4_weight_only, quantize_ +from torchao.prototype.awq import ( + AWQConfig, +) +from torchao.quantization import Int4WeightOnlyConfig, quantize_ # adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 @@ -90,8 +94,9 @@ def wiki2_eval( # adapted from Hicham Badri (@mobicham) -def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): - import lm_eval +def benchmark( + model, tokenizer, max_length, tasks=None, evaluation_limit=None, device="cuda" +): import numpy as np model.eval() @@ -100,7 +105,7 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): lm_eval.tasks.initialize_tasks() except: pass - model_eval = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer) + model_eval = HFLM(pretrained=model, tokenizer=tokenizer) eval_batch_size = 1 # 8 if tasks is None: tasks = [ @@ -111,6 +116,7 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): "hellaswag", "gsm8k", "mmlu", + "bbh", ] results = {} if "PPL" in tasks: @@ -121,22 +127,34 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): if "truthfulqa_mc2" in tasks: for task in [("truthfulqa_mc2", 0)]: tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + results[tag] = evaluator.simple_evaluate( + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) if "winogrande" in tasks: for task in [("winogrande", 5)]: tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + results[tag] = evaluator.simple_evaluate( + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) if "arc_challenge" in tasks: for task in [("arc_challenge", 25)]: tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + results[tag] = evaluator.simple_evaluate( + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) @@ -144,15 +162,23 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): if "hellaswag" in tasks: for task in [("hellaswag", 10)]: tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + results[tag] = evaluator.simple_evaluate( + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) if "gsm8k" in tasks: for task in [("gsm8k", 5)]: tag, fewshot = task - results[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + results[tag] = evaluator.simple_evaluate( + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results[tag]) # ############################################ @@ -162,8 +188,12 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): results_mmlu = {} for task in [("mmlu", 5)]: tag, fewshot = task - results_mmlu[tag] = lm_eval.evaluator.simple_evaluate( - model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + results_mmlu[tag] = evaluator.simple_evaluate( + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, )["results"] print(tag, results_mmlu[tag]) @@ -180,20 +210,34 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): print("MMLU avg acc", np.mean(k)) results["mmlu"] = np.mean(k) + if "bbh" in tasks: + for task in [("leaderboard_bbh", 3)]: + tag, fewshot = task + results[tag] = evaluator.simple_evaluate( + model_eval, + tasks=[tag], + num_fewshot=fewshot, + batch_size=eval_batch_size, + limit=evaluation_limit, + )["results"] + print(tag, results[tag]) + results["bbh"] = results[tag] + return results -def wikitext2_ppl( +def quantize_and_eval( repo_id: str, quant: str, tasks: list[str], - calibration_size: int, - validation_size: int, + max_seq_length: int, + calibration_limit: int, + evaluation_limit: int, device: str, precision: torch.dtype, - sequence_length: int, compile: bool, model_save_path: str, + model_save_hf_hub_path: str, ): print(f"Loading model on {device}...") torch.manual_seed(34) @@ -201,65 +245,89 @@ def wikitext2_ppl( # load any model with torch.nn.linear layers tokenizer = AutoTokenizer.from_pretrained(repo_id) model = ( - AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision) - .eval() - .to(device) + AutoModelForCausalLM.from_pretrained(repo_id, dtype=precision).eval().to(device) ) print(f"Time to load model: {time.time() - t0:.02f} seconds") - if quant.startswith("awq"): - quant_dtype = quant.split("-")[1] + if quant.startswith("awq-int4wo"): group_size = int(quant.split("-")[2]) - quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) - print(f"running {quant_dtype} calibration") + print(f"running {quant} quantization with group size {group_size}") + + if device == "cuda": + base_config = Int4WeightOnlyConfig(group_size=group_size) + elif device == "xpu": + base_config = Int4WeightOnlyConfig( + group_size=group_size, int4_packing_format="plain_int32" + ) + elif device == "cpu": + base_config = Int4WeightOnlyConfig( + group_size=group_size, int4_packing_format="opaque" + ) + else: + assert False, "Unsupported device: {}".format(device) + print(f"running {quant} prepare and calibrate") t0 = time.time() - # insert observers to find average magnitude and calculate scales - insert_awq_observer_( + quant_config = AWQConfig(base_config, step="prepare") + + quantize_( model, - validation_size, - sequence_length, - quant_dtype=quant_dtype, - group_size=group_size, + quant_config, ) - calibration_data = get_calib_dataset( - tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length + from torchao._models._eval import TransformerEvalWrapper + + TransformerEvalWrapper( + model=model.to(device), + tokenizer=tokenizer, + max_seq_length=max_seq_length, + device=device, + ).run_eval( + tasks=tasks, + limit=calibration_limit, ) - for batch in calibration_data: - model(batch.to(device)) - batch.to("cpu") - print(f"time for calibration: {time.time() - t0:.02f} seconds") - - is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) - use_hqq = "hqq" in quant - print(f"running {quant_dtype} quantization") + + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + print(f"running {quant} convert") t0 = time.time() - awq_uintx_config = awq_uintx( - quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq - ) - if "xpu" in device: - awq_uintx_config.layout = Int4XPULayout() - quantize_( - model, - awq_uintx_config, - is_observed_linear, - ) - print(f"time for quantization: {time.time() - t0:.02f} seconds") - if model_save_path is not None: - print(f"Saving model to {model_save_path}") - torch.save(model, model_save_path) + quant_config = AWQConfig(base_config, step="convert") + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + quant_config = AWQConfig(base_config, step="prepare_for_loading") + model.config.quantization_config = TorchAoConfig(quant_config) + elif quant.startswith("int4wo"): group_size = int(quant.split("-")[1]) - use_hqq = "hqq" in quant print(f"running {quant} quantization with group size {group_size}") - int4_weight_only_config = int4_weight_only( - group_size=group_size, use_hqq=use_hqq - ) - if "xpu" in device: - int4_weight_only_config.layout = Int4XPULayout() - quantize_(model, int4_weight_only_config) + # TODO: enable after migration: https://github.com/pytorch/ao/issues/2752 + # use_hqq = "hqq" in quant + if device == "cuda": + base_config = Int4WeightOnlyConfig(group_size=group_size) + elif device == "cpu": + base_config = Int4WeightOnlyConfig( + group_size=group_size, int4_packing_format="opaque" + ) + else: + assert False, "Unsupported device: {}".format(device) + quantize_(model, base_config) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + if compile: model = torch.compile(model) - return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device) + return benchmark( + model, + tokenizer, + max_seq_length, + tasks=tasks, + evaluation_limit=evaluation_limit, + device=device, + ) if __name__ == "__main__": @@ -268,26 +336,30 @@ def wikitext2_ppl( ) # Optional arguments with default values - parser.add_argument("repo", type=str, help="Repository ID of the model.") + parser.add_argument("--repo", type=str, help="Repository ID of the model.") parser.add_argument( - "quant", + "--quant", type=str, - help="Quantization method. Options are either awq-uint- for x =[1..8], int4wo-, or int4wo--hqq.", + help="Quantization method. Options are either awq-int4wo-, or int4wo-.", ) parser.add_argument( "--tasks", - type=list[str], - help="Task to benchmark model on. Either PPL or QA", - default=["PPL"], + nargs="+", + type=str, + help="Task to benchmark model on. Here is the list of tasks you can use: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md", + default=["hellaswag"], ) parser.add_argument( - "--calibration_samples", + "--calibration_limit", type=int, default=10, help="Number of samples to use for calibration. Default is 10.", ) parser.add_argument( - "--validation_size", type=int, default=1, help="Validation size. Default is 1." + "--evaluation_limit", + type=int, + default=None, + help="Number of samples to use for evaluation. Default is None (all).", ) parser.add_argument( "--device", @@ -302,10 +374,10 @@ def wikitext2_ppl( help="Precision type. Default is 'bfloat16'.", ) parser.add_argument( - "--seq_len", + "--max_seq_length", type=int, - default=512, - help="Length of examples to calibrate and evaluate model on. Default is 512", + default=2048, + help="Maximum sequence length of examples to calibrate and evaluate model on. Default is 2048", ) parser.add_argument( "--compile", @@ -318,22 +390,29 @@ def wikitext2_ppl( default=None, help="Path to store the scale values.", ) + parser.add_argument( + "--model_save_hf_hub_path", + type=str, + default=None, + help="Huggingface hub path to store the quantized model and tokenizer.", + ) args = parser.parse_args() # Convert precision argument to torch dtype precision_dtype = getattr(torch, args.precision, torch.bfloat16) - ppl = wikitext2_ppl( + result = quantize_and_eval( args.repo, args.quant, args.tasks, - args.calibration_samples, - args.validation_size, + args.max_seq_length, + args.calibration_limit, + args.evaluation_limit, args.device, args.precision, - args.seq_len, args.compile, args.model_save_path, + args.model_save_hf_hub_path, ) - print(f"{args.quant} Results: {ppl}") + print(f"{args.quant} Results: {result}") diff --git a/torchao/prototype/blockwise_fp8/README.md b/torchao/prototype/blockwise_fp8_inference/README.md similarity index 100% rename from torchao/prototype/blockwise_fp8/README.md rename to torchao/prototype/blockwise_fp8_inference/README.md diff --git a/torchao/prototype/blockwise_fp8/__init__.py b/torchao/prototype/blockwise_fp8_inference/__init__.py similarity index 100% rename from torchao/prototype/blockwise_fp8/__init__.py rename to torchao/prototype/blockwise_fp8_inference/__init__.py diff --git a/torchao/prototype/blockwise_fp8/blockwise_linear.py b/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py similarity index 96% rename from torchao/prototype/blockwise_fp8/blockwise_linear.py rename to torchao/prototype/blockwise_fp8_inference/blockwise_linear.py index c25b946732..ebed3a84a4 100644 --- a/torchao/prototype/blockwise_fp8/blockwise_linear.py +++ b/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py @@ -7,7 +7,7 @@ import torch from torch import nn -from torchao.prototype.blockwise_fp8.blockwise_quantization import ( +from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, ) diff --git a/torchao/prototype/blockwise_fp8/blockwise_quantization.py b/torchao/prototype/blockwise_fp8_inference/blockwise_quantization.py similarity index 100% rename from torchao/prototype/blockwise_fp8/blockwise_quantization.py rename to torchao/prototype/blockwise_fp8_inference/blockwise_quantization.py diff --git a/torchao/prototype/blockwise_fp8_training/__init__.py b/torchao/prototype/blockwise_fp8_training/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/blockwise_fp8_training/kernels.py b/torchao/prototype/blockwise_fp8_training/kernels.py new file mode 100644 index 0000000000..3f82407d40 --- /dev/null +++ b/torchao/prototype/blockwise_fp8_training/kernels.py @@ -0,0 +1,901 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +from torchao.prototype.moe_training.utils import ( + _is_column_major, + _is_row_major, +) + +fp8_gemm_configs_max_autotune = [ + triton.Config( + {"BLOCK_SIZE_M": block_size, "BLOCK_SIZE_N": block_size}, + num_warps=num_warps, + num_stages=num_stages, + ) + for block_size in [64, 128, 256] + for num_warps in [4, 8] + for num_stages in [2] +] + +EPS = 1e-12 + + +@triton.autotune(configs=fp8_gemm_configs_max_autotune, key=["N", "K", "BLOCK_SIZE_K"]) +@triton.jit +def triton_fp8_gemm_1x128_128x128_kernel( + a_ptr, # (M, K) + a_stride_dim_0, + a_stride_dim_1, + b_ptr, # (K, N) + b_stride_dim_0, + b_stride_dim_1, + c_ptr, + c_stride_dim_0, + c_stride_dim_1, + a_s_ptr, # (M, K // block_size) reciprocals of scales + a_s_stride_dim_0, + a_s_stride_dim_1, + b_s_ptr, # (K // block_size, N // block_size) reciprocals of scales + b_s_stride_dim_0, + b_s_stride_dim_1, + M, + N: tl.constexpr, + K: tl.constexpr, + out_dtype: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + ( + offs_m[:, None] * a_stride_dim_0 + offs_k[None, :] * a_stride_dim_1 + ) + b_ptrs = b_ptr + ( + offs_k[:, None] * b_stride_dim_0 + offs_n[None, :] * b_stride_dim_1 + ) + + k_num_blocks = tl.cdiv(K, BLOCK_SIZE_K) + + # Scale base pointers start at row offsets for A, and column offsets for B. + a_s_base_ptr = a_s_ptr + offs_m * a_s_stride_dim_0 + b_s_base_ptr = b_s_ptr + (offs_n // BLOCK_SIZE_K) * b_s_stride_dim_1 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) + b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N) + for k in range(0, k_num_blocks): + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + # Reciprocal scales to scale back to dynamic range of output dtype + a_s = tl.load(a_s_base_ptr + k * a_s_stride_dim_1) + b_s = tl.load(b_s_base_ptr + k * b_s_stride_dim_0) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s + + a_ptrs += BLOCK_SIZE_K * a_stride_dim_1 + b_ptrs += BLOCK_SIZE_K * b_stride_dim_0 + + c = accumulator.to(c_ptr.dtype.element_ty) + c_ptrs = c_ptr + offs_m[:, None] * c_stride_dim_0 + offs_n[None, :] * c_stride_dim_1 + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def triton_fp8_gemm_1x128_128x128( + a: torch.Tensor, # (M, K) + b: torch.Tensor, # (K, N) + a_s: torch.Tensor, # (M, K // block_size) + b_s: torch.Tensor, # (K // block_size, N // block_size) + block_size: int = 128, + out_dtype: torch.dtype = torch.float32, +): + # 'a' must be in row-major layout, 'b' must be in column-major layout + assert _is_row_major(a), "a must be row-major" + assert _is_column_major(b), "b must be column-major" + + # a_scales must be col-major, b_scales must be row-major + assert _is_column_major(a_s), "a_s must be column-major" + assert _is_column_major(b_s), "b_s must be column-major" + + M = a.size(0) + K = a.size(1) + N = b.size(1) + c = a.new_empty(M, N, dtype=out_dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + wrap_triton(triton_fp8_gemm_1x128_128x128_kernel)[grid]( + a, + a.stride(0), + a.stride(1), + b, + b.stride(0), + b.stride(1), + c, + c.stride(0), + c.stride(1), + a_s, + a_s.stride(0), + a_s.stride(1), + b_s, + b_s.stride(0), + b_s.stride(1), + M, + N, + K, + out_dtype=out_dtype, + BLOCK_SIZE_K=block_size, + ) + return c + + +@triton.autotune( + configs=fp8_gemm_configs_max_autotune, key=["M", "N", "K", "BLOCK_SIZE_K"] +) +@triton.jit +def triton_fp8_gemm_1x128_128x1_kernel( + a_ptr, # (M, K) + a_stride_dim_0, + a_stride_dim_1, + b_ptr, # (K, N) + b_stride_dim_0, + b_stride_dim_1, + c_ptr, + a_s_ptr, # (M, K // block_size) + a_s_stride_dim_0, + a_s_stride_dim_1, + b_s_ptr, # (K // block_size, N) + b_s_stride_dim_0, + b_s_stride_dim_1, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + ( + offs_m[:, None] * a_stride_dim_0 + offs_k[None, :] * a_stride_dim_1 + ) + b_ptrs = b_ptr + ( + offs_k[:, None] * b_stride_dim_0 + offs_n[None, :] * b_stride_dim_1 + ) + + k_num_blocks = tl.cdiv(K, BLOCK_SIZE_K) + a_s_base_ptr = a_s_ptr + offs_m * a_s_stride_dim_0 + b_s_base_ptr = b_s_ptr + offs_n * b_s_stride_dim_1 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, k_num_blocks): + a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + # Reciprocal scales to scale back to dynamic range of output dtype + a_s = tl.load(a_s_base_ptr + k * a_s_stride_dim_1) + b_s = tl.load(b_s_base_ptr + k * b_s_stride_dim_0) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + + a_ptrs += BLOCK_SIZE_K * a_stride_dim_1 + b_ptrs += BLOCK_SIZE_K * b_stride_dim_0 + + c = accumulator.to(c_ptr.dtype.element_ty) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def triton_fp8_gemm_1x128_128x1( + a: torch.Tensor, # (M, K) + b: torch.Tensor, # (K, N) + a_s: torch.Tensor, # (M, K // block_size) reciprocals of scales + b_s: torch.Tensor, # (K // block_size, N) reciprocals of scales + block_size: int = 128, + out_dtype: torch.dtype = torch.float32, +): + # 'a' must be in row-major layout, 'b' must be in column-major layout + assert _is_row_major(a), "a must be row-major" + assert _is_column_major(b), "b must be column-major" + + # a_scales must be col-major, b_scales must be row-major + assert _is_column_major(a_s), "a_s must be column-major" + assert _is_row_major(b_s), "b_s must be row-major" + + M = a.size(0) + K = a.size(1) + N = b.size(1) + c = a.new_empty(M, N, dtype=out_dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + wrap_triton(triton_fp8_gemm_1x128_128x1_kernel)[grid]( + a, + a.stride(0), + a.stride(1), + b, + b.stride(0), + b.stride(1), + c, + a_s, + a_s.stride(0), + a_s.stride(1), + b_s, + b_s.stride(0), + b_s.stride(1), + M, + N, + K, + BLOCK_SIZE_K=block_size, + ) + return c + + +# Quantization kernels autotuner configs +quant_kernel_configs = [ + triton.Config( + {}, + num_warps=warps, + num_stages=stages, + ) + for warps in [4, 8] + for stages in [2, 4] +] + +quant_kernel_configs_with_groups = [ + triton.Config( + {"NUM_GROUPS": groups}, + num_warps=warps, + num_stages=stages, + ) + for groups in [2, 16, 32, 64, 128] + for warps in [2, 4, 8] + for stages in [2, 4, 6] +] + + +@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"]) +@triton.jit +def triton_fp8_blockwise_act_quant_lhs_kernel( + x_ptr, + x_stride_dim_0, + x_stride_dim_1, + y_ptr, + y_stride_dim_0, + y_stride_dim_1, + s_ptr, + s_stride_dim_0, + s_stride_dim_1, + M, + K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, + EPS: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + # Load (num_groups x block_size) tile of x, where input is row major + m_offs = pid_m * NUM_GROUPS + tl.arange(0, NUM_GROUPS) + k_offs = pid_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 + x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) + x = tl.load(x_ptr + x_offs, mask=x_mask) + + # Perform scaling + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + + # Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1) + amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, max=float("inf")).to(tl.float64) + scale = (max_fp8_e4m3 / amax).to(tl.float32)[:, None] + y = x * scale + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + + # Write output to column major fomrat + y_offs = m_offs[:, None] * y_stride_dim_0 + k_offs[None, :] * y_stride_dim_1 + y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) + tl.store(y_ptr + y_offs, y, mask=y_mask) + + # Write reciprocal scales + scale_offs = m_offs[:, None] * s_stride_dim_0 + pid_k * s_stride_dim_1 + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale)) + + +@triton_op("torchao::triton_fp8_blockwise_act_quant_lhs", mutates_args={}) +def triton_fp8_blockwise_act_quant_lhs( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Input: row-major high-precision tensor + Output: row-major, with reciprocal scales for (1 x block_size) groups stored in col-major. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + ], "dtype must be torch.float8_e4m3fn" + M, K = x.size() + y = torch.empty_like(x, dtype=dtype) + # Write scales to column-major format to align with torch._scaled_mm requirements. + s = x.new_empty(M, K // block_size, dtype=torch.float32).as_strided( + (M, K // block_size), + (1, M), + ) + grid = lambda meta: ( + triton.cdiv(M, meta["NUM_GROUPS"]), + triton.cdiv(K, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_act_quant_lhs_kernel)[grid]( + x, + x.stride(0), + x.stride(1), + y, + y.stride(0), + y.stride(1), + s, + s.stride(0), + s.stride(1), + M, + K=K, + BLOCK_SIZE=block_size, + EPS=EPS, + ) + return y, s + + +@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"]) +@triton.jit +def triton_fp8_blockwise_act_quant_rhs_kernel( + x_ptr, + x_stride_dim_0, + x_stride_dim_1, + y_ptr, + y_stride_dim_0, + y_stride_dim_1, + s_ptr, + s_stride_dim_0, + s_stride_dim_1, + M, + K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, + EPS: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + # Load (block_size x block_size) tile of x, where input is row major. + # Each scaling group is (block_size x 1), but we load (block_size x block_size) + # to facilitate coalesced gmem accesses and improve efficiency. + m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS) + x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 + x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) + x = tl.load(x_ptr + x_offs, mask=x_mask) + + # Perform scaling + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + + # Column-wise scales for RHS operand, shape (1, block_size) + amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64) + scale = (max_fp8_e4m3 / amax).to(tl.float32)[None, :] + y = x * scale + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + + # Write output to column major format + y_offs = m_offs[:, None] * y_stride_dim_0 + k_offs[None, :] * y_stride_dim_1 + y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) + tl.store(y_ptr + y_offs, y, mask=y_mask) + + # Write scales + scale_offs = pid_m * s_stride_dim_0 + k_offs[None, :] * s_stride_dim_1 + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale)) + + +@triton_op("torchao::triton_fp8_blockwise_act_quant_rhs", mutates_args={}) +def triton_fp8_blockwise_act_quant_rhs( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Input: row-major + Output: column-major, with scales for (block_size x 1) groups stored in row-major. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + ], "dtype must be torch.float8_e4m3fn" + M, K = x.size() + M_blocks = triton.cdiv(M, block_size) + y = torch.empty_like(x, dtype=dtype) + y = y.as_strided(y.size(), (1, y.size(0))) + s = x.new_empty(M_blocks, K, dtype=torch.float32) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(K, meta["NUM_GROUPS"]), + ) + wrap_triton(triton_fp8_blockwise_act_quant_rhs_kernel)[grid]( + x, + x.stride(0), + x.stride(1), + y, + y.stride(0), + y.stride(1), + s, + s.stride(0), + s.stride(1), + M=M, + K=K, + BLOCK_SIZE=block_size, + EPS=EPS, + ) + return y, s + + +@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"]) +@triton.jit +def triton_fp8_blockwise_act_quant_transposed_lhs_kernel( + x_ptr, + x_stride_dim_0, + x_stride_dim_1, + y_ptr, + y_stride_dim_0, + y_stride_dim_1, + s_ptr, + s_stride_dim_0, + s_stride_dim_1, + M, + K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # For scaling groups, not for grid/parallelization + NUM_GROUPS: tl.constexpr, # For grid/parallelization, not for scaling groups + EPS: tl.constexpr, +): + # This kernel reads data in row-major format, and writes to an output tensor with + # transposed dims and in column major format. To facilitate this, given that for a + # LHS operator the scales must be rowwise, we will be computing colwise scales on the + # original data, then writing the scaled data rowwise. + pid_m = tl.program_id(axis=0) + pid_k = tl.program_id(axis=1) + + # Load (block_size x num_groups) block of input, where input is row major. + # We will be computing (block_size x 1) scaling factors (columns), and computing + # `num_groups` at a time, so we aren't parallelizing with 1 thread per column, + # which will fail to launch for large tensors, due to max block number of 65535. + m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS) + x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1 + x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) + x = tl.load(x_ptr + x_offs, mask=x_mask) + + # Perform scaling + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + + # Compute amax across dim 0 (column-wise). + amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64) + scale = (max_fp8_e4m3 / amax).to(tl.float32) + y = x * scale + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + + # Write output to column major fomrat + y_offs = k_offs[:, None] * y_stride_dim_0 + m_offs[None, :] * y_stride_dim_1 + y_mask = (k_offs[:, None] < K) & (m_offs[None, :] < M) + tl.store(y_ptr + y_offs, y.trans(1, 0), mask=y_mask) + + # Scales are one per column (block_size x 1). + scale_m_off = pid_m + scale_k_offs = k_offs + + # Scale tensor size is (K, M // SCALE_BLOCK_SIZE) + scale_offs = scale_k_offs * s_stride_dim_0 + scale_m_off * s_stride_dim_1 + scale_mask = (scale_k_offs < K) & (scale_m_off < M // BLOCK_SIZE) + + # Write out reciprocal scales + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask) + + +@triton_op("torchao::triton_fp8_blockwise_act_quant_transposed_lhs", mutates_args={}) +def triton_fp8_blockwise_act_quant_transposed_lhs( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(0) % block_size == 0, ( + f"First dimension size must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + ], "dtype must be torch.float8_e4m3fn" + + # Output should have transposed dims and be in row major format + M, K = x.shape + y = torch.empty(K, M, dtype=dtype, device=x.device) + M_blocks = triton.cdiv(M, block_size) + + # Column major scales required for torch._scaled_mm + s = x.new_empty(K, M_blocks, dtype=torch.float32).as_strided( + (K, M_blocks), # shape + (1, K), # stride + ) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(K, meta["NUM_GROUPS"]), + ) + + wrap_triton(triton_fp8_blockwise_act_quant_transposed_lhs_kernel)[grid]( + x, + x.stride(0), + x.stride(1), + y, + y.stride(0), + y.stride(1), + s, + s.stride(0), + s.stride(1), + M, + K=K, + BLOCK_SIZE=block_size, # Scaling group size + EPS=EPS, + ) + return y, s + + +@triton.autotune(configs=quant_kernel_configs, key=["M", "N"]) +@triton.jit +def triton_fp8_blockwise_weight_quant_rhs_kernel( + x_ptr, + x_stride_dim_0, + x_stride_dim_1, + y_ptr, + y_stride_dim_0, + y_stride_dim_1, + s_ptr, + s_stride_dim_0, + s_stride_dim_1, + M: tl.constexpr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + # Load (block_size x block_size) block of x, where input is row major + x_offs = offs_m[:, None] * x_stride_dim_0 + offs_n[None, :] * x_stride_dim_1 + x_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + x_offs, mask=x_mask) + + # Scale the data + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64) + scale = (max_fp8_e4m3 / amax).to(tl.float32) + y = x * scale + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + + # Store output in column major format + y_offs = offs_m[:, None] * y_stride_dim_0 + offs_n[None, :] * y_stride_dim_1 + y_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(y_ptr + y_offs, y, mask=y_mask) + + # Write reciprocal scale (scalar value) + scale_m_off = pid_m * s_stride_dim_0 + scale_n_off = pid_n * s_stride_dim_1 + tl.store(s_ptr + scale_m_off + scale_n_off, tl.div_rn(1.0, scale)) + + +@triton_op("torchao::triton_fp8_blockwise_weight_quant_rhs", mutates_args={}) +def triton_fp8_blockwise_weight_quant_rhs( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dim() == 2, "Input tensor must have 2 dimensions" + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( + f"Both dimensions of x must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + ], "dtype must be torch.float8_e4m3fn" + M, N = x.size() + y = torch.empty_like(x, dtype=dtype) + y = y.as_strided(y.size(), (1, y.size(0))) # Column major + M_blocks, N_blocks = triton.cdiv(M, block_size), triton.cdiv(N, block_size) + s = x.new_empty(M_blocks, N_blocks, dtype=torch.float32).as_strided( + (M_blocks, N_blocks), # shape + (1, M_blocks), # stride + ) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_weight_quant_rhs_kernel)[grid]( + x, + x.stride(0), + x.stride(1), + y, + y.stride(0), + y.stride(1), + s, + s.stride(0), + s.stride(1), + M, + N, + BLOCK_SIZE=block_size, + EPS=EPS, + ) + return y, s + + +@triton.autotune(configs=quant_kernel_configs, key=["M", "N"]) +@triton.jit +def triton_fp8_blockwise_weight_quant_transposed_rhs_kernel( + x_ptr, + x_stride_dim_0, + x_stride_dim_1, + y_ptr, + y_stride_dim_0, + y_stride_dim_1, + s_ptr, + s_stride_dim_0, + s_stride_dim_1, + M: tl.constexpr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + EPS: tl.constexpr, +): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factors in `s_ptr`. + + Writes output with transposed dims in column-major format. + + Args: + x_ptr (tl.pointer): Pointer to the input tensor. + y_ptr (tl.pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (tl.pointer): Pointer to the output tensor where scaling factors will be stored. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + # Load (block_size x block_size) block of input, where input is row major + m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + n_offs = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x_offs = m_offs[:, None] * x_stride_dim_0 + n_offs[None, :] * x_stride_dim_1 + x_mask = (m_offs[:, None] < M) & (n_offs[None, :] < N) + x = tl.load(x_ptr + x_offs, mask=x_mask).to(tl.float32) + + # Perform scaling + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64) + scale = (max_fp8_e4m3 / amax).to(tl.float32) + y = x * scale + y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty) + + # Write output to column major fomrat + y_offs = n_offs[:, None] * y_stride_dim_0 + m_offs[None, :] * y_stride_dim_1 + y_mask = (n_offs[:, None] < N) & (m_offs[None, :] < M) + tl.store(y_ptr + y_offs, y.trans(1, 0), mask=y_mask) + + # Write reciprocal scales + scale_m = pid_m + scale_k = pid_n + scale_offs = scale_k[:, None] * s_stride_dim_0 + scale_m[None, :] * s_stride_dim_1 + scale_mask = (scale_k[:, None] < N // BLOCK_SIZE) & ( + scale_m[None, :] < M // BLOCK_SIZE + ) + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask) + + +@triton_op("torchao::triton_fp8_blockwise_weight_quant_transposed_rhs", mutates_args={}) +def triton_fp8_blockwise_weight_quant_transposed_rhs( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dim() == 2, "Input tensor must have 2 dimensions" + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( + f"Both dimensions of x must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + ], "dtype must be torch.float8_e4m3fn" + M, N = x.size() + y = torch.empty(N, M, dtype=dtype, device=x.device) + y = y.as_strided(y.size(), (1, y.size(0))) # Column major + n_blocks, m_blocks = triton.cdiv(N, block_size), triton.cdiv(M, block_size) + s = x.new_empty(n_blocks, m_blocks, dtype=torch.float32).as_strided( + (n_blocks, m_blocks), # shape + (1, n_blocks), # stride + ) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + wrap_triton(triton_fp8_blockwise_weight_quant_transposed_rhs_kernel)[grid]( + x, + x.stride(0), + x.stride(1), + y, + y.stride(0), + y.stride(1), + s, + s.stride(0), + s.stride(1), + M, + N, + BLOCK_SIZE=block_size, + EPS=EPS, + ) + return y, s + + +def torch_blockwise_scale_act_quant_lhs(x, tile_size=128): + """ + Input: weight tensor in high precision + Output: weight tensor in float8, and scale, tiled 1 by tile_size + """ + assert x.is_contiguous(), "input tensor must be contiguous" + orig_shape = x.shape + + # Reshape 2D+ input tensor into 2D tensor with shape (leading_dims, tile_size) + x = x.reshape(-1, tile_size) + + # Compute amax along last dim (i.e., the block) + x_amax = x.abs().max(dim=1, keepdim=True).values.to(torch.float64) + x_amax = torch.clamp(x_amax, min=EPS, max=float("inf")) + + # Convert amax to scale + fp8_dtype_max, fp8_dtype_min = ( + torch.finfo(torch.float8_e4m3fn).max, + torch.finfo(torch.float8_e4m3fn).min, + ) + s = (fp8_dtype_max / x_amax).to(torch.float32) + + # Apply scale and clamp + x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(torch.float8_e4m3fn) + + # Reshape quantized output back to original shape and reshape scales accordingly + x = x.reshape(*orig_shape) + s = s.reshape(orig_shape[0], -1).to(torch.float) + + # Return output tensor and reciprocal scale + return x, 1.0 / s + + +def torch_blockwise_scale_act_quant_rhs( + x: torch.Tensor, + block_size: int = 128, + dtype: torch.dtype = torch.float8_e4m3fn, + eps: float = 1e-12, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [torch.float8_e4m3fn], "dtype must be torch.float8_e4m3fn" + + M, K = x.size() + max_fp8_e4m3 = 448.0 + min_fp8_e4m3 = -448.0 + + # Reshape input to work with blocks of size (block_size, 1) along dimension 0 + num_blocks_m = M // block_size + + # Reshape to (num_blocks_m, block_size, K) for block processing + x_blocks = x.view(num_blocks_m, block_size, K) + + # Initialize output tensors + y_blocks = torch.empty_like(x_blocks, dtype=dtype) + scales = torch.empty(num_blocks_m, K, dtype=torch.float32, device=x.device) + + # Process each column (K dimension) separately + for k in range(K): + # Extract column k from all blocks: shape (num_blocks_m, block_size) + x_col = x_blocks[:, :, k] # (num_blocks_m, block_size) + + # Compute absolute max for each block + amax = torch.abs(x_col).max(dim=1, keepdim=True)[0] # (num_blocks_m, 1) + + # Clamp to avoid division by zero + amax = torch.clamp(amax, min=eps).to(torch.float64) + + # Compute scales + scale = (max_fp8_e4m3 / amax).to(torch.float32) # (num_blocks_m, 1) + + # Apply scaling + y_col = x_col * scale # (num_blocks_m, block_size) + + # Clamp to FP8 range + y_col = torch.clamp(y_col, min=min_fp8_e4m3, max=max_fp8_e4m3) + + # Store results + y_blocks[:, :, k] = y_col.to(dtype) + scales[:, k] = scale.squeeze(-1) # (num_blocks_m,) + + # Reshape back to original shape (removing padding if any) + y = y_blocks.view(-1, K)[:M, :] # (M, K) + + # Convert to column-major format + y = y.t().contiguous().t() + + # Return output tensor and reciprocal scales + return y, 1.0 / scales + + +def torch_blockwise_scale_weight_quant(x, tile_size=128): + """ + Input: weight tensor in high precision + Output: weight tensor in float8, and scale, tiled tile_size by tile_size + """ + assert len(x.shape) == 2, "input shape must be 2D" + assert x.is_contiguous(), "input tensor must be contiguous" + height, width = x.shape + + # Compute block sizes + t_h = height // tile_size + t_w = width // tile_size + + # Reshape 2D input tensor into 4D tensor with shape (t_h, t_w, tile_size * tile_size) + x = x.reshape(t_h, tile_size, t_w, tile_size) + x = x.permute(0, 2, 1, 3) + x = x.reshape(-1, tile_size * tile_size) + + # Compute amax along last dim (i.e., the block) + x_amax = x.abs().max(dim=1).values.unsqueeze(1).to(torch.float64) + x_amax = torch.clamp(x_amax, min=EPS, max=float("inf")) + + # Convert amax to scale + fp8_dtype_max, fp8_dtype_min = ( + torch.finfo(torch.float8_e4m3fn).max, + torch.finfo(torch.float8_e4m3fn).min, + ) + s = (fp8_dtype_max / x_amax).to(torch.float32) + + # Apply scale and clamp + x = (x * s).clamp(min=fp8_dtype_min, max=fp8_dtype_max).to(torch.float8_e4m3fn) + + # Reshape quantized output and scales back to 2D + x = x.reshape(t_h, t_w, tile_size, tile_size) + x = x.permute(0, 2, 1, 3) + x = x.reshape(height, width) + s = s.reshape(t_h, t_w).to(torch.float) + + # Return output tensor and reciprocal scale + return x, 1.0 / s diff --git a/torchao/prototype/blockwise_fp8_training/linear.py b/torchao/prototype/blockwise_fp8_training/linear.py new file mode 100644 index 0000000000..95dc6762d0 --- /dev/null +++ b/torchao/prototype/blockwise_fp8_training/linear.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + +from torchao.core.config import AOBaseConfig +from torchao.prototype.blockwise_fp8_training.kernels import ( + triton_fp8_blockwise_act_quant_lhs, + triton_fp8_blockwise_act_quant_rhs, + triton_fp8_blockwise_act_quant_transposed_lhs, + triton_fp8_blockwise_weight_quant_rhs, + triton_fp8_blockwise_weight_quant_transposed_rhs, + triton_fp8_gemm_1x128_128x1, + triton_fp8_gemm_1x128_128x128, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) +from torchao.utils import is_sm_at_least_90 + + +class fp8_blockwise_mm(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, block_size, out_dtype=torch.bfloat16, use_triton=False): + assert block_size == 128, "Only support block_size=128" + + # Temporarily reshape x to 2D tensor + x_orig_shape = x.shape + x = x.reshape(-1, x_orig_shape[-1]) + + # Cast inputs to fp8 blockwise using (1, block_size) scaling granularity in row major format. + x_fp8, x_scale = triton_fp8_blockwise_act_quant_lhs(x, block_size) + + # Cast weight to fp8 blockwise using (block_size, block_size) scaling granularity, with transposed dims in column major format. + weight_t_fp8, weight_t_scale = triton_fp8_blockwise_weight_quant_transposed_rhs( + weight, + block_size=block_size, + ) + + # out = input @ weight.T + fp8_gemm = triton_fp8_gemm_1x128_128x128 if use_triton else torch._scaled_mm + out = fp8_gemm( + x_fp8, + weight_t_fp8, + x_scale, + weight_t_scale, + out_dtype=out_dtype, + ) + out = out.reshape(*x_orig_shape[:-1], out.shape[-1]) + ctx.save_for_backward(x, weight) + ctx.block_size = block_size + ctx.out_dtype = out_dtype + ctx.use_triton = use_triton + return out + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + block_size = ctx.block_size + out_dtype = ctx.out_dtype + use_triton = ctx.use_triton + + # Reshape input to 2D + x_orig_shape = x.shape + x = x.reshape(-1, x_orig_shape[-1]) + + # Reshape grad_output to 2D + grad_output_orig_shape = grad_output.shape + grad_output = grad_output.reshape(-1, grad_output_orig_shape[-1]).contiguous() + assert grad_output.shape[1] % 128 == 0, "unsupported" + + # Cast grad_output to fp8 blockwise 1x128 since it is the grad of the output activation. + grad_output_fp8, grad_output_scale = triton_fp8_blockwise_act_quant_lhs( + grad_output, + block_size, + ) + + # Cast weight to fp8 blockwise to 128x128 in column major format. + weight_fp8, weight_scale = triton_fp8_blockwise_weight_quant_rhs( + weight, + block_size=block_size, + ) + + # grad_x = grad_output @ weight + fp8_gemm_1x128_128x128 = ( + triton_fp8_gemm_1x128_128x128 if use_triton else torch._scaled_mm + ) + grad_x = fp8_gemm_1x128_128x128( + grad_output_fp8, + weight_fp8, + grad_output_scale, + weight_scale, + out_dtype=out_dtype, + ) + + # Cast grad_output_t to fp8 blockwise with (1 x block_size) scaling groups, since it is + # the grad of the output activation. + # Write directly with transposed dims in row major format, as needed for dW calc. + grad_output_t_fp8, grad_output_t_scale = ( + triton_fp8_blockwise_act_quant_transposed_lhs( + grad_output, + block_size, + ) + ) + + # Cast x to fp8 blockwise with (block_size x 1) scaling groups, in column major format. + # RHS should have groupwise scales calculated colwise, so scaling groups do not cross the + # contracting (K) dim. + x_fp8, x_scale = triton_fp8_blockwise_act_quant_rhs(x, block_size) + + # grad_weight = grad_output.T @ x + fp8_gemm_1x128_128x1 = ( + triton_fp8_gemm_1x128_128x1 if use_triton else torch._scaled_mm + ) + grad_weight = fp8_gemm_1x128_128x1( + grad_output_t_fp8, + x_fp8, + grad_output_t_scale, + x_scale, + out_dtype=out_dtype, + ) + + # Reshape grad_x to expected potentially 3D+ shape + grad_x = grad_x.reshape(*grad_output_orig_shape[:-1], grad_x.shape[-1]) + return grad_x, grad_weight, None, None, None + + +class Float8BlockwiseLinear(nn.Linear): + """ + Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + block_size (int): Block size for quantization. Defaults to 128. + dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn. + """ + + supported_dtypes = [ + torch.bfloat16, + ] + + def __init__( + self, + *args, + block_size: int = 128, + dtype=torch.bfloat16, + use_triton=False, + **kwargs, + ): + super().__init__(*args, **kwargs) + + assert dtype in self.supported_dtypes, ( + f"Unsupported dtype: {dtype}. Supported dtypes: {self.supported_dtypes}" + ) + assert is_sm_at_least_90(), "Only support SM90" + self.block_size = block_size + self.dtype = dtype + self.use_triton = use_triton + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the custom linear layer. + + Args: + x (torch.Tensor): input tensor. + + Returns: + torch.Tensor: Transformed tensor after linear computation. + """ + return fp8_blockwise_mm.apply( + x, self.weight, self.block_size, self.dtype, self.use_triton + ) + + @classmethod + def from_float( + cls, + mod, + ): + assert mod.bias is None, "unsupported" + assert mod.in_features % 128 == 0, "unsupported" + assert mod.out_features % 128 == 0, "unsupported" + with torch.device("meta"): + new_mod = cls( + mod.in_features, + mod.out_features, + bias=False, + ) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + return new_mod + + +class Float8BlockwiseLinearConfig(AOBaseConfig): + pass + + +@register_quantize_module_handler(Float8BlockwiseLinearConfig) +def _float8_blockwise_transform(module, config): + return Float8BlockwiseLinear.from_float(module) diff --git a/torchao/prototype/float8nocompile/examples/example.py b/torchao/prototype/float8nocompile/examples/example.py index 97d42eee90..1351e2c938 100644 --- a/torchao/prototype/float8nocompile/examples/example.py +++ b/torchao/prototype/float8nocompile/examples/example.py @@ -9,10 +9,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input m = ( diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear.py b/torchao/prototype/float8nocompile/float8nocompile_linear.py index 7e0eb85022..b7ee306066 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear.py @@ -11,7 +11,11 @@ import torch from torchao.float8.config import Float8LinearConfig -from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig +from torchao.float8.float8_training_tensor import ( + GemmInputRole, + LinearMMConfig, + ScaledMMConfig, +) from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import ( ToFP8ColumnMajor, ToFP8ColumnMajorT, diff --git a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py index 7b6a25e3f9..1e55c0c2e9 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py +++ b/torchao/prototype/float8nocompile/float8nocompile_scaling_utils.py @@ -10,7 +10,7 @@ import torch -from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig +from torchao.float8.float8_training_tensor import GemmInputRole, LinearMMConfig from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( KernelAlgorithm, hp_to_fp8_col_major, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index 3786b52eb5..37c7611980 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -14,7 +14,11 @@ import triton import triton.language as tl -from torchao.float8.float8_tensor import Float8Tensor, GemmInputRole, LinearMMConfig +from torchao.float8.float8_training_tensor import ( + Float8TrainingTensor, + GemmInputRole, + LinearMMConfig, +) EPS = 1e-12 @@ -487,7 +491,7 @@ def _scale_atomic( tl.float32 ) - # store scale for use in Float8Tensor constructor + # store scale for use in Float8TrainingTensor constructor scale_off = tl.arange(0, 1) tl.store(scale_out_ptr + scale_off, scale) @@ -541,7 +545,7 @@ def hp_to_fp8_row_major( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() @@ -576,8 +580,8 @@ def hp_to_fp8_row_major( EPS=EPS, ) - # wrap output tensor in Float8Tensor - fp8_tensor_row_major = Float8Tensor( + # wrap output tensor in Float8TrainingTensor + fp8_tensor_row_major = Float8TrainingTensor( output_buffer, scale, orig_dtype=hp_tensor.dtype, @@ -593,7 +597,7 @@ def hp_to_fp8_row_major_t( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() @@ -641,8 +645,8 @@ def hp_to_fp8_row_major_t( EPS=EPS, ) - # wrap output tensor in Float8Tensor - fp8_tensor_row_major_t = Float8Tensor( + # wrap output tensor in Float8TrainingTensor + fp8_tensor_row_major_t = Float8TrainingTensor( output_buffer, scale, orig_dtype=hp_tensor.dtype, @@ -658,7 +662,7 @@ def hp_to_fp8_col_major( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() @@ -705,8 +709,8 @@ def hp_to_fp8_col_major( col_major_strides = (1, num_rows) output_buffer = output_buffer.as_strided(output_buffer.size(), col_major_strides) - # wrap output tensor in Float8Tensor - fp8_tensor_col_major = Float8Tensor( + # wrap output tensor in Float8TrainingTensor + fp8_tensor_col_major = Float8TrainingTensor( output_buffer, scale, orig_dtype=hp_tensor.dtype, @@ -722,7 +726,7 @@ def hp_to_fp8_col_major_t( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() @@ -757,8 +761,8 @@ def hp_to_fp8_col_major_t( EPS=EPS, ) - # wrap output tensor in Float8Tensor - fp8_tensor_col_major_t = Float8Tensor( + # wrap output tensor in Float8TrainingTensor + fp8_tensor_col_major_t = Float8TrainingTensor( output_buffer, scale, orig_dtype=hp_tensor.dtype, @@ -774,7 +778,7 @@ def hp_to_fp8_row_and_col_major( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -830,15 +834,15 @@ def hp_to_fp8_row_and_col_major( fp8_output_col_major.size(), col_major_strides ) - # wrap outputs in Float8Tensors - fp8_tensor_row_major = Float8Tensor( + # wrap outputs in Float8TrainingTensors + fp8_tensor_row_major = Float8TrainingTensor( fp8_output_row_major, scale, orig_dtype=hp_tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, ) - fp8_tensor_col_major = Float8Tensor( + fp8_tensor_col_major = Float8TrainingTensor( fp8_output_col_major, scale, orig_dtype=hp_tensor.dtype, @@ -854,7 +858,7 @@ def hp_to_fp8_row_major_t_and_non_t( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -912,15 +916,15 @@ def hp_to_fp8_row_major_t_and_non_t( EPS=EPS, ) - # wrap outputs in Float8Tensors - fp8_tensor_row_major = Float8Tensor( + # wrap outputs in Float8TrainingTensors + fp8_tensor_row_major = Float8TrainingTensor( fp8_output_row_major, scale, orig_dtype=hp_tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, ) - fp8_tensor_row_major_t = Float8Tensor( + fp8_tensor_row_major_t = Float8TrainingTensor( fp8_output_row_major_t, scale, orig_dtype=hp_tensor.dtype, @@ -936,7 +940,7 @@ def hp_to_fp8_col_major_t_and_non_t( linear_mm_config: LinearMMConfig, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, -) -> Float8Tensor: +) -> Float8TrainingTensor: assert hp_tensor.is_contiguous(), "input tensor must be contiguous" tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -999,15 +1003,15 @@ def hp_to_fp8_col_major_t_and_non_t( fp8_output_col_major.size(), col_major_strides ) - # wrap outputs in Float8Tensors - fp8_tensor_col_major = Float8Tensor( + # wrap outputs in Float8TrainingTensors + fp8_tensor_col_major = Float8TrainingTensor( fp8_output_col_major, scale, orig_dtype=hp_tensor.dtype, linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, ) - fp8_tensor_col_major_t = Float8Tensor( + fp8_tensor_col_major_t = Float8TrainingTensor( fp8_output_col_major_t, scale, orig_dtype=hp_tensor.dtype, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index 2348877d5c..0d7a20fae7 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -7,7 +7,7 @@ import torch from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_training_tensor import LinearMMConfig from torchao.float8.float8_utils import is_row_major from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( KernelAlgorithm, diff --git a/torchao/prototype/float8nocompile/test/fsdp_test.py b/torchao/prototype/float8nocompile/test/fsdp_test.py index 4e73fb9b97..375e48311d 100644 --- a/torchao/prototype/float8nocompile/test/fsdp_test.py +++ b/torchao/prototype/float8nocompile/test/fsdp_test.py @@ -22,10 +22,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") class TestModel(nn.Module): diff --git a/torchao/prototype/float8nocompile/test/train_test.py b/torchao/prototype/float8nocompile/test/train_test.py index 3f2ee47cd7..aceca5b400 100644 --- a/torchao/prototype/float8nocompile/test/train_test.py +++ b/torchao/prototype/float8nocompile/test/train_test.py @@ -11,10 +11,6 @@ from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( convert_to_float8_nocompile_training, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") class TestModel(nn.Module): diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index 46fae4bfe9..cda96f6b3c 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -108,15 +108,15 @@ print("Quant API example") print("-------------------------------------------------------------------") -from torchao.quantization.quant_api import int4_weight_only +from torchao.quantization.quant_api import Int4WeightOnlyConfig nbits = 4 target_dtype = torch.int32 inner_k_tiles = 8 _layout = TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles) -int4_weight_only_patch_fct = int4_weight_only( - group_size=group_size, inner_k_tiles=inner_k_tiles +int4_weight_only_patch_fct = Int4WeightOnlyConfig( + group_size=group_size, inner_k_tiles=inner_k_tiles, version=1 ) linear_layer_default = torch.nn.Linear( in_features, out_features, bias=False, device=device diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index f15c9a8104..8f049b431b 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -17,7 +17,7 @@ from torch import Tensor, nn from torchao.dtypes.utils import is_device -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, check_cpu_version +from torchao.utils import check_cpu_version class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -209,9 +209,8 @@ def hqq_quants_to_torch_quants( .reshape(shape) .contiguous() ) - if TORCH_VERSION_AT_LEAST_2_5: - if not is_device(W_q.device.type, "cpu"): - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val diff --git a/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py b/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py index 1f8865356a..145516ddf3 100644 --- a/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py +++ b/torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py @@ -13,7 +13,7 @@ from .utils import expand USEFUL_FUNCTIONS = r""" -inline float {{kernel_name}}_calculate_scale( +inline float calculate_scale( int64_t headSize, std::optional scale) { return scale.has_value() @@ -22,7 +22,7 @@ } template -inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) { +inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { const int32_t vec_size = at::vec::Vectorized::size(); auto data_vec = at::vec::Vectorized(val); int64_t d = 0; @@ -35,13 +35,13 @@ } template -inline void {{kernel_name}}_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { +inline void store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { src.store(dst, size); } template inline typename std::enable_if_t || std::is_same_v, void> -{{kernel_name}}_store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { +store(scalar_t* dst, at::vec::Vectorized src, int size=at::vec::Vectorized::size()) { auto res = at::vec::convert(src); res.store(dst, size); } @@ -52,7 +52,7 @@ 3. max reduce for softmax */ template -inline void {{kernel_name}}_dequant_mask_max_fusion_kernel( +inline void dequant_mask_max_fusion_kernel( const int32_t* in, const mask_t* mask_ptr, const int32_t* sum_a_ptr, @@ -90,7 +90,7 @@ auto tmp7 = at::vec::convert(tmp6); auto tmp8 = tmp5 + tmp7; vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp8); - {{kernel_name}}_store(tmp_out + col, tmp8); + store(tmp_out + col, tmp8); } if (col < N) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); @@ -103,7 +103,7 @@ auto tmp6 = at::vec::Vectorized::loadu(mask_data_ptr + col, N - col); auto tmp7 = at::vec::convert(tmp6); auto tmp8 = tmp5 + tmp7; - {{kernel_name}}_store(tmp_out + col, tmp8, N - col); + store(tmp_out + col, tmp8, N - col); vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp8), N - col); } sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); @@ -114,7 +114,7 @@ 1. dequant 2. max reduce for softmax */ -inline void {{kernel_name}}_dequant_max_fusion_kernel( +inline void dequant_max_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int32_t* sum_b_ptr, @@ -146,7 +146,7 @@ auto tmp4 = at::vec::convert(tmp3); auto tmp5 = tmp4 * vec_alpha; vec_tmp_max = at::vec::clamp_min(vec_tmp_max, tmp5); - {{kernel_name}}_store(tmp_out + col, tmp5); + store(tmp_out + col, tmp5); } if (col < N) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); @@ -156,7 +156,7 @@ auto tmp3 = tmp2 + vec_beta; auto tmp4 = at::vec::convert(tmp3); auto tmp5 = tmp4 * vec_alpha; - {{kernel_name}}_store(tmp_out + col, tmp5, N - col); + store(tmp_out + col, tmp5, N - col); vec_tmp_max = at::vec::Vectorized::set(vec_tmp_max, at::vec::clamp_min(vec_tmp_max, tmp5), N - col); } sfm_max_ptr[row] = std::max(sfm_max_ptr[row], vec_tmp_max.reduce_max()); @@ -169,7 +169,7 @@ 3. sum for attention */ template -inline void {{kernel_name}}_sub_exp_sum_div_quant_sum_fusion_kernel( +inline void sub_exp_sum_div_quant_sum_fusion_kernel( const float* in, const int64_t& M, const int64_t& N_step, @@ -214,13 +214,13 @@ auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); vec_tmp_sum += tmp2; - {{kernel_name}}_store(tmp_out + col, tmp2); + store(tmp_out + col, tmp2); } if (col < kvBlockSize) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); - {{kernel_name}}_store(tmp_out + col, tmp2, kvBlockSize - col); + store(tmp_out + col, tmp2, kvBlockSize - col); vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); } sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); @@ -243,7 +243,7 @@ auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp4); + store(tmp_out + col, tmp4); auto tmp6 = at::vec::convert(tmp4); vec_tmp_sum += tmp6; } @@ -253,7 +253,7 @@ auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp4, kvBlockSize - col); + store(tmp_out + col, tmp4, kvBlockSize - col); auto tmp6 = at::vec::convert(tmp4); vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp6, kvBlockSize - col); } @@ -261,10 +261,10 @@ // set zero col = kvBlockSize; for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { - {{kernel_name}}_store(tmp_out + col, vec_zero); + store(tmp_out + col, vec_zero); } if (col < av_gemm_K) { - {{kernel_name}}_store(tmp_out + col, vec_zero, av_gemm_K - col); + store(tmp_out + col, vec_zero, av_gemm_K - col); } } } @@ -275,7 +275,7 @@ 2. quant */ template -inline void {{kernel_name}}_sub_exp_sum_div_quant_fusion_kernel( +inline void sub_exp_sum_div_quant_fusion_kernel( const float* in, const int64_t& M, const int64_t& N_step, @@ -318,14 +318,14 @@ auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); vec_tmp_sum += tmp2; - {{kernel_name}}_store(tmp_out + col, tmp2); + store(tmp_out + col, tmp2); } if (col < kvBlockSize) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); auto tmp1 = tmp0 - vec_max; auto tmp2 = tmp1.exp_u20(); vec_tmp_sum = at::vec::Vectorized::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); - {{kernel_name}}_store(tmp_out + col, tmp2, kvBlockSize - col); + store(tmp_out + col, tmp2, kvBlockSize - col); } sfm_sum_ptr[row] += vec_tmp_sum.reduce_add(); } @@ -345,7 +345,7 @@ auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp4); + store(tmp_out + col, tmp4); } if (col < kvBlockSize) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + col, kvBlockSize - col); @@ -353,15 +353,15 @@ auto tmp2 = tmp1.round(); auto tmp3 = tmp2 + vec_beta1; auto tmp4 = at::vec::clamp(tmp3, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp4, kvBlockSize - col); + store(tmp_out + col, tmp4, kvBlockSize - col); } // set zero col = kvBlockSize; for (; col < vec_size * (av_gemm_K / vec_size); col += vec_size) { - {{kernel_name}}_store(tmp_out + col, vec_zero); + store(tmp_out + col, vec_zero); } if (col < av_gemm_K) { - {{kernel_name}}_store(tmp_out + col, vec_zero, av_gemm_K - col); + store(tmp_out + col, vec_zero, av_gemm_K - col); } } } @@ -372,7 +372,7 @@ 2. quant */ template -inline void {{kernel_name}}_dequant_quant_fusion_kernel( +inline void dequant_quant_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int32_t* sum_b_ptr, @@ -410,7 +410,7 @@ auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp8); + store(tmp_out + col, tmp8); } if (col < N) { auto vec_sum_b = at::vec::Vectorized::loadu(sum_b_ptr + col, N - col); @@ -423,7 +423,7 @@ auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp8, N - col); + store(tmp_out + col, tmp8, N - col); } } } @@ -433,7 +433,7 @@ 2. quant */ template -inline void {{kernel_name}}_dequant_quant_fusion_kernel( +inline void dequant_quant_fusion_kernel( const int32_t* in, const int32_t* sum_a_ptr, const int& M, @@ -467,7 +467,7 @@ auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp8); + store(tmp_out + col, tmp8); } if (col < N) { auto tmp1 = at::vec::Vectorized::loadu(tmp_in + col, N - col); @@ -477,13 +477,13 @@ auto tmp6 = tmp5.round(); auto tmp7 = tmp6 + vec_beta2; auto tmp8 = at::vec::clamp(tmp7, vec_min_val, vec_max_val); - {{kernel_name}}_store(tmp_out + col, tmp8, N - col); + store(tmp_out + col, tmp8, N - col); } } } template -inline void {{kernel_name}}_int_sum_b_contiguous_kernel_helper( +inline void int_sum_b_contiguous_kernel_helper( const scalar_t* in, int32_t* out, const int& N, @@ -507,7 +507,7 @@ // reduce along dim b for shape [a, b], with sum shape [a] template -inline void {{kernel_name}}_int_sum_b_contiguous_kernel( +inline void int_sum_b_contiguous_kernel( const scalar_t* in, int32_t* out, const int& M, @@ -515,13 +515,13 @@ const int& ld, const int32_t& scale) { for (long r = 0; r < M; r += 1) { - {{kernel_name}}_int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); + int_sum_b_contiguous_kernel_helper(in + r * ld, out + r, N, scale); } } // reduce along dim a for shape [a, b], with sum shape [b] template -inline void {{kernel_name}}_int_sum_a_contiguous_kernel( +inline void int_sum_a_contiguous_kernel( const scalar_t* in, int32_t* out, const int& M, @@ -535,10 +535,10 @@ auto vec_zero = at::vec::Vectorized(zero); long i = 0; for (; i < vec_size * (M / vec_size); i += vec_size) { - {{kernel_name}}_store(out + i, vec_zero); + store(out + i, vec_zero); } if (i < M) { - {{kernel_name}}_store(out + i, vec_zero, M - i); + store(out + i, vec_zero, M - i); } // sum for (long j = 0; j < N; j++) { @@ -549,14 +549,14 @@ auto tmp1 = at::vec::Vectorized::loadu(out + k); auto tmp2 = at::vec::convert(tmp0); auto tmp3 = tmp1 + tmp2; - {{kernel_name}}_store(out + k, tmp3); + store(out + k, tmp3); } if (k < M) { auto tmp0 = at::vec::Vectorized::loadu(tmp_in + k, M - k); auto tmp1 = at::vec::Vectorized::loadu(out + k, M - k); auto tmp2 = at::vec::convert(tmp0); auto tmp3 = tmp1 + tmp2; - {{kernel_name}}_store(out + k, tmp3, M - k); + store(out + k, tmp3, M - k); } } // scale @@ -564,18 +564,18 @@ for (; i < vec_size * (M / vec_size); i += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(out + i); auto tmp1 = tmp0 * vec_scale; - {{kernel_name}}_store(out + i, tmp1); + store(out + i, tmp1); } if (i < M) { auto tmp0 = at::vec::Vectorized::loadu(out + i, M - i); auto tmp1 = tmp0 * vec_scale; - {{kernel_name}}_store(out + i, tmp1, M - i); + store(out + i, tmp1, M - i); } } // do the transpose: [in_rows, in_cols] -> [in_cols, in_rows] template -inline void {{kernel_name}}_do_transpose( +inline void do_transpose( const scalar_t* src, scalar_t* dst, int64_t in_rows, @@ -591,7 +591,7 @@ // padding with pad_val: [rows, cols] -> [prows, pcols] template -inline void {{kernel_name}}_pad_remain_row_col( +inline void pad_remain_row_col( scalar_t* value_ptr, int rows, int cols, @@ -630,7 +630,7 @@ // copy value_ptr to dst_ptr with padding: [rows, cols] -> [prows, pcols] template -inline void {{kernel_name}}_copy_value_with_pad( +inline void copy_value_with_pad( const scalar_t* value_ptr, scalar_t* dst_ptr, int rows, @@ -694,6 +694,9 @@ INT8_SDPA_ONE_LOOP_TEMPLATE = r""" +#ifndef HEADER_DEFINED +#define HEADER_DEFINED + {{template.header().getvalue()}} #include #include @@ -721,6 +724,8 @@ {{template.codegen_useful_function(kernel.kernel_name)}} +#endif + {%- if has_attention_mask %} {%- set kernel_args = {"query": query, "key": key, "value": value, "attention_mask": attention_mask} %} @@ -746,7 +751,7 @@ int64_t num_head = {{kernel.size(query, 2)}}; int64_t headSize = {{kernel.size(query, 3)}}; float scaling_factor = - {{kernel.kernel_name}}_calculate_scale(headSize, {{scale}}); + calculate_scale(headSize, {{scale}}); // Strides int64_t qStrideB = {{kernel.stride(query, 0)}}; @@ -873,16 +878,16 @@ // sum k and v {%- if q_zp == 0 %} - {{kernel.kernel_name}}_fill_stub(k_sum_ptr, static_cast(0), kvSize); + fill_stub(k_sum_ptr, static_cast(0), kvSize); {%- else %} - {{kernel.kernel_name}}_int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, k_sum_ptr, kvSize, headSize, kStrideN, {{q_zp}}); {%- endif %} {%- if a_zp == 0 %} - {{kernel.kernel_name}}_fill_stub(v_sum_ptr, static_cast(0), headSize); + fill_stub(v_sum_ptr, static_cast(0), headSize); {%- else %} - {{kernel.kernel_name}}_int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, v_sum_ptr, headSize, kvSize, vStrideN, {{a_zp}}); {%- endif %} @@ -893,7 +898,7 @@ for (int64_t b = 0; b < kvBlockSize; b += block_64) { bool istail = kvBlockSize - b < block_64; int64_t trans_rows = istail ? kvBlockSize - b : block_64; - {{kernel.kernel_name}}_do_transpose( + do_transpose( k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, B_blocked_xform_u8, trans_rows, @@ -901,7 +906,7 @@ kStrideN, block_64); if (!headSize_mul64 || istail) { - {{kernel.kernel_name}}_pad_remain_row_col( + pad_remain_row_col( B_blocked_xform_u8, headSize, trans_rows, @@ -942,14 +947,14 @@ int64_t m = k * qSplitSize; int64_t qBlockSize = std::min(qSplitSize, qSize - m); // Initialize sum and max - {{kernel.kernel_name}}_fill_stub( + fill_stub( sfm_sum_ptr, static_cast(0), qSplitSize); - {{kernel.kernel_name}}_fill_stub( + fill_stub( a_sum_ptr, static_cast(0), qSplitSize); - {{kernel.kernel_name}}_fill_stub( + fill_stub( sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); int64_t num_keys = kvSize; - {{kernel.kernel_name}}_copy_value_with_pad( + copy_value_with_pad( q_data + i * qStrideB + j * qStrideH + m * qStrideM, query_t_padding_ptr, qBlockSize, @@ -959,10 +964,10 @@ qStrideM); // sum q {%- if k_zp != 0 %} - {{kernel.kernel_name}}_int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, q_sum_ptr, qBlockSize, headSize, qStrideM, {{k_zp}}); {%- else %} - {{kernel.kernel_name}}_fill_stub( + fill_stub( q_sum_ptr, static_cast(0), qSplitSize); {%- endif %} const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; @@ -986,7 +991,7 @@ accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; {%- if has_attention_mask %} const mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - {{kernel.kernel_name}}_dequant_mask_max_fusion_kernel( + dequant_mask_max_fusion_kernel( qk_s32_data, //in mask_data_offset, //mask_ptr q_sum_ptr, //sum_a_ptr @@ -1002,7 +1007,7 @@ sfm_max_ptr //sfm_max_ptr ); {%- else %} - {{kernel.kernel_name}}_dequant_max_fusion_kernel( + dequant_max_fusion_kernel( qk_s32_data, //in q_sum_ptr, //sum_a_ptr k_sum_ptr + n, //sum_b_ptr @@ -1021,7 +1026,7 @@ // and quant // and sum for attention {%- if v_zp == 0 %} - {{kernel.kernel_name}}_sub_exp_sum_div_quant_fusion_kernel( + sub_exp_sum_div_quant_fusion_kernel( qk_data, //in qBlockSize, //M kvSplitSize, //N_step @@ -1039,7 +1044,7 @@ sfm_sum_ptr //sfm_sum_ptr ); {%- else %} - {{kernel.kernel_name}}_sub_exp_sum_div_quant_sum_fusion_kernel( + sub_exp_sum_div_quant_sum_fusion_kernel( qk_data, //in qBlockSize, //M kvSplitSize, //N_step @@ -1079,7 +1084,7 @@ // After the last gemm, // do dequant compensation, quant and convert from s32 to int8 {%- if a_zp == 0 %} - {{kernel.kernel_name}}_dequant_quant_fusion_kernel( + dequant_quant_fusion_kernel( dst_s32_data, //in a_sum_ptr, //sum_a_ptr qBlockSize, //M @@ -1091,7 +1096,7 @@ out_data + i * oStrideB + j * oStrideH + m * oStrideM //out ); {%- else %} - {{kernel.kernel_name}}_dequant_quant_fusion_kernel( + dequant_quant_fusion_kernel( dst_s32_data, //in a_sum_ptr, //sum_a_ptr v_sum_ptr, //sum_b_ptr @@ -1118,6 +1123,9 @@ INT8_SDPA_SEVERAL_LOOPS_TEMPLATE = r""" +#ifndef HEADER_DEFINED +#define HEADER_DEFINED + {{template.header().getvalue()}} #include #include @@ -1125,6 +1133,7 @@ #include #include #include +#include #include #include #include @@ -1145,6 +1154,8 @@ {{template.codegen_useful_function(kernel.kernel_name)}} +#endif + {%- if has_attention_mask %} {%- set kernel_args = {"query": query, "key": key, "value": value, "attention_mask": attention_mask} %} @@ -1160,8 +1171,6 @@ int64_t num_thread = {{num_thread}}; using accum_t = float; using scalar_t = {{kernel.dtype(query)}}; - int block_64 = 64; - auto u8_dt = at::ScalarType::Byte; // Sizes int64_t batchSize = {{kernel.size(query, 0)}}; @@ -1170,7 +1179,7 @@ int64_t num_head = {{kernel.size(query, 2)}}; int64_t headSize = {{kernel.size(query, 3)}}; float scaling_factor = - {{kernel.kernel_name}}_calculate_scale(headSize, {{scale}}); + calculate_scale(headSize, {{scale}}); // Strides int64_t qStrideB = {{kernel.stride(query, 0)}}; @@ -1192,15 +1201,11 @@ int64_t kvSlice = (kvSize - 1) / kvSplitSize + 1; int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; - int64_t rndHeadSize = (headSize + block_64 - 1L) / block_64 * block_64; - int64_t rndkvSplitSize = (kvSplitSize + block_64 - 1L) / block_64 * block_64; - int64_t rndkvTail = (kvTail + block_64 - 1L) / block_64 * block_64; + int64_t rndHeadSize = headSize % 4 == 0 ? headSize : headSize + 4 - headSize % 4; + int64_t rndkvSplitSize = kvSplitSize % 4 == 0 ? kvSplitSize : kvSplitSize + 4 - kvSplitSize % 4; + int64_t rndkvTail = kvTail % 4 == 0 ? kvTail : kvTail + 4 - kvTail % 4; int64_t rndkvSize = {{kv_split_size}} > kvSize ? rndkvTail : rndkvSplitSize * kvSlice + rndkvTail; - bool av_gemm_K_mul4 = kvSplitSize % 4 == 0; - int av_gemm_K_padding = av_gemm_K_mul4 ? 0 : 4 - kvSplitSize % 4; - int av_gemm_K = kvSplitSize + av_gemm_K_padding; - {%- if has_attention_mask %} // attention mask using mask_t = {{kernel.dtype(attention_mask)}}; @@ -1229,16 +1234,12 @@ const scalar_t* v_data = value; scalar_t* out_data = output; - bool headSize_mul64 = headSize % 64 == 0; - int qk_gemm_K_padding = headSize_mul64 ? 0 : 64 - headSize % 64; - int qk_gemm_K = headSize + qk_gemm_K_padding; - - int64_t qk_reduce_strideL = qSplitSize * av_gemm_K; - int64_t v_reorder_strideL = av_gemm_K * rndHeadSize; + int64_t qk_reduce_strideL = qSplitSize * rndkvSplitSize; + int64_t v_reorder_strideL = rndkvSplitSize * rndHeadSize; int64_t total_size_uint8_per_thread = /* qk */ kvSlice * qSplitSize * rndkvSplitSize * 4 + - /* qk_local */ kvSlice * av_gemm_K * 4 + + /* qk_local */ kvSlice * rndkvSplitSize * 4 + /* qk_reduce */ kvSlice * qk_reduce_strideL + /* qk_s32 */ qSplitSize * rndkvSplitSize * 4 + /* dst_s32 */ qSplitSize * rndHeadSize * 4 + @@ -1246,7 +1247,7 @@ /* query_sum */ qSplitSize * 4 + /* attention_sum */ qSplitSize * 4 + /* softmax max */ qSplitSize * 4 + - /* query_padding_data */ qSplitSize * qk_gemm_K; + /* query_padding_data */ qSplitSize * rndHeadSize; {{template.codegen_allocate_buffer("total_buf_data", "scalar_t", "num_thread * total_size_uint8_per_thread")}} int64_t kv_sum_size_per_BH = @@ -1255,11 +1256,11 @@ {{template.codegen_allocate_buffer("kv_sum_buf_data", "int32_t", "batchSize * num_head * kv_sum_size_per_BH")}} int64_t kv_reorder_size_per_BH = - /* key_t_reorder */ qk_gemm_K * rndkvSize + + /* key_t_reorder */ rndHeadSize * rndkvSize + /* value_t_reorder */ kvSlice * v_reorder_strideL; {{template.codegen_allocate_buffer("kv_reorder_buf_data", "scalar_t", "batchSize * num_head * kv_reorder_size_per_BH")}} scalar_t* key_reorder_ptr = kv_reorder_buf_data; - scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * qk_gemm_K * rndkvSize; + scalar_t* value_reorder_ptr = kv_reorder_buf_data + batchSize * num_head * rndHeadSize * rndkvSize; // sum k and v at::parallel_for( @@ -1275,16 +1276,16 @@ int32_t* k_sum_ptr = kv_sum_ptr; int32_t* v_sum_ptr = kv_sum_ptr + kvSize; {%- if q_zp == 0 %} - {{kernel.kernel_name}}_fill_stub(k_sum_ptr, static_cast(0), kvSize); + fill_stub(k_sum_ptr, static_cast(0), kvSize); {%- else %} - {{kernel.kernel_name}}_int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, + int_sum_b_contiguous_kernel(k_data + i * kStrideB + j * kStrideH, k_sum_ptr, kvSize, headSize, kStrideN, {{q_zp}}); {%- endif %} {%- if a_zp == 0 %} - {{kernel.kernel_name}}_fill_stub(v_sum_ptr, static_cast(0), headSize); + fill_stub(v_sum_ptr, static_cast(0), headSize); {%- else %} - {{kernel.kernel_name}}_int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, + int_sum_a_contiguous_kernel(v_data + i * vStrideB + j * vStrideH, v_sum_ptr, headSize, kvSize, vStrideN, {{a_zp}}); {%- endif %} @@ -1299,59 +1300,35 @@ int64_t i = 0, j = 0, l = 0, n = 0; at::native::data_index_init( begin, i, batchSize, j, num_head, l, kvSlice); - uint8_t* B_blocked_xform_u8 = new uint8_t[qk_gemm_K * block_64]; + uint8_t* B_blocked_xform_u8 = new uint8_t[rndHeadSize * kvSplitSize]; for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable n = l * kvSplitSize; - auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K; + auto k_reorder = key_reorder_ptr + i * num_head * rndHeadSize * rndkvSize + + j * rndHeadSize * rndkvSize + n * rndHeadSize; auto v_reorder = value_reorder_ptr + i * num_head * kvSlice * v_reorder_strideL + j * kvSlice * v_reorder_strideL + n * rndHeadSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - for (int64_t b = 0; b < kvBlockSize; b += block_64) { - bool istail = kvBlockSize - b < block_64; - int64_t trans_rows = istail ? kvBlockSize - b : block_64; - {{kernel.kernel_name}}_do_transpose( - k_data + i * kStrideB + j * kStrideH + n * kStrideN + b * kStrideN, - B_blocked_xform_u8, - trans_rows, + at::native::utils::transpose( + kvBlockSize, headSize, + k_data + i * kStrideB + j * kStrideH + n * kStrideN, kStrideN, - block_64); - if (!headSize_mul64 || istail) { - {{kernel.kernel_name}}_pad_remain_row_col( - B_blocked_xform_u8, - headSize, - trans_rows, - qk_gemm_K, - block_64, - block_64 - ); - } - at::native::cpublas::pack( - qk_gemm_K, // K - block_64, // N - block_64, // ld_in - block_64, // ld_out - u8_dt, // dt_in - u8_dt, // dt_out - B_blocked_xform_u8, - k_reorder + b * qk_gemm_K); - } - // split headSize to block_64, block_64, block_64 ... - // [av_gemm_K, headSize] -> [av_gemm_K, block_64 ...] - for (int64_t b = 0; b < rndHeadSize; b += block_64) { - at::native::cpublas::pack( - av_gemm_K, - block_64, - vStrideN, // block_64, - block_64, - u8_dt, - u8_dt, - v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, - v_reorder + av_gemm_K * b); - } + B_blocked_xform_u8, + kvBlockSize); + at::vec::pack_vnni4( + /* src */ B_blocked_xform_u8, + /* dst */ k_reorder, + /* ld_src */ kvBlockSize, + /* K */ rndHeadSize, + /* N */ kvBlockSize); + at::vec::pack_vnni4( + /* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN, + /* dst */ v_reorder, + /* ld_src */ vStrideN, + /* K */ rndkvSplitSize, + /* N */ rndHeadSize); // Move to the next query at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); } @@ -1368,7 +1345,7 @@ accum_t* qk_data = reinterpret_cast(total_buf_ptr); offset += kvSlice * qSplitSize * rndkvSplitSize * 4; accum_t* qk_local_data = reinterpret_cast(total_buf_ptr + offset); - offset += kvSlice * av_gemm_K * 4; + offset += kvSlice * rndkvSplitSize * 4; scalar_t* qk_reduced_data = reinterpret_cast(total_buf_ptr + offset); offset += kvSlice * qk_reduce_strideL; int32_t* qk_s32_data = reinterpret_cast(total_buf_ptr + offset); @@ -1382,8 +1359,8 @@ int32_t* a_sum_ptr = reinterpret_cast(total_buf_ptr + offset); offset += qSplitSize * 4; accum_t* sfm_max_ptr = reinterpret_cast(total_buf_ptr + offset); - offset += qSplitSize * 4; - scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); + //offset += qSplitSize * 4; + //scalar_t* query_t_padding_ptr = reinterpret_cast(total_buf_ptr + offset); for (const auto z : c10::irange(begin, end)) { (void)z; // Suppress unused variable @@ -1398,53 +1375,45 @@ int64_t m = k * qSplitSize; int64_t qBlockSize = std::min(qSplitSize, qSize - m); // Initialize sum and max - {{kernel.kernel_name}}_fill_stub( + fill_stub( sfm_sum_ptr, static_cast(0), qSplitSize); - {{kernel.kernel_name}}_fill_stub( + fill_stub( a_sum_ptr, static_cast(0), qSplitSize); - {{kernel.kernel_name}}_fill_stub( + fill_stub( sfm_max_ptr, static_cast(-std::numeric_limits::infinity()), qSplitSize); int64_t num_keys = kvSize; - {{kernel.kernel_name}}_copy_value_with_pad( - q_data + i * qStrideB + j * qStrideH + m * qStrideM, - query_t_padding_ptr, - qBlockSize, - headSize, - qBlockSize, - qk_gemm_K, - qStrideM); // sum q + const scalar_t* q_tmp = q_data + i * qStrideB + j * qStrideH + m * qStrideM; {%- if k_zp != 0 %} - {{kernel.kernel_name}}_int_sum_b_contiguous_kernel(q_data + i * qStrideB + j * qStrideH + m * qStrideM, + int_sum_b_contiguous_kernel(q_tmp, q_sum_ptr, qBlockSize, headSize, qStrideM, {{k_zp}}); {%- else %} - {{kernel.kernel_name}}_fill_stub( + fill_stub( q_sum_ptr, static_cast(0), qSplitSize); {%- endif %} const int64_t rkvSlice = (num_keys - 1) / kvSplitSize + 1; + for (int64_t l = 0; l < rkvSlice; l++) { int64_t n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); - auto k_reorder = key_reorder_ptr + i * num_head * qk_gemm_K * rndkvSize + - j * qk_gemm_K * rndkvSize + n * qk_gemm_K; + auto k_reorder = key_reorder_ptr + i * num_head * rndHeadSize * rndkvSize + + j * rndHeadSize * rndkvSize + n * rndHeadSize; // Calculate q @ k.T - for (int64_t b = 0; b < kvBlockSize; b += block_64) { - at::native::cpublas::brgemm( - qSplitSize, block_64, qk_gemm_K, - qk_gemm_K, // lda - block_64, //ldb + at::native::cpublas::brgemm( + qSplitSize, kvBlockSize, headSize, + qStrideM, // lda + kvBlockSize, //ldb rndkvSplitSize, //ldc, false, - query_t_padding_ptr, - k_reorder + b * qk_gemm_K, - qk_s32_data + b); - } + q_tmp, + k_reorder, + qk_s32_data); // do dequant compensation, add mask, max reduce for softmax, and convert qk from s32 to fp32 accum_t* qk_block_data = qk_data + l * qSplitSize * rndkvSplitSize; {%- if has_attention_mask %} const mask_t* mask_data_offset = mask_data + i * mStrideB + j * mStrideH + m * mStrideM + (mStrideN == 0 ? 0 : n); - {{kernel.kernel_name}}_dequant_mask_max_fusion_kernel( + dequant_mask_max_fusion_kernel( qk_s32_data, //in mask_data_offset, //mask_ptr q_sum_ptr, //sum_a_ptr @@ -1460,7 +1429,7 @@ sfm_max_ptr //sfm_max_ptr ); {%- else %} - {{kernel.kernel_name}}_dequant_max_fusion_kernel( + dequant_max_fusion_kernel( qk_s32_data, //in q_sum_ptr, //sum_a_ptr k_sum_ptr + n, //sum_b_ptr @@ -1479,7 +1448,7 @@ // and quant // and sum for attention {%- if v_zp == 0 %} - {{kernel.kernel_name}}_sub_exp_sum_div_quant_fusion_kernel( + sub_exp_sum_div_quant_fusion_kernel( qk_data, //in qBlockSize, //M kvSplitSize, //N_step @@ -1488,7 +1457,7 @@ qk_reduce_strideL, //ldo kvSize, //kvSize rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K + rndkvSplitSize, //av_gemm_K {{a_zp}}, // zp_a=beta1 {{a_scale}}, // scale_a=alpha qk_local_data, //local @@ -1497,7 +1466,7 @@ sfm_sum_ptr //sfm_sum_ptr ); {%- else %} - {{kernel.kernel_name}}_sub_exp_sum_div_quant_sum_fusion_kernel( + sub_exp_sum_div_quant_sum_fusion_kernel( qk_data, //in qBlockSize, //M kvSplitSize, //N_step @@ -1506,7 +1475,7 @@ qk_reduce_strideL, //ldo kvSize, //kvSize rndkvSplitSize, //rndkvSplitSize - av_gemm_K, //av_gemm_K + rndkvSplitSize, //av_gemm_K {{a_zp}}, // zp_a=beta1 {{v_zp}}, // zp_b=beta2 {{a_scale}}, // scale_a=alpha @@ -1521,26 +1490,22 @@ auto v_reorder = value_reorder_ptr + i * num_head * kvSlice * v_reorder_strideL + j * kvSlice * v_reorder_strideL; - for (int64_t b = 0; b < headSize; b += block_64) { - auto value_reorder_b = v_reorder + b * av_gemm_K; - auto dst_s32_b = dst_s32_data + b; - for (int64_t s = 0; s < kvSlice; s++) { - at::native::cpublas::brgemm( - qSplitSize, block_64, av_gemm_K, - av_gemm_K, // lda - rndHeadSize, //ldb - rndHeadSize, //ldc - s != 0, - qk_reduced_data + s * qk_reduce_strideL, - value_reorder_b + s * v_reorder_strideL, - dst_s32_b); - } + for (int64_t s = 0; s < kvSlice; s++) { + at::native::cpublas::brgemm( + qSplitSize, headSize, rndkvSplitSize, + rndkvSplitSize, // lda + rndHeadSize, //ldb + rndHeadSize, //ldc + s != 0, + qk_reduced_data + s * qk_reduce_strideL, + v_reorder + s * v_reorder_strideL, + dst_s32_data); } // After the last gemm, // do dequant compensation, quant and convert from s32 to int8 {%- if a_zp == 0 %} - {{kernel.kernel_name}}_dequant_quant_fusion_kernel( + dequant_quant_fusion_kernel( dst_s32_data, //in a_sum_ptr, //sum_a_ptr qBlockSize, //M @@ -1552,7 +1517,7 @@ out_data + i * oStrideB + j * oStrideH + m * oStrideM //out ); {%- else %} - {{kernel.kernel_name}}_dequant_quant_fusion_kernel( + dequant_quant_fusion_kernel( dst_s32_data, //in a_sum_ptr, //sum_a_ptr v_sum_ptr, //sum_b_ptr @@ -1704,8 +1669,7 @@ def get_options( if qSize >= 768: q_split_size = 256 elif qSize >= 192: - q_split_size = 64 - kv_split_size = 64 + q_split_size = 128 qSplitSize = min(qSize, q_split_size) l2_cache_size = torch._C._cpu._L2_cache_size() @@ -1717,8 +1681,9 @@ def get_options( ): # if not symbolic shape use_one_parallel_loop = (batchSize * num_head > num_threads) and ( - attn_size > 1.5 * l2_cache_size + attn_size > 3 * l2_cache_size ) + kv_split_size = 64 if use_one_parallel_loop else 512 options = dict( q_split_size=q_split_size, diff --git a/torchao/prototype/inductor/fx_passes/README.md b/torchao/prototype/inductor/fx_passes/README.md index 7007aba993..fe4939a314 100644 --- a/torchao/prototype/inductor/fx_passes/README.md +++ b/torchao/prototype/inductor/fx_passes/README.md @@ -11,7 +11,7 @@ In TorchAO, you can replace the following customized graph passes of Inductor: ## Directory Structure -- `int8_sdpa_fusion`: Pattern match for int8 sdpa fusion. +- `qsdpa_fusion`: Pattern match for qsdpa fusion. ## Getting Started diff --git a/torchao/prototype/inductor/fx_passes/__init__.py b/torchao/prototype/inductor/fx_passes/__init__.py index aae6d5348a..eff7ff1dc2 100644 --- a/torchao/prototype/inductor/fx_passes/__init__.py +++ b/torchao/prototype/inductor/fx_passes/__init__.py @@ -1,5 +1,7 @@ -from .int8_sdpa_fusion import _int8_sdpa_init +from .da8w4_concat_linear_fusion_cpu import register_da8w4_concat_linear_cpu_pass +from .qsdpa_fusion import _qsdpa_init __all__ = [ - "_int8_sdpa_init", + "_qsdpa_init", + "register_da8w4_concat_linear_cpu_pass", ] diff --git a/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py b/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py new file mode 100644 index 0000000000..8e39826f4c --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +import torch +from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files + + +class DA8W4ConcatLinearCPUPass(CustomGraphPass): + def __call__(self, graph: torch.fx.Graph): + _concat_linear_dq8w4_cpu(graph) + + def uuid(self): + return get_hash_for_files((__file__,)) + + +# Inductor FX passes for concat linear for DA8W4 +def _is_valid_concat_linear_da8w4_fusion(computation_nodes): + if "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"): + # cpp kernels not built + return False + # OP schema: + # da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor + computation_op = torch.ops.torchao.da8w4_linear_cpu.default + act = computation_nodes[0].args[0] + act_scales = computation_nodes[0].args[1] + act_zp = computation_nodes[0].args[2] + wgt = computation_nodes[0].args[3] + in_feature_size = act.meta.get("val").size(1) # type: ignore[union-attr] + if len(wgt.meta.get("val").shape) != 4: + return False + block_k = wgt.meta.get("val").size(2) # type: ignore[union-attr] + with_bias = computation_nodes[0].args[7] is not None + output_dtype = computation_nodes[0].args[-1] + + def check_in_feature_of_wgt(wgt): + return ( + wgt.meta.get("val").size(1) * wgt.meta.get("val").size(2) == in_feature_size + ) # type: ignore[union-attr] + + def check_block_k_of_wgt(wgt): + return wgt.meta.get("val").size(2) == block_k + + def check_bias(b): + return (b is not None) if with_bias else (b is None) + + return len(computation_nodes) >= 2 and all( + ( + node.target == computation_op + and node.args[0] == act # share same activation + and node.args[1] == act_scales # same act scale + and node.args[2] == act_zp # same act zero point + and check_in_feature_of_wgt(node.args[3]) # same in-feature size + and (node.args[3] != wgt or gemm_idx == 0) + and node.args[3].op == "get_attr" # wgt are all constants + and check_block_k_of_wgt(node.args[3]) # same block_k + and check_bias(node.args[7]) # bias is either all None or all not None + and node.args[-1] == output_dtype # same output dtype + ) + for gemm_idx, node in enumerate(computation_nodes) + ) + + +def _concat_linear_dq8w4_cpu(graph: torch.fx.Graph): + """ + Concat Linear optimization pass for DA8W4 on CPU + This pass fuses the original pattern: + def ... + return (da8w4_linear_cpu(x, ..., w1, ...), da8w4_linear_cpu(x, ..., w2, ...), ...) + into a single operation: + def ... + concat_res = da8w4_linear_cpu(x, ..., concat_w, ...) + return split(concat_res, split_size_list) + """ + if "CPU" not in torch._C._dispatch_dump("torchao::da8w4_linear_cpu"): + # cpp kernels not built + return + from torch._inductor import config as inductor_config + + if not inductor_config.cpp.enable_concat_linear: + # only concat linear if the flag is set + return + gm = graph.owning_module + computation_op = torch.ops.torchao.da8w4_linear_cpu.default + # OP schema: + # da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor + for node in graph.find_nodes(op="call_function", target=computation_op): + if ( + not node._erased + and isinstance(node.meta.get("val"), torch.Tensor) + and node.meta["val"].device.type == "cpu" + ): + act = node.args[0] + act_scales = node.args[1] + act_qzeros = node.args[2] + users = list(act.users) + if _is_valid_concat_linear_da8w4_fusion(users): + with graph.inserting_before(node): + computation_node_0 = users[0] + packed_wgts = [getattr(gm, user.args[3].target) for user in users] + out_feature_size_list = [ + (w.size(0) * w.size(-1) * 2) for w in packed_wgts + ] + wgt_scales = [getattr(gm, user.args[4].target) for user in users] + wgt_qzeros = [getattr(gm, user.args[5].target) for user in users] + compensations = [getattr(gm, user.args[6].target) for user in users] + bias = [] + with_bias = users[0].args[7] is not None + if with_bias: + bias = [getattr(gm, user.args[7].target) for user in users] + output_dtype = node.args[-1] + # Shape of packed weight: [N/block_n, K/block_k, block_k, block_n/2] + # Shape of weight scales/qzeros: [N/block_n, G, block_n] + # Shape of compensation: [N/block_n, K/block_k, block_n] + # Concat them along N/block_n + concat_wgt = torch.cat(packed_wgts, dim=0) + concat_w_node_name = computation_node_0.args[3].target + "_concat" + concat_wgt_scales = torch.cat(wgt_scales, dim=0) + concat_ws_node_name = computation_node_0.args[4].target + "_concat" + concat_wgt_qzeros = torch.cat(wgt_qzeros, dim=0) + concat_wz_node_name = computation_node_0.args[5].target + "_concat" + concat_compensation = torch.cat(compensations, dim=0) + concat_comp_node_name = ( + computation_node_0.args[6].target + "_concat" + ) + concat_bias = torch.cat(bias, dim=0) if with_bias else None + concat_bias_node_name = ( + computation_node_0.args[7].target + "_concat" + if with_bias + else None + ) + gm.register_buffer(concat_w_node_name, concat_wgt) + setattr(gm, concat_w_node_name, concat_wgt) + gm.register_buffer(concat_ws_node_name, concat_wgt_scales) + setattr(gm, concat_ws_node_name, concat_wgt_scales) + gm.register_buffer(concat_wz_node_name, concat_wgt_qzeros) + setattr(gm, concat_wz_node_name, concat_wgt_qzeros) + gm.register_buffer(concat_comp_node_name, concat_compensation) + setattr(gm, concat_comp_node_name, concat_compensation) + if with_bias: + gm.register_buffer(concat_bias_node_name, concat_bias) + setattr(gm, concat_bias_node_name, concat_bias) + + concat_w_node = graph.create_node( + "get_attr", concat_w_node_name, (), {} + ) + with graph.inserting_after(concat_w_node): + concat_wgt_scales_node = graph.create_node( + "get_attr", concat_ws_node_name, (), {} + ) + with graph.inserting_after(concat_wgt_scales_node): + concat_wgt_qzeros_node = graph.create_node( + "get_attr", concat_wz_node_name, (), {} + ) + with graph.inserting_after(concat_wgt_qzeros_node): + concat_compensation_node = graph.create_node( + "get_attr", concat_comp_node_name, (), {} + ) + node_before_linear = concat_compensation_node + if with_bias: + with graph.inserting_after(concat_compensation_node): + concat_bias_node = graph.create_node( + "get_attr", concat_bias_node_name, (), {} + ) + node_before_linear = concat_bias_node + else: + concat_bias_node = None + with graph.inserting_after(node_before_linear): + new_linear_node = graph.create_node( + "call_function", + computation_op, + ( + act, + act_scales, + act_qzeros, + concat_w_node, + concat_wgt_scales_node, + concat_wgt_qzeros_node, + concat_compensation_node, + concat_bias_node, + output_dtype, + ), + ) + with graph.inserting_after(new_linear_node): + split_node = graph.create_node( + "call_function", + torch.ops.aten.split_with_sizes.default, + ( + new_linear_node, + out_feature_size_list, + -1, # split along the out feature dimension + ), + ) + with graph.inserting_after(split_node): + for gemm_idx, user in enumerate(users): + get_item = graph.create_node( + "call_function", + operator.getitem, + ( + split_node, + gemm_idx, + ), + ) + with graph.inserting_after(get_item): + clone_node = graph.create_node( + "call_function", + torch.ops.aten.clone.default, + (get_item,), + {"memory_format": torch.contiguous_format}, + ) + user.replace_all_uses_with(clone_node) + graph.erase_node(user) + + +# Define and register a custom pass for concat linear +# We always register the pass when calling this function +# but it only takes effect when config.cpp.enable_concat_linear is set to True +def register_da8w4_concat_linear_cpu_pass(): + from torch._inductor import config as inductor_config + + da8w4_concat_linear_cpu_pass = DA8W4ConcatLinearCPUPass() + inductor_config.post_grad_custom_post_pass = da8w4_concat_linear_cpu_pass diff --git a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py b/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py deleted file mode 100644 index 5e032f01c2..0000000000 --- a/torchao/prototype/inductor/fx_passes/int8_sdpa_fusion.py +++ /dev/null @@ -1,396 +0,0 @@ -import functools -import itertools - -import torch -from torch._dynamo.utils import counters -from torch._inductor import config -from torch._inductor.lowering import lowerings as L -from torch._inductor.lowering import make_fallback -from torch._inductor.pattern_matcher import ( - Arg, - CallFunction, - KeywordArg, - Match, - PatternMatcherPass, - register_lowering_pattern, -) - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 - -if TORCH_VERSION_AT_LEAST_2_7: - # TORCH_VERSION_AT_LEAST_2_7 is needed for functions in int8 sdpa lowering - from ..int8_sdpa_lowering import register_int8_sdpa # noqa: F401 -else: - make_fallback(torch.ops.torchao.qscaled_dot_product.default) - -__all__ = [ - "_int8_sdpa_init", -] - -aten = torch.ops.aten - - -def _is_valid_int8_sdpa_pattern(): - def fn(match): - assert all(k in match.kwargs for k in ("query", "key", "value")) - query = match.kwargs["query"].meta["val"] - key = match.kwargs["key"].meta["val"] - value = match.kwargs["value"].meta["val"] - return ( - query.dtype == torch.uint8 - and key.dtype == torch.uint8 - and value.dtype == torch.uint8 - and query.device.type == "cpu" - and key.device == query.device - and value.device == query.device - ) - - return fn - - -def _register_int8_sdpa_pattern(pattern, custom_pass_dict): - @register_lowering_pattern( - pattern, extra_check=_is_valid_int8_sdpa_pattern(), pass_dict=custom_pass_dict - ) - def int8_sdpa(match: Match, *args, **kwargs): - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - scale = 1.0 / kwargs["inv_scale"] if "inv_scale" in kwargs else None - attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None - q_scale = kwargs["q_scale"] - q_zp = kwargs["q_zp"] - k_scale = kwargs["k_scale"] - k_zp = kwargs["k_zp"] - v_scale = kwargs["v_scale"] - v_zp = kwargs["v_zp"] - a_scale = kwargs["a_scale"] - a_zp = kwargs["a_zp"] - o_scale = kwargs["o_scale"] - o_zp = kwargs["o_zp"] - counters["inductor"]["int8_fuse_attention"] += 1 - counters["inductor"]["int8_sdpa_nodes"] += len(match.nodes) - - trans_query = L[aten.permute.default](query, [0, 2, 1, 3]) - trans_key = L[aten.permute.default](key, [0, 2, 1, 3]) - trans_value = L[aten.permute.default](value, [0, 2, 1, 3]) - output = L[torch.ops.torchao.qscaled_dot_product.default]( - trans_query, - trans_key, - trans_value, - attn_mask, - 0.0, # dropout - False, # is_causal - scale, # scale - q_scale, - q_zp, - k_scale, - k_zp, - v_scale, - v_zp, - a_scale, - a_zp, - o_scale, - o_zp, - ) - trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) - return L[aten.clone.default]( - trans_output, memory_format=torch.contiguous_format - ) - - return int8_sdpa - - -def _get_int8_sdpa_qkv_pattern( - is_batch_size_1: bool, has_convert: bool, input_name: str -): - assert input_name in ["query", "key", "value"] - int8_sdpa_qkv_pattern_before_dequant = CallFunction( - aten.permute.default, - KeywordArg(input_name), - Arg(), - ) - if input_name == "key": - # do transpose - int8_sdpa_qkv_pattern_before_dequant = CallFunction( - aten.permute.default, - int8_sdpa_qkv_pattern_before_dequant, - Arg(), - ) - int8_sdpa_qkv_basic_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - int8_sdpa_qkv_pattern_before_dequant, - KeywordArg(input_name[0] + "_scale"), - KeywordArg(input_name[0] + "_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_qkv_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - int8_sdpa_qkv_basic_pattern = CallFunction( - aten.expand.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - if is_batch_size_1: - # pattern is different for bs=1 - return CallFunction( - aten.reshape.default, - int8_sdpa_qkv_basic_pattern, - Arg(), - ) - else: - return CallFunction( - aten.reshape.default, - CallFunction( - aten.clone.default, - int8_sdpa_qkv_basic_pattern, - memory_format=Arg(), - ), - Arg(), - ) - - -def _get_int8_sdpa_score_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_q_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "query" - ) - int8_sdpa_k_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "key" - ) - int8_sdpa_score_basic_pattern = CallFunction( - aten.reshape.default, - CallFunction( - aten.bmm.default, - int8_sdpa_q_pattern, - int8_sdpa_k_pattern, - ), - Arg(), - ) - if is_reduced_type and not has_mask: - int8_sdpa_score_basic_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_score_basic_pattern, - Arg(), - ) - if has_mask: - return CallFunction( - aten.add.Tensor, - CallFunction( - aten.div.Tensor, - int8_sdpa_score_basic_pattern, - KeywordArg("inv_scale"), - ), - KeywordArg("attn_mask"), - _users=2, - ) - else: - return CallFunction( - aten.mul.Tensor, - int8_sdpa_score_basic_pattern, - Arg(), - _users=2, - ) - - -def _get_int8_sdpa_exp_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_score_pattern = _get_int8_sdpa_score_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - int8_sdpa_exp_basic_pattern = CallFunction( - aten.sub.Tensor, - int8_sdpa_score_pattern, - CallFunction( - aten.amax.default, - int8_sdpa_score_pattern, - Arg(), - Arg(), - ), - ) - if has_mask: - return CallFunction( - aten.exp.default, - int8_sdpa_exp_basic_pattern, - _users=2, - ) - else: - return CallFunction( - aten.exp.default, - CallFunction( - aten.div.Tensor, - int8_sdpa_exp_basic_pattern, - KeywordArg("inv_scale"), - ), - _users=2, - ) - - -def _get_int8_sdpa_attn_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_exp_pattern = _get_int8_sdpa_exp_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - int8_sdpa_div_pattern = CallFunction( - aten.div.Tensor, - int8_sdpa_exp_pattern, - CallFunction( - aten.sum.dim_IntList, - int8_sdpa_exp_pattern, - Arg(), - Arg(), - ), - ) - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - int8_sdpa_div_pattern, - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ) - if is_reduced_type: - if has_mask: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_softmax_pattern, - Arg(), - ) - else: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.quantized_decomposed.dequantize_per_tensor.default, - CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_div_pattern, - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ), - KeywordArg("a_scale"), - KeywordArg("a_zp"), - Arg(), - Arg(), - Arg(), - ) - if has_convert: - int8_sdpa_softmax_pattern = CallFunction( - torch.ops.prims.convert_element_type.default, - int8_sdpa_softmax_pattern, - Arg(), - ) - return CallFunction( - aten.reshape.default, - CallFunction( - aten.expand.default, - int8_sdpa_softmax_pattern, - Arg(), - ), - Arg(), - ) - - -# Parameters to generate various patterns: -# has_mask: if SDPA has attention mask -# is_batch_size_1: if the batch size is 1 -# is_reduced_type: if autocast is enabled -# has_convert: convert type if dequant out dtype is assigned -def _get_int8_sdpa_final_pattern( - has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool -): - int8_sdpa_v_pattern = _get_int8_sdpa_qkv_pattern( - is_batch_size_1, has_convert, "value" - ) - int8_sdpa_attn_pattern = _get_int8_sdpa_attn_pattern( - has_mask, is_batch_size_1, is_reduced_type, has_convert - ) - return CallFunction( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - CallFunction( - aten.clone.default, - CallFunction( - aten.permute.default, - CallFunction( - aten.reshape.default, - CallFunction( - aten.bmm.default, - int8_sdpa_attn_pattern, - int8_sdpa_v_pattern, - ), - Arg(), - ), - Arg(), - ), - memory_format=Arg(), - ), - KeywordArg("o_scale"), - KeywordArg("o_zp"), - Arg(), - Arg(), - Arg(), - ) - - -def _register_int8_sdpa_lowerings(custom_pass_dict): - for has_mask, is_batch_size_1, is_reduced_type, has_convert in itertools.product( - [True, False], [True, False], [True, False], [True, False] - ): - _register_int8_sdpa_pattern( - _get_int8_sdpa_final_pattern( - has_mask=has_mask, - is_batch_size_1=is_batch_size_1, - is_reduced_type=is_reduced_type, - has_convert=has_convert, - ), - custom_pass_dict, - ) - - -custom_pass = None -if TORCH_VERSION_AT_LEAST_2_7: - # TORCH_VERSION_AT_LEAST_2_7 is needed for custom graph pass - from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files - - # define the custom pass - class _CustomPass(PatternMatcherPass, CustomGraphPass): - def __init__(self) -> None: - super().__init__() - - def __call__(self, g: torch.fx.graph.Graph): - self.apply(g) - - def uuid(self) -> bytes: - return get_hash_for_files((__file__,)) - - custom_pass = _CustomPass() - - -@functools.lru_cache(None) -def _int8_sdpa_init(): - if TORCH_VERSION_AT_LEAST_2_7: - _register_int8_sdpa_lowerings(config.post_grad_custom_pre_pass) - else: - pass diff --git a/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py b/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py new file mode 100644 index 0000000000..5e495a0623 --- /dev/null +++ b/torchao/prototype/inductor/fx_passes/qsdpa_fusion.py @@ -0,0 +1,489 @@ +import functools +import itertools + +import torch +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.lowering import lowerings as L +from torch._inductor.lowering import make_fallback +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + KeywordArg, + Match, + PatternMatcherPass, + register_lowering_pattern, +) + +from torchao.utils import torch_version_at_least + +if torch_version_at_least("2.7.0"): + # PyTorch 2.7+ is needed for functions in qsdpa lowering + from ..qsdpa_lowering import register_qsdpa # noqa: F401 +else: + make_fallback(torch.ops.torchao.qscaled_dot_product.default) + +__all__ = [ + "_qsdpa_init", +] + +aten = torch.ops.aten +quantize_dtypes = [torch.uint8] + + +def _is_valid_qsdpa_pattern(): + def fn(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + return ( + query.dtype in quantize_dtypes + and key.dtype in quantize_dtypes + and value.dtype in quantize_dtypes + and query.device.type == "cpu" + and key.device == query.device + and value.device == query.device + ) + + return fn + + +def _register_qsdpa_pattern(pattern, custom_pass_dict): + @register_lowering_pattern( + pattern, extra_check=_is_valid_qsdpa_pattern(), pass_dict=custom_pass_dict + ) + def qsdpa(match: Match, *args, **kwargs): + query = kwargs["query"] + key = kwargs["key"] + value = kwargs["value"] + scale = 1.0 / kwargs["inv_scale"] if "inv_scale" in kwargs else None + if scale is None: + scale = kwargs["scale"] if "scale" in kwargs else None + attn_mask = kwargs["attn_mask"] if "attn_mask" in kwargs else None + q_zp = 0 + k_zp = 0 + v_zp = 0 + a_zp = 0 + o_zp = 0 + if query.dtype == torch.uint8: + q_scale = kwargs["q_scale"] + q_zp = kwargs["q_zp"] + k_scale = kwargs["k_scale"] + k_zp = kwargs["k_zp"] + v_scale = kwargs["v_scale"] + v_zp = kwargs["v_zp"] + a_scale = kwargs["a_scale"] + a_zp = kwargs["a_zp"] + o_scale = kwargs["o_scale"] + o_zp = kwargs["o_zp"] + else: + assert match.kwargs["q_scale"].target == aten.full.default + q_scale = match.kwargs["q_scale"].args[1] + k_scale = match.kwargs["k_scale"].args[1] + v_scale = match.kwargs["v_scale"].args[1] + a_scale = match.kwargs["a_scale"].args[1] + o_scale = match.kwargs["o_scale"].args[1] + + counters["inductor"]["qsdpa_fuse_attention"] += 1 + counters["inductor"]["qsdpa_nodes"] += len(match.nodes) + + trans_query = L[aten.permute.default](query, [0, 2, 1, 3]) + trans_key = L[aten.permute.default](key, [0, 2, 1, 3]) + trans_value = L[aten.permute.default](value, [0, 2, 1, 3]) + output = L[torch.ops.torchao.qscaled_dot_product.default]( + trans_query, + trans_key, + trans_value, + attn_mask, + 0.0, # dropout + False, # is_causal + scale, + q_scale, + q_zp, + k_scale, + k_zp, + v_scale, + v_zp, + a_scale, + a_zp, + o_scale, + o_zp, + ) + trans_output = L[aten.permute.default](output, [0, 2, 1, 3]) + return L[aten.clone.default]( + trans_output, memory_format=torch.contiguous_format + ) + + return qsdpa + + +def _generate_dequant_pattern( + input_pattern, qtype, is_reduced_type, scale: str, zp: str = None +): + assert qtype is torch.uint8, "QSDPA expects type to be uint8" + assert zp is not None, "Zero point must be provided for uint8 dequantization" + return CallFunction( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + input_pattern, + KeywordArg(scale), + KeywordArg(zp), + Arg(), + Arg(), + Arg(), + ) + + +def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str = None): + assert qtype is torch.uint8, "QSDPA expects type to be uint8" + assert zp is not None, "Zero point must be provided for uint8 quantization" + return CallFunction( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + input_pattern, + KeywordArg(scale), + KeywordArg(zp), + Arg(), + Arg(), + Arg(), + ) + + +def _get_qsdpa_qkv_pattern( + qtype, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + input_name: str, +): + assert input_name in ["query", "key", "value"] + qsdpa_qkv_pattern_before_dequant = CallFunction( + aten.permute.default, + KeywordArg(input_name), + Arg(), + ) + if input_name == "key": + # do transpose + qsdpa_qkv_pattern_before_dequant = CallFunction( + aten.permute.default, + qsdpa_qkv_pattern_before_dequant, + Arg(), + ) + qsdpa_qkv_basic_pattern = _generate_dequant_pattern( + qsdpa_qkv_pattern_before_dequant, + qtype, + is_reduced_type, + input_name[0] + "_scale", + input_name[0] + "_zp" if qtype is torch.uint8 else None, + ) + if has_convert: + qsdpa_qkv_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_qkv_basic_pattern, + Arg(), + ) + qsdpa_qkv_basic_pattern = CallFunction( + aten.expand.default, + qsdpa_qkv_basic_pattern, + Arg(), + ) + if is_batch_size_1: + # pattern is different for bs=1 + return CallFunction( + aten.reshape.default, + qsdpa_qkv_basic_pattern, + Arg(), + ) + else: + return CallFunction( + aten.reshape.default, + CallFunction( + aten.clone.default, + qsdpa_qkv_basic_pattern, + memory_format=Arg(), + ), + Arg(), + ) + + +def _get_qsdpa_score_pattern( + qtype, + has_mask: bool, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + is_inv_scale: bool, +): + qsdpa_q_pattern = _get_qsdpa_qkv_pattern( + qtype, is_batch_size_1, is_reduced_type, has_convert, "query" + ) + qsdpa_k_pattern = _get_qsdpa_qkv_pattern( + qtype, is_batch_size_1, is_reduced_type, has_convert, "key" + ) + qsdpa_score_basic_pattern = CallFunction( + aten.reshape.default, + CallFunction( + aten.bmm.default, + qsdpa_q_pattern, + qsdpa_k_pattern, + ), + Arg(), + ) + if is_reduced_type and not has_mask: + qsdpa_score_basic_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_score_basic_pattern, + Arg(), + ) + if not has_mask: + return CallFunction( + aten.mul.Tensor, + qsdpa_score_basic_pattern, + Arg(), + _users=2, + ) + elif is_inv_scale: + return CallFunction( + aten.add.Tensor, + CallFunction( + aten.div.Tensor, + qsdpa_score_basic_pattern, + KeywordArg("inv_scale"), + ), + KeywordArg("attn_mask"), + _users=2, + ) + else: + return CallFunction( + aten.add.Tensor, + CallFunction( + aten.mul.Tensor, + qsdpa_score_basic_pattern, + KeywordArg("scale"), + ), + KeywordArg("attn_mask"), + _users=2, + ) + + +def _get_qsdpa_exp_pattern( + qtype, + has_mask: bool, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + is_inv_scale: bool, +): + qsdpa_score_pattern = _get_qsdpa_score_pattern( + qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale + ) + qsdpa_exp_basic_pattern = CallFunction( + aten.sub.Tensor, + qsdpa_score_pattern, + CallFunction( + aten.amax.default, + qsdpa_score_pattern, + Arg(), + Arg(), + ), + ) + if has_mask: + return CallFunction( + aten.exp.default, + qsdpa_exp_basic_pattern, + _users=2, + ) + elif is_inv_scale: + return CallFunction( + aten.exp.default, + CallFunction( + aten.div.Tensor, + qsdpa_exp_basic_pattern, + KeywordArg("inv_scale"), + ), + _users=2, + ) + else: + return CallFunction( + aten.exp.default, + CallFunction( + aten.mul.Tensor, + qsdpa_exp_basic_pattern, + KeywordArg("scale"), + ), + _users=2, + ) + + +def _get_qsdpa_attn_pattern( + qtype, + has_mask: bool, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + is_inv_scale: bool, +): + qsdpa_exp_pattern = _get_qsdpa_exp_pattern( + qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale + ) + qsdpa_div_pattern = CallFunction( + aten.div.Tensor, + qsdpa_exp_pattern, + CallFunction( + aten.sum.dim_IntList, + qsdpa_exp_pattern, + Arg(), + Arg(), + ), + ) + qsdpa_softmax_pattern = _generate_dequant_pattern( + _generate_quant_pattern( + qsdpa_div_pattern, + qtype, + "a_scale", + "a_zp" if qtype is torch.uint8 else None, + ), + qtype, + is_reduced_type, + "a_scale", + "a_zp" if qtype is torch.uint8 else None, + ) + if is_reduced_type: + if has_mask: + qsdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_softmax_pattern, + Arg(), + ) + else: + qsdpa_softmax_pattern = _generate_dequant_pattern( + _generate_quant_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_div_pattern, + Arg(), + ), + qtype, + "a_scale", + "a_zp" if qtype is torch.uint8 else None, + ), + qtype, + is_reduced_type, + "a_scale", + "a_zp" if qtype is torch.uint8 else None, + ) + if has_convert: + qsdpa_softmax_pattern = CallFunction( + torch.ops.prims.convert_element_type.default, + qsdpa_softmax_pattern, + Arg(), + ) + return CallFunction( + aten.reshape.default, + CallFunction( + aten.expand.default, + qsdpa_softmax_pattern, + Arg(), + ), + Arg(), + ) + + +# Parameters to generate various patterns: +# qdtype: quantized dtypes are uint8, float8_e4m3fn for now +# has_mask: if SDPA has attention mask +# is_batch_size_1: if the batch size is 1 +# is_reduced_type: if autocast is enabled +# has_convert: convert type if dequant out dtype is assigned +# is_inv_scale: if the scale in SDPA is inversed, in which case it is multiplied instead of divided +def _get_qsdpa_final_pattern( + qtype, + has_mask: bool, + is_batch_size_1: bool, + is_reduced_type: bool, + has_convert: bool, + is_inv_scale: bool, +): + qsdpa_v_pattern = _get_qsdpa_qkv_pattern( + qtype, is_batch_size_1, is_reduced_type, has_convert, "value" + ) + qsdpa_attn_pattern = _get_qsdpa_attn_pattern( + qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale + ) + return _generate_quant_pattern( + CallFunction( + aten.clone.default, + CallFunction( + aten.permute.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.bmm.default, + qsdpa_attn_pattern, + qsdpa_v_pattern, + ), + Arg(), + ), + Arg(), + ), + memory_format=Arg(), + ), + qtype, + "o_scale", + "o_zp" if qtype is torch.uint8 else None, + ) + + +def _register_qsdpa_lowerings(custom_pass_dict): + for ( + qtype, + has_mask, + is_batch_size_1, + is_reduced_type, + has_convert, + is_inv_scale, + ) in itertools.product( + quantize_dtypes, + [True, False], + [True, False], + [True, False], + [True, False], + [True, False], + ): + _register_qsdpa_pattern( + _get_qsdpa_final_pattern( + qtype=qtype, + has_mask=has_mask, + is_batch_size_1=is_batch_size_1, + is_reduced_type=is_reduced_type, + has_convert=has_convert, + is_inv_scale=is_inv_scale, + ), + custom_pass_dict, + ) + + +custom_pass = None +if torch_version_at_least("2.7.0"): + # PyTorch 2.7+ is needed for custom graph pass + from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files + + # define the custom pass + class _CustomPass(PatternMatcherPass, CustomGraphPass): + def __init__(self) -> None: + super().__init__() + + def __call__(self, g: torch.fx.graph.Graph): + self.apply(g) + + def uuid(self) -> bytes: + return get_hash_for_files((__file__,)) + + custom_pass = _CustomPass() + + +@functools.lru_cache(None) +def _qsdpa_init(): + if torch_version_at_least("2.7.0"): + _register_qsdpa_lowerings(config.post_grad_custom_pre_pass) + else: + pass diff --git a/torchao/prototype/inductor/int8_sdpa_lowering.py b/torchao/prototype/inductor/qsdpa_lowering.py similarity index 79% rename from torchao/prototype/inductor/int8_sdpa_lowering.py rename to torchao/prototype/inductor/qsdpa_lowering.py index 4fbff51c32..da6c1af0b4 100644 --- a/torchao/prototype/inductor/int8_sdpa_lowering.py +++ b/torchao/prototype/inductor/qsdpa_lowering.py @@ -3,7 +3,13 @@ import sympy import torch from torch._inductor.ir import ChoiceCaller, FixedLayout, TensorBox, get_fill_order -from torch._inductor.kernel.flex_attention import construct_strides, maybe_realize + +try: + # use the directory after refactor + from torch._inductor.kernel.flex.common import construct_strides, maybe_realize +except ImportError: + # use the old path for compatibility + from torch._inductor.kernel.flex_attention import construct_strides, maybe_realize from torch._inductor.lowering import register_lowering from torch._inductor.select_algorithm import ( ExternKernelChoice, @@ -12,20 +18,21 @@ from .codegen.cpp_int8_sdpa_template import CppInt8SdpaTemplate -op_int8_sdpa = ExternKernelChoice( +op_qsdpa = ExternKernelChoice( torch.ops.torchao.qscaled_dot_product.default, "torchao::qscaled_dot_product", has_out_variant=False, use_fallback_kernel=True, op_overload=torch.ops.torchao.qscaled_dot_product.default, ) +quantize_dtypes = [torch.uint8, torch.float8_e4m3fn] -def register_int8_sdpa(): +def register_qsdpa(): @register_lowering( torch.ops.torchao.qscaled_dot_product.default, type_promotion_kind=None ) - def int8_sdpa( + def qsdpa( query: TensorBox, key: TensorBox, value: TensorBox, @@ -61,12 +68,12 @@ def int8_sdpa( ) if ( - query.get_dtype() is not torch.uint8 - or key.get_dtype() is not torch.uint8 - or value.get_dtype() is not torch.uint8 + query.get_dtype() not in quantize_dtypes + or key.get_dtype() not in quantize_dtypes + or value.get_dtype() not in quantize_dtypes ): raise NotImplementedError( - "Only `torch.uint8` is supported in Int8 SDPA template for CPU device. " + "Only `torch.uint8` or `torch.float8_e4m3fn` is supported in Quantized SDPA template for CPU device. " f"Found input tensors are `{query.get_dtype()}`,`{key.get_dtype()}`,`{value.get_dtype()}`." ) @@ -85,8 +92,8 @@ def int8_sdpa( if attn_mask is not None: input_nodes.append(attn_mask) - # use template if machine has amx - if torch._C._cpu._is_amx_tile_supported(): + # use template if machine has amx, only support uint8 for now + if torch._C._cpu._is_amx_tile_supported() and query.get_dtype() is torch.uint8: CppInt8SdpaTemplate.add_choices( choices=choices, input_nodes=input_nodes, @@ -106,7 +113,7 @@ def int8_sdpa( if len(choices) == 0: choices.append( - op_int8_sdpa.bind( + op_qsdpa.bind( input_nodes=input_nodes, layout=layout, scale=scale, @@ -130,11 +137,11 @@ def int8_sdpa( ] return autotune_select_algorithm( - "int8_sdpa", + "qsdpa", choices, inputs_for_autotuning, layout, ) -register_int8_sdpa() +register_qsdpa() diff --git a/torchao/prototype/moe_quant/llama4_quant.py b/torchao/prototype/moe_quant/llama4_quant.py index 36e684d47d..ae6abccea5 100644 --- a/torchao/prototype/moe_quant/llama4_quant.py +++ b/torchao/prototype/moe_quant/llama4_quant.py @@ -58,7 +58,7 @@ def convert_fn(module): model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" -model = Llama4ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) +model = Llama4ForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) _replace_with_custom_fn_if_matches_filter( @@ -75,7 +75,12 @@ def convert_fn(module): ) from torchao.quantization import Int4WeightOnlyConfig, quantize_ -quantize_(model, MoEQuantConfig(Int4WeightOnlyConfig()), cond_ffn_filter, device="cuda") +quantize_( + model, + MoEQuantConfig(Int4WeightOnlyConfig(version=1)), + cond_ffn_filter, + device="cuda", +) model.cuda() diff --git a/torchao/prototype/moe_quant/utils.py b/torchao/prototype/moe_quant/utils.py index 0e75de2ee4..28291afdf4 100644 --- a/torchao/prototype/moe_quant/utils.py +++ b/torchao/prototype/moe_quant/utils.py @@ -20,18 +20,7 @@ dataclass, register_quantize_module_handler, ) -from torchao.utils import fill_defaults - - -class DummyModule(torch.nn.Module): - """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a - DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. - """ - - def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - super().__init__() - self.weight = weight - self.bias = bias +from torchao.utils import DummyModule, fill_defaults class FakeExtraDimTensor(torch.Tensor): diff --git a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py b/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py deleted file mode 100644 index c229eaeb71..0000000000 --- a/torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py -import argparse -import itertools -import time -from dataclasses import dataclass -from typing import List - -import torch -from tabulate import tabulate -from tqdm import tqdm - -from torchao.prototype.moe_training import _scaled_grouped_mm - -device = torch.device("cuda") - -# Needed since changing args to function causes recompiles -torch._dynamo.config.cache_size_limit = 1000 - - -@dataclass(frozen=True) -class ExperimentConfig: - high_precision_dtype: torch.dtype - A_shape: tuple[int] - B_shape: tuple[int] - - -@dataclass(frozen=True) -class ExperimentResult: - time_us: float - - -@dataclass(frozen=True) -class Experiment: - config: ExperimentConfig - result: ExperimentResult - - -def get_configs() -> List[ExperimentConfig]: - A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)] - B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)] - high_precision_dtypes = [torch.bfloat16] - configs = [] - for A_shape, B_shape, high_precision_dtype in itertools.product( - A_shapes, - B_shapes, - high_precision_dtypes, - ): - configs.append( - ExperimentConfig( - A_shape=A_shape, - B_shape=B_shape, - high_precision_dtype=high_precision_dtype, - ) - ) - return configs - - -def run_experiment( - config: ExperimentConfig, args: argparse.Namespace -) -> ExperimentResult: - # define test inputs - A = torch.randn( - *config.A_shape, - dtype=config.high_precision_dtype, - device=device, - requires_grad=True, - ) - B_t = torch.randn( - *config.B_shape, - dtype=config.high_precision_dtype, - device=device, - requires_grad=True, - ).transpose(-2, -1) - - # - configure input to be row-major with groups divided along the column dimension, - # representing the left operand of grad_weight = grad_output_t @ input - # that occurs in the backward pass of the differentiable scaled grouped mm. - # - the transposed tensor in col-major format with groups along the row dimension, - # which represents the right operand. - n_groups = config.B_shape[0] - group_size = A.shape[0] // n_groups - offs = torch.arange( - group_size, - group_size * n_groups + 1, - group_size, - device=device, - dtype=torch.int32, - ) - - def warmup(func, *args, **kwargs): - for _ in range(10): - func(*args, **kwargs) - - def forward_backward(A, B_t, offs): - out = _scaled_grouped_mm( - A, - B_t, - offs=offs, - out_dtype=torch.bfloat16, - ) - out.sum().backward() - torch.cuda.synchronize() - - # benchmark torch - torch_func = torch.compile(forward_backward) if args.compile else forward_backward - warmup(torch_func, A, B_t, offs) - start_time_ns = time.perf_counter_ns() - torch_func(A, B_t, offs) - torch_time_ns = time.perf_counter_ns() - start_time_ns - time_us = torch_time_ns / 1e3 - - return ExperimentResult( - time_us=round(time_us, 3), - ) - - -def print_results(experiments: List[Experiment]): - headers = [ - "A_shape", - "B_shape", - "time_us", - ] - rows = [] - for experiment in experiments: - A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})" - B_shape = f"({experiment.config.B_shape[0]}, {experiment.config.B_shape[1]}, {experiment.config.B_shape[2]})" - rows.append( - [ - A_shape, - B_shape, - experiment.result.time_us, - ] - ) - print(tabulate(rows, headers=headers)) - - -def main(args: argparse.Namespace): - torch.random.manual_seed(123) - configs = get_configs() - results = [] - for config in tqdm(configs): - result = run_experiment(config, args) - results.append(Experiment(config=config, result=result)) - - # Use Tabulate to print results - print_results(results) - - -if __name__ == "__main__": - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument("--compile", action="store_true") - args = arg_parser.parse_args() - main(args) diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 51af0fd956..c6492c9dbd 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -1,13 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import logging +from enum import Enum from typing import Callable, Optional from torch import nn from torchao.core.config import AOBaseConfig -from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor from torchao.quantization.transform_module import ( register_quantize_module_handler, ) +logger: logging.Logger = logging.getLogger(__name__) + + +class MoEScalingType(Enum): + FP8_ROWWISE = "fp8_rowwise" + MXFP8 = "mxfp8" + class MoETrainingConfig(AOBaseConfig): """ @@ -28,6 +41,10 @@ class MoETrainingConfig(AOBaseConfig): For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. """ + def __init__(self, scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE): + super().__init__() + self.scaling_type = scaling_type + @register_quantize_module_handler(MoETrainingConfig) def _moe_training_transform( @@ -68,6 +85,8 @@ def _swap_params( Returns: nn.Module: The modified module with swapped linear layers. """ + from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor + if isinstance(module, nn.Parameter) and ( module_filter_fn is None or module_filter_fn(module, "") ): @@ -76,7 +95,7 @@ def _swap_params( f"Does not support a root nn.Parameter with children: {module}" ) if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data) + new_data = ScaledGroupedMMTensor(module.data, config.scaling_type) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module @@ -102,10 +121,13 @@ def post_order_traversal( for param_name, param in module.named_parameters(recurse=False): if not isinstance(param.data, ScaledGroupedMMTensor): new_param = nn.Parameter( - ScaledGroupedMMTensor(param), requires_grad=param.requires_grad + ScaledGroupedMMTensor(param.data, config.scaling_type), + requires_grad=param.requires_grad, ) setattr(module, param_name, new_param) - print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor") + logger.info( + f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor" + ) post_order_traversal(root_module) return root_module diff --git a/torchao/prototype/moe_training/kernels/__init__.py b/torchao/prototype/moe_training/kernels/__init__.py index b5446849b6..0b88cc08a2 100644 --- a/torchao/prototype/moe_training/kernels/__init__.py +++ b/torchao/prototype/moe_training/kernels/__init__.py @@ -1,6 +1,9 @@ +from torchao.prototype.moe_training.kernels.float8_rowwise import ( + triton_fp8_rowwise_3d_transpose_rhs as triton_fp8_rowwise_3d_transpose_rhs, +) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales, + triton_fp8_per_group_colwise_scales as triton_fp8_per_group_colwise_scales, ) from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( - triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales, ) diff --git a/torchao/prototype/moe_training/kernels/float8_rowwise.py b/torchao/prototype/moe_training/kernels/float8_rowwise.py new file mode 100644 index 0000000000..7d83090741 --- /dev/null +++ b/torchao/prototype/moe_training/kernels/float8_rowwise.py @@ -0,0 +1,469 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +EPS = 1e-12 + +FP8_DTYPE_MAP = { + torch.int8: tl.int8, + torch.int16: tl.int16, + torch.int32: tl.int32, + torch.int64: tl.int64, + torch.float8_e4m3fn: tl.float8e4nv, + torch.float8_e5m2: tl.float8e5, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32, + torch.float64: tl.float64, +} + +block_sizes_n = [128] # large dim (output_features) +block_sizes_k = [128] # small dim (input_features) +num_warps = [4] +num_stages = [4] +atomic_kernel_configs_2D = [ + triton.Config( + {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}, + num_warps=warps, + num_stages=stages, + ) + for block_size_n in block_sizes_n + for block_size_k in block_sizes_k + for warps in num_warps + for stages in num_stages +] + + +@torch.library.custom_op("torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={}) +def triton_fp8_rowwise_3d_transpose_rhs( + hp_tensor: torch.Tensor, # (E, K, N) + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 3, "input tensor must be 3D" + + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] + tl_output_dtype = FP8_DTYPE_MAP[output_dtype] + + fp8_dtype_min = torch.finfo(output_dtype).min + fp8_dtype_max = torch.finfo(output_dtype).max + + e, k, n = hp_tensor.shape + + # allocate on-device buffers for output and scales + # output shape = input.transpose(-2, -1).shape = (E, N, K) in column major layout + output_buffer = torch.empty( + (e, n, k), dtype=output_dtype, device=hp_tensor.device + ).as_strided((e, n, k), (n * k, 1, n)) + + scales_buffer = torch.full( + (e, k), float("inf"), dtype=torch.float32, device=hp_tensor.device + ) + + # parallelize across experts, and for each expert, parallelize across rows and cols + grid = lambda meta: ( + e, + triton.cdiv(k, meta["BLOCK_SIZE_K"]), + triton.cdiv(n, meta["BLOCK_SIZE_N"]), + ) + + # compute scales + _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel[grid]( + hp_tensor, + hp_tensor.stride(0), + hp_tensor.stride(1), + hp_tensor.stride(2), + scales_buffer, + scales_buffer.stride(0), + scales_buffer.stride(1), + e, + n, + k, + fp8_dtype_min, + fp8_dtype_max, + tl_input_dtype, + round_scales_to_power_of_2=round_scales_to_power_of_2, + EPS=EPS, + ) + + # perform casting + _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel[grid]( + hp_tensor, + hp_tensor.stride(0), + hp_tensor.stride(1), + hp_tensor.stride(2), + output_buffer, + output_buffer.stride(0), + output_buffer.stride(1), + output_buffer.stride(2), + scales_buffer, + scales_buffer.stride(0), + scales_buffer.stride(1), + e, + n, + k, + fp8_dtype_min, + fp8_dtype_max, + tl_input_dtype, + tl_output_dtype, + ) + return output_buffer, scales_buffer + + +@triton_fp8_rowwise_3d_transpose_rhs.register_fake +def _fake_triton_fp8_rowwise_3d_transpose_rhs( + hp_tensor: torch.Tensor, # (E, K, N) + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 3, "input tensor must be 3D" + e, k, n = hp_tensor.shape + output_buffer = torch.empty( + (e, n, k), dtype=output_dtype, device=hp_tensor.device + ).as_strided((e, n, k), (n * k, 1, n)) + + scales_buffer = torch.empty((e, k), dtype=torch.float32, device=hp_tensor.device) + return output_buffer, scales_buffer + + +@triton.autotune(configs=atomic_kernel_configs_2D, key=["K", "N"]) +@triton.jit +def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel( + input_ptr, + stride_input_dim0: tl.int64, + stride_input_dim1, + stride_input_dim2, + scales_ptr, + stride_scales_dim0: int, + stride_scales_dim1, + E: int, + N: int, + K: int, + fp8_dtype_min: tl.constexpr, + fp8_dtype_max: tl.constexpr, + input_dtype: tl.constexpr, + round_scales_to_power_of_2: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + EPS: tl.constexpr, +): + # parallelize across experts, rows, and cols + expert_idx = tl.program_id(0) + k_block_idx = tl.program_id(1) + n_block_idx = tl.program_id(2) + + # compute offsets for each dimension + k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + n_offs = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # load block of input data, shape (K, N) + input_offs = ( + expert_idx * stride_input_dim0 + + k_offs[:, None] * stride_input_dim1 + + (n_offs[None, :] * stride_input_dim2) + ) + input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N) + input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0) + + # In a normal torch implementation, we should transpose the tensor then compute the amax + # along the dim1 (N), to compute colwise scales for a RHS operand of a scaled grouped gemm: + # input_data = input_data.transpose(-2,-1) # (E, K, N) -> (E, N, K) + # amaxes = input_data.abs().max(dim=1) # (E, N, K) -> (E, 1, K) + # + # Here, we are reading a (K, N) chunk for a given E, and computing the amax along the dim=1 (N) + # to compute an equivalent scale of shape (K,) for this chunk of the expert. + # We then use atomic min to compute the final scale for these logical columns of the transposed tensor. + # + # Later, we will use this scale to cast the same (K,N) input chunk to fp8 and transpose it to (N, K) before + # writing it to the output tensor. + # ((K, N) * (K, 1))^T = (N, K) + amaxes = tl.max(tl.abs(input_data), axis=1).to(tl.float64) # (K,) + scales = (fp8_dtype_max / tl.clamp(amaxes, min=EPS, max=float("inf"))).to( + tl.float32 + ) + if round_scales_to_power_of_2: + scales = tl.exp2(tl.floor(tl.log2(scales))) + + # compute global scales using atomics with local scales - shape (1, K) + scales_offs = ( + expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1 + ) + scales_mask = k_offs[None, :] < K + tl.atomic_min(scales_ptr + scales_offs, scales[None, :], mask=scales_mask) + + +@triton.autotune(configs=atomic_kernel_configs_2D, key=["num_elements"]) +@triton.jit +def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel( + input_ptr, + stride_input_dim0: tl.int64, + stride_input_dim1, + stride_input_dim2, + output_ptr, + stride_output_dim0: tl.int64, + stride_output_dim1, + stride_output_dim2, + scales_ptr, + stride_scales_dim0: int, + stride_scales_dim1, + E: int, + N: int, + K: int, + fp8_dtype_min: tl.constexpr, + fp8_dtype_max: tl.constexpr, + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + # parallelize across experts, rows, and cols + expert_idx = tl.program_id(0) + k_block_idx = tl.program_id(1) + n_block_idx = tl.program_id(2) + + # compute offsets for each dimension + k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + n_offs = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # load block of input data for this expert - shape (K, N) + input_offs = ( + expert_idx * stride_input_dim0 + + k_offs[:, None] * stride_input_dim1 + + (n_offs[None, :] * stride_input_dim2) + ) + input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N) + input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0) + input_data = input_data.trans(1, 0) # (K, N) -> (N, K) + + # load global scales for this block of the given expert - shape (1, K) + scales_offs = ( + expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1 + ) + scales_mask = k_offs[None, :] < K + scales = tl.load(scales_ptr + scales_offs, mask=scales_mask, other=0.0) + + # transpose data and apply scales - shape (N,K) * (1,K) = (N,K) + output_data = tl.clamp( + input_data * scales, min=fp8_dtype_min, max=fp8_dtype_max + ).to(output_dtype) + + # store transpose and store output data - shape (N, K) + output_offs = ( + expert_idx * stride_output_dim0 + + n_offs[:, None] * stride_output_dim1 + + (k_offs[None, :] * stride_output_dim2) + ) + output_mask = (n_offs[:, None] < N) & (k_offs[None, :] < K) + tl.store(output_ptr + output_offs, output_data, mask=output_mask) + + +block_sizes_n = [ + 64, +] # large dim (output_features) +block_sizes_k = [128] # small dim (input_features) +num_warps = [8] +num_stages = [6] +reduction_kernel_configs_2D = [ + triton.Config( + {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}, + num_warps=warps, + num_stages=stages, + ) + for block_size_n in block_sizes_n + for block_size_k in block_sizes_k + for warps in num_warps + for stages in num_stages +] + + +@triton.autotune(configs=reduction_kernel_configs_2D, key=["K", "N"]) +@triton.jit +def _triton_fp8_rowwise_3d_transpose_rhs_fused_reduction_kernel( + input_ptr, + stride_input_dim0: tl.int64, + stride_input_dim1, + stride_input_dim2, + output_ptr, + stride_output_dim0: tl.int64, + stride_output_dim1, + stride_output_dim2, + scales_ptr, + stride_scales_dim0: int, + stride_scales_dim1, + E: int, + N: int, + K: int, + fp8_dtype_min: tl.constexpr, + fp8_dtype_max: tl.constexpr, + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + round_scales_to_power_of_2: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + EPS: tl.constexpr, +): + # This kernel parallelizes across experts and K blocks + # Each program computes scales for one K block of one expert + expert_idx = tl.program_id(0) + k_block_idx = tl.program_id(1) + + # Compute K offsets for this block + k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + k_mask = k_offs < K + + # Initialize row maxes for this K block + row_maxes = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float64) - float("inf") + + # First pass: compute row-wise maximum absolute values across all N + for n_block_start in range(0, N, BLOCK_SIZE_N): + n_offs = n_block_start + tl.arange(0, BLOCK_SIZE_N) + n_mask = n_offs < N + + # Load block of input data - shape (K, N) + input_offs = ( + expert_idx * stride_input_dim0 + + k_offs[:, None] * stride_input_dim1 + + n_offs[None, :] * stride_input_dim2 + ) + input_mask = k_mask[:, None] & n_mask[None, :] + input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0) + + # Compute row-wise max for this N block + block_row_maxes = tl.max(tl.abs(input_data), axis=1) + + # Update running maxes + row_maxes = tl.maximum(row_maxes, block_row_maxes) + + # Convert row maxes to scales + clamped_maxes = tl.clamp(row_maxes, min=EPS, max=float("inf")) + scales = (fp8_dtype_max / clamped_maxes.to(tl.float64)).to(tl.float32) + + if round_scales_to_power_of_2: + scales = tl.exp2(tl.floor(tl.log2(scales))) + + # Store computed scales for this K block + scales_offs = expert_idx * stride_scales_dim0 + k_offs * stride_scales_dim1 + tl.store(scales_ptr + scales_offs, scales, mask=k_mask) + + # Second pass: apply scales and transpose data for output + for n_block_start in range(0, N, BLOCK_SIZE_N): + n_offs = n_block_start + tl.arange(0, BLOCK_SIZE_N) + n_mask = n_offs < N + + # Load block of input data - shape (K, N) + input_offs = ( + expert_idx * stride_input_dim0 + + k_offs[:, None] * stride_input_dim1 + + n_offs[None, :] * stride_input_dim2 + ) + input_mask = k_mask[:, None] & n_mask[None, :] + input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0) + + # Transpose data: (K, N) -> (N, K) + input_data_transposed = input_data.trans(1, 0) + + # Apply scales: (N, K) * (1, K) = (N, K) + scaled_data = input_data_transposed * scales[None, :] + + # Clamp and cast to output dtype + output_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to( + output_dtype + ) + + # Store transposed output - shape (N, K) + output_offs = ( + expert_idx * stride_output_dim0 + + n_offs[:, None] * stride_output_dim1 + + k_offs[None, :] * stride_output_dim2 + ) + output_mask = n_mask[:, None] & k_mask[None, :] + tl.store(output_ptr + output_offs, output_data, mask=output_mask) + + +@torch.library.custom_op( + "torchao::triton_fp8_rowwise_transpose_rhs_fused", mutates_args={} +) +def triton_fp8_rowwise_3d_transpose_rhs_fused_reduction( + hp_tensor: torch.Tensor, # (E, K, N) + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Equivalent fused Triton kernel to triton_fp8_rowwise_3d_transpose_rhs that uses + reduction to calculate rowwise scales instead of atomic operations. + + This kernel fuses the scale computation and casting into a single kernel, + avoiding the need for atomic operations by using reduction operations. + """ + assert hp_tensor.ndim == 3, "input tensor must be 3D" + + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] + tl_output_dtype = FP8_DTYPE_MAP[output_dtype] + + fp8_dtype_min = torch.finfo(output_dtype).min + fp8_dtype_max = torch.finfo(output_dtype).max + + e, k, n = hp_tensor.shape + + # allocate on-device buffers for output and scales + # output shape = input.transpose(-2, -1).shape = (E, N, K) in column major layout + output_buffer = torch.empty( + (e, n, k), dtype=output_dtype, device=hp_tensor.device + ).as_strided((e, n, k), (n * k, 1, n)) + + scales_buffer = torch.empty((e, k), dtype=torch.float32, device=hp_tensor.device) + + # Use a grid that parallelizes across experts and K blocks + # Each program handles one K block of one expert + grid = lambda meta: (e, triton.cdiv(k, meta["BLOCK_SIZE_K"]), 1) + + # Single fused kernel that computes scales using reduction and performs casting + _triton_fp8_rowwise_3d_transpose_rhs_fused_reduction_kernel[grid]( + hp_tensor, + hp_tensor.stride(0), + hp_tensor.stride(1), + hp_tensor.stride(2), + output_buffer, + output_buffer.stride(0), + output_buffer.stride(1), + output_buffer.stride(2), + scales_buffer, + scales_buffer.stride(0), + scales_buffer.stride(1), + e, + n, + k, + fp8_dtype_min, + fp8_dtype_max, + tl_input_dtype, + tl_output_dtype, + round_scales_to_power_of_2=round_scales_to_power_of_2, + EPS=EPS, + ) + + return output_buffer, scales_buffer + + +@triton_fp8_rowwise_3d_transpose_rhs_fused_reduction.register_fake +def _fake_triton_fp8_rowwise_3d_transpose_rhs_fused_reduction( + hp_tensor: torch.Tensor, # (E, K, N) + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 3, "input tensor must be 3D" + e, k, n = hp_tensor.shape + output_buffer = torch.empty( + (e, n, k), dtype=output_dtype, device=hp_tensor.device + ).as_strided((e, n, k), (n * k, 1, n)) + + scales_buffer = torch.empty((e, k), dtype=torch.float32, device=hp_tensor.device) + return output_buffer, scales_buffer diff --git a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py index 3a497bf4a6..f3bda41b1e 100644 --- a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py +++ b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py @@ -16,8 +16,6 @@ import triton import triton.language as tl -from torchao.prototype.moe_training.utils import _is_column_major - EPS = 1e-12 FP8_DTYPE_MAP = { @@ -33,17 +31,27 @@ torch.float64: tl.float64, } -block_sizes = [128, 256] +block_sizes = [32] # [16, 32, 64] +block_sizes_iter = [128] # [64, 128, 256] +num_warps = [4] +num_stages = [3] kernel_configs_2D = [ triton.Config( - {"BLOCK_SIZE_ROWS": block_size_rows, "BLOCK_SIZE_COLS": block_size_cols} + {"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter}, + num_warps=warps, + num_stages=stages, ) - for block_size_rows in block_sizes - for block_size_cols in block_sizes + for block_size in block_sizes + for block_size_iter in block_sizes_iter + for warps in num_warps + for stages in num_stages ] -def triton_fp8_row_major_jagged_rowwise_scales( +@torch.library.custom_op( + "torchao::triton_fp8_per_group_rowwise_scales", mutates_args={} +) +def triton_fp8_per_group_rowwise_scales( hp_tensor: torch.Tensor, offsets: torch.Tensor, output_dtype: torch.dtype = torch.float8_e4m3fn, @@ -65,7 +73,6 @@ def triton_fp8_row_major_jagged_rowwise_scales( - jagged rowwise scales (i.e., rowwise scales for each group) """ assert hp_tensor.ndim == 2, "input tensor must be 2D" - assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -78,19 +85,17 @@ def triton_fp8_row_major_jagged_rowwise_scales( n_groups = offsets.numel() # allocate on-device buffers for output and scales - output_buffer = torch.empty_like( - hp_tensor, dtype=output_dtype, device=hp_tensor.device - ) + output_buffer = torch.empty((m, k), dtype=output_dtype, device=hp_tensor.device) scales_buffer = torch.empty( (m * n_groups), dtype=torch.float32, device=hp_tensor.device ) # parallelize across rows and groups (offsets) grid = lambda meta: ( - triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]), + triton.cdiv(m, meta["BLOCK_SIZE"]), offsets.numel(), ) - _triton_fp8_row_major_jagged_rowwise_scales[grid]( + _triton_fp8_per_group_rowwise_scales_kernel[grid]( hp_tensor, offsets, output_buffer, @@ -112,9 +117,33 @@ def triton_fp8_row_major_jagged_rowwise_scales( return output_buffer, scales_buffer -@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) +@triton_fp8_per_group_rowwise_scales.register_fake +def _fake_triton_fp8_per_group_rowwise_scales_kernel( + hp_tensor: torch.Tensor, + offsets: torch.Tensor, + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 2, "input tensor must be 2D" + m, k = hp_tensor.shape + n_groups = offsets.numel() + output = torch.empty_like(hp_tensor, dtype=output_dtype).as_strided( + (m, k), # shape + (k, 1), # stride + ) + scales = torch.empty((m * n_groups), dtype=torch.float32, device=hp_tensor.device) + return output, scales + + +# This kernel is used on grad_output.t() which has shape (K, M), +# before the calculation `grad_B = grad_output_t @ input`. +# However, in this code, we use the conventional dim names (M, K) +# so the kernel is easily interpretable in a standalone fasion. +# The tokens per expert will vary per iteration, so don't want +# to recompile on `token` dim (K, in this case) changes. +@triton.autotune(configs=kernel_configs_2D, key=["M"]) @triton.jit -def _triton_fp8_row_major_jagged_rowwise_scales( +def _triton_fp8_per_group_rowwise_scales_kernel( input_ptr, offsets_ptr, out_ptr, @@ -131,8 +160,8 @@ def _triton_fp8_row_major_jagged_rowwise_scales( input_dtype: tl.constexpr, output_dtype: tl.constexpr, round_scales_to_power_of_2: tl.constexpr, - BLOCK_SIZE_ROWS: tl.constexpr, - BLOCK_SIZE_COLS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_ITER: tl.constexpr, EPS: tl.constexpr, ): # parallel across rows and groups (offsets) @@ -144,12 +173,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales( offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0 ) group_col_end_idx = tl.load(offsets_ptr + offset_idx) - block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS) + block_row_offs = block_row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # compute rowwise amaxes for this group - amax_buffer = tl.zeros((BLOCK_SIZE_ROWS,), dtype=input_dtype) - for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS): - block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS) + amax_buffer = tl.zeros((BLOCK_SIZE,), dtype=input_dtype) + for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_ITER): + block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_ITER) block_offs = ( block_row_offs[:, None] * stride_input_row + block_col_offs[None, :] * stride_input_col @@ -177,12 +206,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales( # store rowwise scales for each group in contiguous memory: # [group0_row0, group_0_row1, ..., group2_row0, group2_row1] scales_offs = block_row_offs + (M * offset_idx) - scales_mask = tl.arange(0, BLOCK_SIZE_ROWS) < M + scales_mask = tl.arange(0, BLOCK_SIZE) < M tl.store(scales_ptr + scales_offs, scales, mask=scales_mask) # perform float8 conversion for this group - for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS): - block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS) + for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_ITER): + block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_ITER) block_offs = ( block_row_offs[:, None] * stride_input_row + block_col_offs[None, :] * stride_input_col @@ -204,7 +233,10 @@ def _triton_fp8_row_major_jagged_rowwise_scales( tl.store(out_ptr + out_offs, fp8_data, mask=block_mask) -def triton_fp8_col_major_jagged_colwise_scales( +@torch.library.custom_op( + "torchao::triton_fp8_per_group_colwise_scales", mutates_args={} +) +def triton_fp8_per_group_colwise_scales( hp_tensor: torch.Tensor, offsets: torch.Tensor, output_dtype: torch.dtype = torch.float8_e4m3fn, @@ -226,7 +258,6 @@ def triton_fp8_col_major_jagged_colwise_scales( - jagged column-wise scales (i.e., column-wise scales for each group) """ assert hp_tensor.ndim == 2, "input tensor must be 2D" - assert _is_column_major(hp_tensor), "input tensor must be column-major" num_elements = hp_tensor.numel() tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -238,20 +269,21 @@ def triton_fp8_col_major_jagged_colwise_scales( k, n = hp_tensor.shape n_groups = offsets.numel() - # allocate on-device buffers for output and scales + # Output buffer in column major output_buffer = torch.empty_like( hp_tensor, dtype=output_dtype, device=hp_tensor.device - ) + ).as_strided(hp_tensor.size(), (1, k)) + scales_buffer = torch.empty( (n * n_groups), dtype=torch.float32, device=hp_tensor.device ) # parallelize across columns and groups (offsets) grid = lambda meta: ( - triton.cdiv(n, meta["BLOCK_SIZE_COLS"]), + triton.cdiv(n, meta["BLOCK_SIZE"]), offsets.numel(), ) - _triton_fp8_col_major_jagged_colwise_scales[grid]( + _triton_fp8_per_group_colwise_scales_kernel[grid]( hp_tensor, offsets, output_buffer, @@ -273,9 +305,33 @@ def triton_fp8_col_major_jagged_colwise_scales( return output_buffer, scales_buffer -@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) +@triton_fp8_per_group_colwise_scales.register_fake +def _fake_triton_fp8_per_group_colwise_scales( + hp_tensor: torch.Tensor, + offsets: torch.Tensor, + output_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert hp_tensor.ndim == 2, "input tensor must be 2D" + k, n = hp_tensor.shape + n_groups = offsets.numel() + output_buffer = torch.empty_like( + hp_tensor, dtype=output_dtype, device=hp_tensor.device + ).as_strided(hp_tensor.size(), (1, k)) + + scales_buffer = torch.empty( + (n * n_groups), dtype=torch.float32, device=hp_tensor.device + ) + return output_buffer, scales_buffer + + +# This kernel is used on `input` which has shape (M, K), +# before the calculation `grad_B = grad_output_t @ input`. +# The tokens per expert will vary per iteration, so don't want +# to recompile on `token` dim (M) changes. +@triton.autotune(configs=kernel_configs_2D, key=["K"]) @triton.jit -def _triton_fp8_col_major_jagged_colwise_scales( +def _triton_fp8_per_group_colwise_scales_kernel( input_ptr, offsets_ptr, out_ptr, @@ -292,8 +348,8 @@ def _triton_fp8_col_major_jagged_colwise_scales( input_dtype: tl.constexpr, output_dtype: tl.constexpr, round_scales_to_power_of_2: tl.constexpr, - BLOCK_SIZE_ROWS: tl.constexpr, - BLOCK_SIZE_COLS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_ITER: tl.constexpr, EPS: tl.constexpr, ): # parallel across columns and groups (offsets) @@ -305,12 +361,12 @@ def _triton_fp8_col_major_jagged_colwise_scales( offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0 ) group_row_end_idx = tl.load(offsets_ptr + offset_idx) - block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS) + block_col_offs = block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # compute colwise amaxes for this group - amax_buffer = tl.zeros((BLOCK_SIZE_COLS,), dtype=input_dtype) - for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS): - block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS) + amax_buffer = tl.zeros((BLOCK_SIZE,), dtype=input_dtype) + for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ITER): + block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ITER) block_offs = ( block_row_offs[:, None] * stride_input_row + block_col_offs[None, :] * stride_input_col @@ -339,12 +395,12 @@ def _triton_fp8_col_major_jagged_colwise_scales( # [group0_col0, group_0_col1, ..., group2_col0, group2_col1] # note: input tensor is in col-major memory layout. scales_offs = block_col_offs + (N * offset_idx) - scales_mask = tl.arange(0, BLOCK_SIZE_COLS) < N + scales_mask = tl.arange(0, BLOCK_SIZE) < N tl.store(scales_ptr + scales_offs, scales, mask=scales_mask) # perform float8 conversion for this group - for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS): - block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS) + for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ITER): + block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ITER) block_offs = ( block_row_offs[:, None] * stride_input_row + block_col_offs[None, :] * stride_input_col diff --git a/torchao/prototype/moe_training/kernels/mxfp8.py b/torchao/prototype/moe_training/kernels/mxfp8.py new file mode 100644 index 0000000000..353688f185 --- /dev/null +++ b/torchao/prototype/moe_training/kernels/mxfp8.py @@ -0,0 +1,725 @@ +import logging +from typing import Tuple + +import torch +import triton +import triton.language as tl +from torch import Tensor +from torch.library import triton_op, wrap_triton + +from torchao.prototype.mx_formats.utils import to_blocked +from torchao.utils import ( + ceil_div, + is_sm_at_least_100, +) + + +def torch_to_blocked_2d_M_groups( + x_scales: Tensor, group_offs: Tensor, K: int, block_size: int = 32 +) -> Tuple[Tensor, Tensor]: + """ + Convert scales to blocked format for a 2D tensor (input activations / token groups), + where groups are along the total_M dimension (rows). + + Args: + x_scales: Tensor with per group scales in blocked format concatenated into one tensor. + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the total_M dimension. + total_M: total size of all groups summed together + K: K dim size + + Returns: + blocked_scales: Tensor + start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group. + """ + + assert x_scales.ndim == 2, "x_scales must be 2D" + assert block_size == 32, "Only block_size=32 is supported for now" + total_M, _ = x_scales.shape + num_groups = group_offs.shape[0] + + # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group, + # the Triton kernenl will use an upper bound of adding 128 padding rows to each group. + # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl). + total_M_padded = total_M + num_groups * 128 + blocked_scales = x_scales.new_zeros(total_M_padded, K // block_size) + start_row_after_padding_list = [0] + group_start_idx = 0 + for i, group_end_idx in enumerate(group_offs.tolist()): + group_size = group_end_idx - group_start_idx + prev_start_row_after_padding = start_row_after_padding_list[i] + if group_size == 0: + start_row_after_padding_list.append(prev_start_row_after_padding) + continue + + # Convert group scales to blocked format + group_scales = x_scales[group_start_idx:group_end_idx] + group_scales_blocked = to_blocked(group_scales) + + # Calculate the start row after padding + scaling_groups_per_row = K // block_size + rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row + new_start_row = prev_start_row_after_padding + rows_for_group + start_row_after_padding_list.append(new_start_row) + + # Write output to subtensor + group_rows_padded = ceil_div(group_size, 128) * 128 + blocked_scales[ + prev_start_row_after_padding : prev_start_row_after_padding + + group_rows_padded, + :, + ] = group_scales_blocked.reshape(-1, K // block_size) + + # Update next group start index + group_start_idx = group_end_idx + + start_row_after_padding = torch.tensor( + start_row_after_padding_list, device=x_scales.device, dtype=torch.int64 + ) + return blocked_scales, start_row_after_padding + + +def torch_to_blocked_2d_K_groups( + x_scales: Tensor, group_offs: Tensor, block_size: int = 32 +) -> Tuple[Tensor, Tensor]: + """ + Convert scales to blocked format for a 2D tensor (input activations), + when groups are along the scaled (K) dimension. + + Args: + x_scales: Tensor with per group scales in blocked format concatenated into one tensor. + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the total_k dimension. + total_K: total size of all groups summed together + + Returns: + blocked_scales: Tensor + start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group. + """ + assert x_scales.ndim == 2, "x_scales must be 2D" + assert block_size == 32, "Only block_size=32 is supported for now" + M, total_K = x_scales.shape + padded_M = ceil_div(M, 128) * 128 + num_groups = group_offs.shape[0] + + # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group, + # Triton kernel will use an upper bound of adding 4 padding cols to each group. + # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl). + total_K_padded = total_K + num_groups * 4 + blocked_scales = x_scales.new_zeros(padded_M, total_K_padded) + + start_col_after_padding_list = [0] + group_start_idx = 0 + for i, group_end_idx in enumerate(group_offs.tolist()): + group_size = group_end_idx - group_start_idx + prev_start_col_after_padding = start_col_after_padding_list[i] + if group_size == 0: + start_col_after_padding_list.append(prev_start_col_after_padding) + continue + + # Convert group scales to blocked format + group_scales = x_scales[:, group_start_idx:group_end_idx] + group_scales_blocked = to_blocked(group_scales) + cols_after_padding = ceil_div(group_size, 4) * 4 + + # Write output to subtensor + blocked_scales[ + :, + prev_start_col_after_padding : prev_start_col_after_padding + + cols_after_padding, + ] = group_scales_blocked.reshape(-1, cols_after_padding) + + # Calculate the start row after padding + new_start_col = prev_start_col_after_padding + cols_after_padding + start_col_after_padding_list.append(new_start_col) + + # Update next group start index + group_start_idx = group_end_idx + + start_cols_after_padding = torch.tensor( + start_col_after_padding_list, device=x_scales.device, dtype=torch.int64 + ) + return blocked_scales, start_cols_after_padding + + +def torch_to_blocked_per_group_3d(weight_scales: Tensor) -> Tensor: + """ + Convert scales to blocked format for each group for a 3D tensor (expert weights) + + Args: + scales: Tensor of shape (E, N, K//block_size) + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the + """ + + blocked_scales_list = [] + num_groups = weight_scales.shape[0] + for i in range(num_groups): + group_scales = weight_scales[i] + group_scales_blocked = to_blocked(group_scales) + blocked_scales_list.append(group_scales_blocked) + weight_scales_blocked = torch.stack(blocked_scales_list, dim=0).contiguous() + weight_scales_blocked = weight_scales_blocked.reshape(num_groups, -1) + return weight_scales_blocked + + +def compute_blocked_scale_offsets_for_M_groups(offsets: torch.Tensor): + """ + Given a 1D tensor of input group offsets along the total_M dimension (rows), + compute the starting row offset of the scales for each group after padding to blocked format. + + In effect, this rrounds each integer in a 1D PyTorch tensor up to the nearest multiple of 128. + + Args: + - offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the total_M dimension. + + Returns: + - group_sizes: A 1D PyTorch tensor of integers representing the size of each group. + - starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. + """ + # Calculate group sizes + zero = torch.tensor([0], dtype=offsets.dtype, device=offsets.device) + group_sizes = torch.diff(offsets, prepend=zero) + + # Round each group size up to the nearest multiple of 128 + rounded_group_sizes = ceil_div(group_sizes, 128) * 128 + + # Calculate the starting row after padding for each group + starting_row_after_padding = torch.cumsum(rounded_group_sizes, dim=0) + + # Must start with 0 + starting_row_after_padding = torch.cat([zero, starting_row_after_padding]) + return group_sizes, starting_row_after_padding + + +def compute_blocked_scale_offsets_for_K_groups( + scale_group_offsets: torch.Tensor, block_size: int = 32 +): + """ + Performs round_up(x, 4) on each element in a 1D offsets tensor, + to compute the starting offsets of each group after scaling along the contraction dimension. + + Args: + offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the total_M dimension. + + Returns: + - starting_col_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. + """ + # Calculate group sizes + zero = torch.tensor( + [0], dtype=scale_group_offsets.dtype, device=scale_group_offsets.device + ) + group_sizes = torch.diff(scale_group_offsets, prepend=zero) + + # After scaling with block_size 32, each group size is rounded up to the nearest multiple of 4 + rounded_group_sizes = ceil_div(group_sizes, 4) * 4 + + # Calculate the starting row after padding for each group + starting_col_after_padding = torch.cumsum(rounded_group_sizes, dim=0) + + # Must start with 0 + starting_col_after_padding = torch.cat([zero, starting_col_after_padding]) + return group_sizes, starting_col_after_padding + + +@triton_op("torchao::triton_mx_block_rearrange_2d_M_groups", mutates_args={}) +def triton_mx_block_rearrange_2d_M_groups( + scales_tensor: torch.Tensor, + input_group_end_offsets: torch.Tensor, + output_group_start_offsets: torch.Tensor, +) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale to block-scaled swizzle format, + where groups are along the total_M dimension (rows). + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor. + input_group_end_offsets: tensor of int32 values representing group end indexes for the input scales + output_group_start_offsets: tensor of int32 values representing pre-computed group start indexes after blocked format padding + Returns: + - Rearranged tensor in block-scaled swizzle format + """ + assert scales_tensor.ndim == 2, "scales tensor must be 2d" + assert scales_tensor.element_size() == 1, ( + "Expected element size to be 1 byte (8 bits)" + ) + rows, cols = scales_tensor.shape + num_groups = input_group_end_offsets.shape[0] + + # Final offset is the total number of rows in the tensor. + # Padding needing per group is variable/data dependent, so we just pad each group by + # the upper bound of 128 rows to avoid a d2h sync caused by iterating over each group. + padded_rows = rows + num_groups * 128 + + num_col_blocks = ceil_div(cols, 4) + padded_cols = num_col_blocks * 4 + output = scales_tensor.new_zeros((padded_rows, padded_cols)) + + # Output block stride for the rearranged format + BLOCK_ROWS, BLOCK_COLS = 128, 4 + output_stride_per_block = BLOCK_ROWS * BLOCK_COLS + output_stride_per_row_of_blocks = ( + BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS) + ) + + # We parallelize per group and per col block. + # Rows per group is variable so we just loop through row blocks per group, per col block. + grid = lambda META: ( + num_groups, + num_col_blocks, + ) + wrap_triton(triton_scale_swizzle_M_groups)[grid]( + # Input scales + scales_tensor.view(torch.uint8), + scales_tensor.stride(0), + scales_tensor.stride(1), + rows, + cols, + num_groups, + # Original offsets (to read from) + input_group_end_offsets, + # Output scales tensor and group offsets after padding (to write to) + output.view(torch.uint8), + output.stride(0), + output_group_start_offsets, + output_stride_per_block, + output_stride_per_row_of_blocks, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + return output + + +@triton.jit +def triton_scale_swizzle_M_groups( + scales_ptr, # (M, K//block_size) + scales_stride_dim0, + scales_stride_dim1, + scale_rows, + scale_cols, + num_groups, + orig_offsets, # (num_groups,) + output_scales_ptr, + output_scales_stride_dim0, + output_scales_group_offsets, # (num_groups,) + output_stride_per_block, + output_stride_per_row_of_blocks, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, +): + group_pid = tl.program_id(0) + block_col_pid = tl.program_id(1) + # Input scales row range for this group + input_group_start_row = tl.load( + orig_offsets + group_pid - 1, mask=group_pid > 0, other=0 + ) + input_group_end_row = tl.load( + orig_offsets + group_pid, mask=group_pid < num_groups, other=0 + ) + # Output scales start row we will begin writing to + output_group_start_row = tl.load( + output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0 + ) + # Calculate destination indices for each row and col in block swizzled layout. + # We can reuse this swizzle transformation on each block of data we read. + row_offs = tl.arange(0, BLOCK_ROWS)[:, None] + col_offs = tl.arange(0, BLOCK_COLS)[None, :] + + # Compute desination indices for each elem in block swizzled layout + dest_indices_flat = _dest_indices_for_block( + row_offs, + col_offs, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + + # For this group and col block, we iterate through row blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales. + # We track how many row blocks we have iterated through. + block_row_id = 0 + current_start_row = input_group_start_row + + # TODO: Investigate if it is possible and beneficial to parallelize along + # row blocks as well, and get rid of this loop. + while current_start_row < input_group_end_row: + # Read block of input scales + block_row_offs = current_start_row + row_offs + block_col_offs = block_col_pid * BLOCK_COLS + col_offs + block_offs = ( + block_row_offs * scales_stride_dim0 + block_col_offs * scales_stride_dim1 + ) + mask = (block_row_offs < input_group_end_row) & (block_col_offs < scale_cols) + input_scales = tl.load(scales_ptr + block_offs, mask=mask, other=0.0) + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) + # Calculate block offset using provided output block stride + output_block_offsets = ( + output_group_start_row * output_scales_stride_dim0 + + (block_row_id * output_stride_per_row_of_blocks) + + (block_col_pid * output_stride_per_block) + ) + # Apply swizzling for write to gmem + tl.store( + output_scales_ptr + output_block_offsets + dest_indices_flat, + scales_flat, + ) + # Update row block id to next block + block_row_id += 1 + current_start_row += BLOCK_ROWS + + +@triton_op("torchao::triton_mx_block_rearrange_per_group_3d", mutates_args={}) +def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale to block-scaled swizzle format. + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scale_tensor: Input tensor in row-major format with 8-bit elements + + Returns: + Rearranged tensor in block-scaled swizzle format + """ + assert scale_tensor.ndim == 3, "scales tensor must be 3d" + assert scale_tensor.element_size() == 1, ( + "Expected element size to be 1 byte (8 bits)" + ) + + num_groups, rows, cols = scale_tensor.shape + input_stride_dim0 = scale_tensor.stride(0) + input_stride_dim1 = scale_tensor.stride(1) + input_stride_dim2 = scale_tensor.stride(2) + + # Calculate blocks needed and allocate output tensor + num_row_blocks = triton.cdiv(rows, 128) + num_col_blocks = triton.cdiv(cols, 4) + padded_rows = num_row_blocks * 128 + padded_cols = num_col_blocks * 4 + output = scale_tensor.new_empty((num_groups, padded_rows * padded_cols)) + output_stride_dim0 = output.stride(0) + + # We probably want handle multiple blocks per tile but for now keep it simple + BLOCK_ROWS, BLOCK_COLS = 128, 4 + + # Output block stride for the rearranged format + output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS) + + grid = lambda META: ( + num_groups, + num_row_blocks, + num_col_blocks, + ) + + wrap_triton(triton_scale_swizzle_per_group_3d)[grid]( + scale_tensor.view(torch.uint8), + input_stride_dim0, + input_stride_dim1, + input_stride_dim2, + output.view(torch.uint8), + output_stride_dim0, + output_block_stride, + rows, + cols, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + + return output + + +@triton.jit +def triton_scale_swizzle_per_group_3d( + input_ptr, + input_stride_dim0, + input_stride_dim1, + input_stride_dim2, + output_ptr, + output_stride_dim0, + output_block_stride, + scale_rows, + scale_cols, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, +): + pid_group = tl.program_id(0) + pid_row = tl.program_id(1) + pid_col = tl.program_id(2) + + # Update base pointers based on this group id + input_ptr += pid_group * input_stride_dim0 + output_ptr += pid_group * output_stride_dim0 + + row_offs = tl.arange(0, BLOCK_ROWS)[:, None] + col_offs = tl.arange(0, BLOCK_COLS)[None, :] + + # Compute desination offs for each elem in block swizzled layout + dest_indices_flat = _dest_indices_for_block( + row_offs, + col_offs, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + + # Calculate starting row and column for this tile + start_row = pid_row * BLOCK_ROWS + start_col = pid_col * BLOCK_COLS + global_rows = start_row + row_offs + global_cols = start_col + col_offs + + mask = (global_rows < scale_rows) & (global_cols < scale_cols) + + input_scales = tl.load( + input_ptr + global_rows * input_stride_dim1 + global_cols * input_stride_dim2, + mask=mask, + other=0.0, + ) + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) + + # Calculate block offset using provided output block stride + LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS + block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride) + + tl.store( + output_ptr + block_offset + dest_indices_flat, + scales_flat, + ) + + +@triton_op("torchao::triton_mx_block_rearrange_2d_K_groups", mutates_args={}) +def triton_mx_block_rearrange_2d_K_groups( + scales_tensor: torch.Tensor, + input_group_end_offsets: torch.Tensor, + output_group_start_offsets: torch.Tensor, +) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale to block-scaled swizzle format on a per group basis, + where the groups are along the contraction dimension of the GEMM. + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor. + input_group_end_offsets: tensor of int32 values representing group end indexes for the input scales + output_group_start_offsets: tensor of int32 values representing pre-computed group start indexes after blocked format padding + Returns: + - Rearranged tensor in block-scaled swizzle format + """ + assert scales_tensor.ndim == 2, "scales tensor must be 2d" + assert scales_tensor.element_size() == 1, ( + "Expected element size to be 1 byte (8 bits)" + ) + rows, cols = scales_tensor.shape + # Calculate blocks needed + num_groups = input_group_end_offsets.shape[0] + num_row_blocks = ceil_div(rows, 128) + padded_rows = num_row_blocks * 128 + + # Padding needing per group is variable/data dependent, so we just pad each group by + # the upper bound of 4 cols to avoid a d2h sync caused by iterating over each group. + padded_cols = cols + num_groups * 4 + output = scales_tensor.new_zeros((padded_rows, padded_cols)) + + # Output block stride for the rearranged format + BLOCK_ROWS, BLOCK_COLS = 128, 4 + output_stride_per_block = BLOCK_ROWS * BLOCK_COLS + + # We parallelize per group and per row block. + # Cols per group is variable, so we just loop through col blocks for each group. + grid = lambda META: ( + num_groups, + num_row_blocks, + ) + wrap_triton(triton_scale_swizzle_2d_K_groups)[grid]( + # Input scales + scales_tensor.view(torch.uint8), + scales_tensor.stride(0), + scales_tensor.stride(1), + rows, + cols, + padded_rows, + num_groups, + # Original offsets (to read from) + input_group_end_offsets, + # Output scales tensor and group offsets after padding (to write to) + output.view(torch.uint8), + output_group_start_offsets, + output_stride_per_block, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + DEBUG=False, + ) + return output + + +@triton.jit +def triton_scale_swizzle_2d_K_groups( + scales_ptr, # (M, total_K//block_size) + scales_stride_dim0, + scales_stride_dim1, + scale_rows, + scale_cols, + padded_rows, + num_groups, + orig_offsets, # (num_groups,) + output_scales_ptr, + output_scales_group_offsets, # (num_groups,) + output_stride_per_block, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, + DEBUG: tl.constexpr = False, +): + group_pid = tl.program_id(0) + block_row_pid = tl.program_id(1) + + # Input scales row range for this group + input_group_start_col = tl.load( + orig_offsets + group_pid - 1, mask=group_pid > 0, other=0 + ) + input_group_end_col = tl.load(orig_offsets + group_pid) + + # Output scales start row we will begin writing to + output_group_start_col = tl.load(output_scales_group_offsets + group_pid) + + row_offs = tl.arange(0, BLOCK_ROWS)[:, None] + col_offs = tl.arange(0, BLOCK_COLS)[None, :] + + # Compute desination offs for each elem in block swizzled layout + dest_indices_flat = _dest_indices_for_block( + row_offs, + col_offs, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + + # For this group and row block, we iterate through col blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales. + # We track how many col blocks we have iterated through. + out_group_base_offset = output_group_start_col * padded_rows + curr_input_start_col = input_group_start_col + curr_out_start_col_block = 0 + while curr_input_start_col < input_group_end_col: + # Read block of input scales + block_row_offs = block_row_pid * BLOCK_ROWS + row_offs + block_col_offs = curr_input_start_col + col_offs + block_offs = ( + block_row_offs * scales_stride_dim0 + block_col_offs * scales_stride_dim1 + ) + mask = (block_row_offs < scale_rows) & (block_col_offs < input_group_end_col) + input_scales = tl.load(scales_ptr + block_offs, mask=mask, other=0.0) + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) + + # Get offset within the group to add to the group's base offset + num_cols_in_group = input_group_end_col - input_group_start_col + num_col_blocks_in_group = tl.cdiv(num_cols_in_group, BLOCK_COLS) + stride_per_row_of_blocks_in_group = ( + num_col_blocks_in_group * output_stride_per_block + ) + offset_in_group = ( + block_row_pid * stride_per_row_of_blocks_in_group + + curr_out_start_col_block * output_stride_per_block + ) + final_offset = out_group_base_offset + offset_in_group + + # Apply swizzling for write to gmem + tl.store( + output_scales_ptr + final_offset + dest_indices_flat, + scales_flat, + ) + + # Advance to next col block + curr_input_start_col += BLOCK_COLS + curr_out_start_col_block += 1 + + +@triton.jit +def _dest_indices_for_block( + row_offs, + col_offs, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, +): + # Calculate destination indices for each row and col in block swizzled layout. + # We can reuse this swizzle transformation on each block of data we read. + r_div_32 = row_offs // 32 + r_mod_32 = row_offs % 32 + + # Rearrange to (32, 4, 4) then to final (32, 16) coordinates + dest_indices = r_mod_32 * 16 + r_div_32 * 4 + col_offs + + # Flatten + dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + return dest_indices_flat + + +mxfp8_cuda_extension_available = False +if is_sm_at_least_100(): + try: + # MXFP8 CUDA kernel is only built on SM100+. Furthermore, + # currently our CI runners are not SM100+, so the user needs to build + # from source. + # TODO(#2932): improve this + from torchao.prototype import mxfp8_cuda + + mxfp8_cuda_extension_available = True + except ImportError: + logging.debug("Skipping import of torchao.prototype.mxfp8_cuda") + +if mxfp8_cuda_extension_available: + # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string. + # Currently we have to use an arbitrary string because custom ops don't support enum + # params. + @torch.library.custom_op("torchao::mxfp8_quantize_cuda_3d", mutates_args=()) + def mxfp8_quantize_cuda_3d( + x: torch.Tensor, + block_size: int = 32, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes a 3D tensor of shape (E,N,K) to MXFP8 format, scaling along N. + + Args: + x (torch.Tensor): Input tensor to be quantized. + block_size (int, optional): Block size for quantization. Defaults to 32. + scaling_mode (str, optional): Scaling mode for quantization. Defaults to "floor". + + Returns: + torch.Tensor: quantized tensor + torch.Tensor: scales tensor + """ + assert x.ndim == 3, "Input tensor must be 3D" + assert x.dtype in (torch.float32, torch.bfloat16), ( + "Input tensor must be float32 or bfloat16" + ) + q_data, scales = mxfp8_cuda.quantize_3d( + x, scale_dim_n=block_size, scaling_mode=scaling_mode + ) + return q_data, scales + + @mxfp8_quantize_cuda_3d.register_fake + def _fake_mxfp8_quantize_cuda_3d( + x: torch.Tensor, + block_size: int = 32, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 3, "Input tensor must be 3D" + assert x.dtype in (torch.float32, torch.bfloat16), ( + "Input tensor must be float32 or bfloat16" + ) + E, N, K = x.shape + # Quantized tensor is in column major layouts + q_data = x.new_empty(x.shape, dtype=torch.float8_e4m3fn).as_strided( + x.shape, (N * K, 1, N) + ) + scales = x.new_empty((E, N // block_size, K), dtype=torch.float8_e8m0fnu) + return q_data, scales + +else: + + def mxfp8_quantize_cuda_3d( + x: torch.Tensor, + block_size: int = 32, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError( + "mxfp8_quantize_cuda_3d is not implemented on this device" + ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 29adffd831..2d28ade35d 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -4,26 +4,46 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Optional import torch from torchao.float8.config import ScalingGranularity from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchao.prototype.moe_training.kernels import ( - triton_fp8_col_major_jagged_colwise_scales, - triton_fp8_row_major_jagged_rowwise_scales, + triton_fp8_per_group_colwise_scales, + triton_fp8_rowwise_3d_transpose_rhs, +) +from torchao.prototype.moe_training.kernels.mxfp8 import ( + compute_blocked_scale_offsets_for_K_groups, + compute_blocked_scale_offsets_for_M_groups, + mxfp8_quantize_cuda_3d, + triton_mx_block_rearrange_2d_K_groups, + triton_mx_block_rearrange_2d_M_groups, + triton_mx_block_rearrange_per_group_3d, ) from torchao.prototype.moe_training.utils import ( _is_column_major, ) +from torchao.prototype.mx_formats.config import ( + MXFP8Dim1CastKernelChoice, + MXGemmKernelChoice, + ScaleCalculationMode, +) +from torchao.prototype.mx_formats.mx_tensor import to_mx +from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper + +logger: logging.Logger = logging.getLogger(__name__) def _scaled_grouped_mm( A: torch.Tensor, B_t: torch.Tensor, - offs: torch.Tensor, + offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, + scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE, ) -> torch.Tensor: """ This function performs dynamic float8 quantization with row-wise scaling @@ -32,18 +52,30 @@ def _scaled_grouped_mm( Args: A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K) and in row-major memory layout. - B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N) + B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (E, K, N) and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. - use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True. """ - return _Float8GroupedMM.apply( - A, - B_t, - offs, - out_dtype, - ) + # TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging. + if scaling_type == MoEScalingType.FP8_ROWWISE: + return _Float8GroupedMM.apply( + A, + B_t, + offs, + out_dtype, + ) + elif scaling_type == MoEScalingType.MXFP8: + block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow? + return _MXFP8GroupedMM.apply( + A, + B_t, + offs, + block_size, + out_dtype, + ) + else: + raise ValueError(f"Unsupported scaling type {scaling_type}") class _Float8GroupedMM(torch.autograd.Function): @@ -54,12 +86,11 @@ def forward( ctx, A: torch.Tensor, B_t: torch.Tensor, - offs: torch.Tensor, + offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, - use_triton_for_per_group_scales: bool = True, ) -> torch.Tensor: - # torchao _scaled_grouped_mm only supports A=2D, B=3D. - assert A.ndim == 2, "A must be 2D" + # torchao _scaled_grouped_mm only supports A=2D|3D and B=3D. + assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D" assert B_t.ndim == 3, "B must be 3D" assert A.size(-1) % 16 == 0, ( @@ -76,7 +107,9 @@ def forward( assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, ( "B must be float32 or bfloat16" ) - assert offs.dtype == torch.int32, "offs must be int32" + assert offs is None or offs.dtype == torch.int32, ( + "offs must be int32 tensor or None" + ) # Assert A and B dims are compatible for a scaled grouped GEMM. assert A.size(-1) == B_t.size(-2), ( @@ -87,14 +120,11 @@ def forward( assert not _is_column_major(A), "A must be row-major" # Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major. - if not _is_column_major(B_t): - # FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major. - # TODO: figure out better solution than transposing for each forward pass. - B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1) + assert _is_column_major(B_t), "B must be column-major" # Convert high precision input tensor to float8, row-major for left operand of grouped GEMM. - # A shape: (M, K) - # A_scales shape: (M,1) + # A shape: (M, K) or (B, M, K) + # A_scales shape: (M,1) or (B, M, 1) A_scales = tensor_to_scale( A, torch.float8_e4m3fn, @@ -103,12 +133,12 @@ def forward( round_scales_to_power_of_2=True, ) A_scaled = A.to(torch.float32) * A_scales - A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + A_data_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) # Convert B to float8, column-major for right operand of grouped GEMM. - # B shape: (B, K, N) - # B scales must be computed rowwise keeping the outer/final dim, so: - # B_scales shape: (B, 1, N) + # B_t shape: (E, K, N) + # B_t scales must be computed rowwise keeping the outer/final dim, so: + # B_t_scales shape: (E, 1, N) B_t_scales = tensor_to_scale( B_t, torch.float8_e4m3fn, @@ -117,43 +147,31 @@ def forward( round_scales_to_power_of_2=True, ) B_t_scaled = B_t.to(torch.float32) * B_t_scales - B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) - - # Precompute non-transposed B column-major for backward, to save memory by storing the - # low precision B tensor instead of the high precision B tensor. - # In the backward this is needed for grad_A: grad_output @ B. - B = B_t.contiguous().transpose(-2, -1) - - # - B shape: (B, K, N) - # - B scales must be computed rowwise keeping the outer/final dim, so: - # - B_scale shape: (B, 1, N) - B_scales = tensor_to_scale( - B, - torch.float8_e4m3fn, - scaling_granularity=ScalingGranularity.AXISWISE, - axiswise_dim=-2, - round_scales_to_power_of_2=True, - ) - B_scaled = B.to(torch.float32) * B_scales - B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn) + B_t_data_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) # Store what we need for backward. - ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) + ctx.save_for_backward(A, B_t, offs) ctx.out_dtype = out_dtype # Perform scaled grouped GEMM and return result. # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) - assert not _is_column_major(A_fp8_row_major), ( + assert not _is_column_major(A_data_row_major), ( "A must be row-major for output = A @ B" ) - assert _is_column_major(B_t_fp8_col_major), ( + assert _is_column_major(B_t_data_col_major), ( "B must be column-major for output = A @ B" ) + + # Squeeze empty dims out of scales, to comply with grouped mm API. + # A_scales shape: (M,1) or (B, M, 1) + # B_t_scales shape: (E, 1, N) + A_scales = A_scales.squeeze(-1) + B_t_scales = B_t_scales.squeeze(1) return torch._scaled_grouped_mm( - A_fp8_row_major, - B_t_fp8_col_major, - A_scales.squeeze().reciprocal(), - B_t_scales.squeeze().reciprocal(), + A_data_row_major, + B_t_data_col_major, + A_scales.reciprocal(), # Reciprocals are needed for rescaling the output. + B_t_scales.reciprocal(), offs, out_dtype=out_dtype, use_fast_accum=True, @@ -161,14 +179,14 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): - A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors + A, B_t, offs = ctx.saved_tensors out_dtype = ctx.out_dtype # Convert grad_output to float8, row-major for left operand of grouped GEMM # needed for grad_A: grad_output @ B # - # grad_output shape: (M, N) - # grad_output_scale shape: (M, 1) + # grad_output shape: (Mg, N) + # grad_output_scale shape: (Mg, 1) grad_output_scales = tensor_to_scale( grad_output, torch.float8_e4m3fn, @@ -177,51 +195,64 @@ def backward(ctx, grad_output: torch.Tensor): round_scales_to_power_of_2=True, ) grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales - grad_output_fp8_row_major = to_fp8_saturated( + grad_output_data_row_major = to_fp8_saturated( grad_output_scaled, torch.float8_e4m3fn ) + # Compute B fp8 column-major for right operand of grouped GEMM: + # grad_A = grad_output @ B. + B_data_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( + B_t._data if hasattr(B_t, "_data") else B_t, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + # Compute grad_A. - # # grad_A = grad_output @ B # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) - assert not _is_column_major(grad_output_fp8_row_major), ( + assert not _is_column_major(grad_output_data_row_major), ( "grad_output must be row-major for grad_A = grad_output @ B" ) - assert _is_column_major(B_fp8_col_major), ( + assert _is_column_major(B_data_col_major), ( "B must be column-major for grad_A = grad_output @ B" ) + + # Squeeze empty dims out of scales, to comply with grouped mm API. + # grad_output_scales shape: (M,1) or (B, M, 1) + # B_scales shape: (E, 1, N) + grad_output_scales = grad_output_scales.squeeze(-1) + B_scales = B_scales.squeeze(1) grad_A = torch._scaled_grouped_mm( - grad_output_fp8_row_major, - B_fp8_col_major, - grad_output_scales.squeeze().reciprocal(), - B_scales.squeeze().reciprocal(), + grad_output_data_row_major, + B_data_col_major, + grad_output_scales.reciprocal(), + B_scales.reciprocal(), offs, out_dtype=out_dtype, use_fast_accum=True, ) - # Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM - # needed for grad_B: grad_output_t @ A - grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous() - - # Convert A to float8, column-major for right operand of grouped GEMM: - # needed for grad_B: grad_output @ A - A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1) - # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. - grad_output_t_fp8_row_major, grad_output_t_scales = ( - triton_fp8_row_major_jagged_rowwise_scales( - grad_output_t_row_major, - offs, - torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) + + # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM + # needed for grad_B: grad_output_t @ A + # Use transpose method to avoid uncoalesced memory accesses. + grad_out_data_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales( + grad_output.t() + .contiguous() + .t(), # Quantization is over 2x faster when input is col major, even with this transformation + offs, + torch.float8_e4m3fn, + round_scales_to_power_of_2=True, ) + grad_output_t_data_row_major = grad_out_data_colwise.t() + grad_output_t_scales = grad_out_scales.t() - A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales( - A_col_major, + A_data_col_major, A_scales = triton_fp8_per_group_colwise_scales( + A.t() + .contiguous() + .t(), # Quantization is over 2x faster when input is col major, even with this transformation offs, torch.float8_e4m3fn, round_scales_to_power_of_2=True, @@ -229,16 +260,19 @@ def backward(ctx, grad_output: torch.Tensor): # Compute grad_B = grad_output_t @ A. # grad_B = grad_output_t @ A - # grad_B = (N,M) @ (M,K) = (N,K) - assert not _is_column_major(grad_output_t_fp8_row_major), ( + assert not _is_column_major(grad_output_t_data_row_major), ( "grad_output_t must be row-major for grad_B = grad_output_t @ A" ) - assert _is_column_major(A_fp8_col_major), ( + assert _is_column_major(A_data_col_major), ( "A must be column-major for grad_B = grad_output_t @ A" ) + + # Per-token group scales computed via triton kernels above do not have + # the empty dim like the scales computed via tensor_to_scale, so we need + # don't need to squeeze here. grad_B = torch._scaled_grouped_mm( - grad_output_t_fp8_row_major, - A_fp8_col_major, + grad_output_t_data_row_major, + A_data_col_major, grad_output_t_scales.reciprocal(), A_scales.reciprocal(), offs, @@ -246,3 +280,372 @@ def backward(ctx, grad_output: torch.Tensor): use_fast_accum=True, ) return grad_A, grad_B.transpose(-2, -1), None, None, None, None + + +class _MXFP8GroupedMM(torch.autograd.Function): + """Differentiable implementation of grouped GEMM with dynamic mxpf8 quantization.""" + + @staticmethod + def forward( + ctx, + A: torch.Tensor, + B_t: torch.Tensor, + offs: Optional[torch.Tensor] = None, + block_size: int = 32, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + emulated: bool = False, + ) -> torch.Tensor: + # torchao _scaled_grouped_mm only supports A=2D and B=3D. + assert A.ndim == 2, "A must be 2D" + assert B_t.ndim == 3, "B must be 3D" + assert block_size == 32, "Only block_size=32 is supported" + assert offs is not None, "offs must be provided for 2d-2d and 2d-3d grouped mm" + + # A_data shape: (M, K) + # A_scale shape: (M, K//block_size) + A_scale, A_data = to_mx( + A, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + # B_data shape: (E, N, K) + # B_scale shape: (E, N, K//block_size) + B_scales, B_data = to_mx( + B_t.transpose(-2, -1), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Convert scales to blocked format for 2d-3d grouped mm + _, blocked_scales_group_offsets_2d3d = ( + compute_blocked_scale_offsets_for_M_groups(offs) + ) + A_scales_blocked = triton_mx_block_rearrange_2d_M_groups( + A_scale, + offs, + blocked_scales_group_offsets_2d3d, + ) + B_scales_blocked = triton_mx_block_rearrange_per_group_3d(B_scales) + + # output = input @ weight.T + # output shape: (M, N) + out = torch._scaled_grouped_mm( + A_data, + B_data.transpose(-2, -1), + A_scales_blocked, + B_scales_blocked, + offs=offs, + out_dtype=out_dtype, + ) + + ctx.save_for_backward(A, B_t, offs, blocked_scales_group_offsets_2d3d) + ctx.block_size = block_size + ctx.out_dtype = out_dtype + ctx.emulated = emulated + return out + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + A, B_t, offs, blocked_scales_group_offsets_2d3d = ctx.saved_tensors + block_size = ctx.block_size + out_dtype = ctx.out_dtype + + # grad_out_data shape: (M, N) + # grad_out_scale shape: (M, N//block_size) + grad_out_scale, grad_out_data = to_mx( + grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + # Quantize 3d expert weights along N (contraction dimension for next grouped gemm) + # (E, K, N) -> (E, N, K) + B = B_t.transpose(-2, -1) + B_data, B_scales = mxfp8_quantize_cuda_3d( + B._data if hasattr(B, "_data") else B, block_size=block_size + ) + # (E, N//block_size, K) -> (E, K, N//block_size) + B_scales = B_scales.transpose(-2, -1) + + # Convert scales to blocked format for 2d-3d grouped mm + grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups( + grad_out_scale, + offs, + blocked_scales_group_offsets_2d3d, + ) + B_scales_blocked = triton_mx_block_rearrange_per_group_3d(B_scales) + + # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) + grad_A = torch._scaled_grouped_mm( + grad_out_data, + B_data, + grad_out_scales_blocked, + B_scales_blocked, + offs=offs, + out_dtype=out_dtype, + ) + + # grad_out_t_data shape: (M, N) + # grad_out_t_scales shape: (N, M//block_size) + grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper( + grad_out, + block_size, + elem_dtype=torch.float8_e4m3fn, + hp_dtype=grad_out.dtype, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, + scale_calculation_mode=ScaleCalculationMode.FLOOR, + ) + grad_out_t_data = grad_out_t_mx.qdata + grad_out_t_scales = grad_out_t_mx._scale_e8m0 + + # Transpose A so we can scale along the M dimension, then un-transpose. + # A shape: (M, K) + # A_t_data shape: (K, M) + # A_t_scales shape: (K, M//block_size) + A_t_mx = _to_mxfp8_dim1_kernel_wrapper( + A, + block_size, + elem_dtype=torch.float8_e4m3fn, + hp_dtype=A.dtype, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, + scale_calculation_mode=ScaleCalculationMode.FLOOR, + ) + A_t_data = A_t_mx.qdata + A_t_scales = A_t_mx._scale_e8m0 + + # Convert scales to blocked format for 2d-2d grouped mm + scale_group_offsets = offs // block_size + _, blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups( + scale_group_offsets + ) + grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups( + grad_out_t_scales, + scale_group_offsets, + blocked_scale_group_offsets, + ) + A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups( + A_t_scales, + scale_group_offsets, + blocked_scale_group_offsets, + ) + + # grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K) + grad_B = torch._scaled_grouped_mm( + grad_out_t_data, + A_t_data.transpose(-2, -1), + grad_out_t_scales_blocked, + A_t_scales_blocked, + offs=offs, + out_dtype=out_dtype, + ) + # grad_B_t shape = (E,K,N) + grad_B_t = grad_B.transpose(-2, -1) + return grad_A, grad_B_t, None, None, None + + +def _to_mxfp8_dim1_3d( + B: torch.Tensor, + block_size: int = 32, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert a 3D tensor to MXFP8 format with (block_size, 1) scaling granularity. + """ + E, N, K = B.shape + B_reshaped = B.reshape(E * N, K) + B_t_mx = _to_mxfp8_dim1_kernel_wrapper( + B_reshaped, + block_size, + elem_dtype=torch.float8_e4m3fn, + hp_dtype=B_reshaped.dtype, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, + scale_calculation_mode=scaling_mode, + ) + B_data = B_t_mx.qdata.t() # (K, E*N) -> (E*N, K) + B_data = B_data.reshape(E, N, K) # (E*N, K) -> (E, N, K) + B_scales = B_t_mx._scale_e8m0.view(torch.uint8) # (K, E*N//block_size) + B_scales = B_scales.reshape( + K, E, N // block_size + ) # (K, E*N//block_size) -> (K, E, N//block_size) + B_scales = B_scales.permute( + 1, 0, 2 + ) # (K, E, N//block_size) -> (E, K, N//block_size) + B_scales = B_scales.view(torch.float8_e8m0fnu) + + # TODO: Update cutlass grouped gemm to accept NT/TN/NN/TT layouts so we can avoid this conversion to column major + B_data = B_data.transpose(-2, -1).contiguous().transpose(-2, -1) + return B_scales, B_data + + +def _emulated_mxfp8_scaled_grouped_mm_2d_3d( + A_data: torch.Tensor, + A_scale: torch.Tensor, + B_data: torch.Tensor, + B_scale: torch.Tensor, + offs: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + block_size: int = 32, +) -> torch.Tensor: + assert A_data.ndim == 2, f"A must be 2D, got {A_data.ndim}" + assert B_data.ndim == 3, f"B must be 3D, got {B_data.ndim}" + assert A_scale.shape[0] == A_data.shape[0], ( + f"A_scale must have same M dim as A_data, got A={A_data.shape} and A_scale={A_scale.shape}" + ) + assert A_scale.shape[1] == A_data.shape[1] // block_size, ( + f"A_scale dim1 should be size K//block_size, got A={A_data.shape} and A_scale={A_scale.shape}" + ) + assert B_scale.shape[0] == B_data.shape[0], ( + f"B_scale must have same E dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}" + ) + assert B_scale.shape[1] == B_data.shape[1], ( + f"B_scale must have same N dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}" + ) + assert B_scale.shape[2] == B_data.shape[2] // block_size, ( + f"B_scale dim2 should be size K//block_size, got B={B_data.shape} and B_scale={B_scale.shape}" + ) + + # Dequantize input + # A_data shape: (M, K) + # A_scale shape: (M, K//block_size) + A_orig_shape = A_data.shape + + # Reshape to be able to do per-scaling group multiplication + # A_data shape: (M, K//block_size, block_size) + # A_scale shape: (M, K//block_size, 1) + A_data = A_data.reshape( + *A_data.shape[:-1], A_data.shape[-1] // block_size, block_size + ) + A_scale = A_scale.unsqueeze(-1) + + # Rescale and cast to bfloat16 + A = A_data.to(torch.bfloat16) * A_scale.to(torch.bfloat16) + + # Reshape back to original shape + # A shape: (M, K) + A = A.reshape(A_orig_shape) + + # Dequantize weights + # Tranpose to get block_size on rightmost dim + # B_data shape: (E, N, K) + # B_scale shape: (E, N, K//block_size) + E, N, K = B_data.shape + + # Reshape to be able to do per-scaling group multiplication + # B_data shape: (E, N, K//block_size, block_size) + # B_scale shape: (E, N, K//block_size, 1) + B_data = B_data.reshape( + *B_data.shape[:-1], B_data.shape[-1] // block_size, block_size + ) + B_scale = B_scale.unsqueeze(-1) + + # Rescale and cast to bfloat16 + B = B_data.to(torch.bfloat16) * B_scale.to(torch.bfloat16) + + # Reshape back to original shape + # B shape: (E, K, N) + B_t = B.reshape(E, N, K).transpose(-2, -1) + + # Perform bf16 grouped GEMM. + out = torch._grouped_mm(A, B_t, offs=offs, out_dtype=out_dtype) + return out + + +def _emulated_mxfp8_scaled_grouped_mm_2d_2d( + A_data: torch.Tensor, # (M, K) + A_scale: torch.Tensor, # (M, K//block_size) + B_data: torch.Tensor, # (K, N) + B_scale: torch.Tensor, # (K//block_size, N) + offs: torch.Tensor, + out_dtype: Optional[torch.dtype] = torch.bfloat16, + block_size: int = 32, +) -> torch.Tensor: + assert A_data.ndim == 2, "A must be 2D" + assert B_data.ndim == 2, "B must be 2D" + A = torch.zeros( + A_data.shape, + dtype=torch.bfloat16, + device=A_data.device, + requires_grad=A_data.requires_grad, + ) + B = torch.zeros( + B_data.shape, + dtype=torch.bfloat16, + device=B_data.device, + requires_grad=B_data.requires_grad, + ) + + # Dequantize input per each scaling group + scales_start_idx = 0 + group_start_idx = 0 + for group_end_idx in offs.tolist(): + group_size = group_end_idx - group_start_idx + scale_group_size = group_size // block_size + if group_size == 0: + group_start_idx = group_end_idx + continue + + # -- Dequantize A tensor + # A_group shape: (M, group_size) + # A_scale shape: (M, group_size//block_size) + A_group = A_data[:, group_start_idx:group_end_idx] + A_group_shape = A_group.shape + + # Get scales for this group. + # scales shape: (M, group_size//block_size) + scales = A_scale[:, scales_start_idx : scales_start_idx + scale_group_size] + + # Reshape to be able to do per-scaling group multiplication + # A_group shape: (M, group_size//block_size, block_size) + # A_scale shape: (M, group_size//block_size, 1) + A_group = A_group.reshape( + *A_group.shape[:-1], A_group.shape[-1] // block_size, block_size + ) + scales = scales.unsqueeze(-1) + + # Rescale and cast to bfloat16 + A_group = A_group.to(torch.bfloat16) * scales.to(torch.bfloat16) + + # Reshape back to original shape and store in dequantized A buffer + # A shape: (M, group_size) + A_group = A_group.reshape(A_group_shape) + A[:, group_start_idx:group_end_idx] = A_group + + # -- Dequantize B tensor + # B_group shape is (group_size, N) + B_group = B_data[group_start_idx:group_end_idx, :] + B_group_shape = B_group.shape + + # Scales shape is (group_size//block_size, N) + scales = B_scale[scales_start_idx : scales_start_idx + scale_group_size, :] + + # Transpose B to get scaling group on rightmost dim, to make things easier + # B_group_shape = (N, group_size) + # scales shape = N, group_size//block_size) + B_group, scales = B_group.transpose(-2, -1), scales.transpose(-2, -1) + + # Reshape B to be able to do per-scaling group multiplication + # B_group shape: (N, group_size//block_size, block_size) + # scales shape: (N, group_size//block_size, 1) + B_group = B_group.reshape( + *B_group.shape[:-1], B_group.shape[-1] // block_size, block_size + ) + scales = scales.unsqueeze(-1) + + # Cast to bf16 and perform scaling + B_group = B_group.to(torch.bfloat16) * scales.to(torch.bfloat16) + + # Reshape B_group back to original shape and store in dequantized B buffer + B_group = B_group.reshape(B_group_shape[1], B_group_shape[0]).transpose(-2, -1) + B[group_start_idx:group_end_idx, :] = B_group + + # Increment group start and scale start indices + group_start_idx = group_end_idx + scales_start_idx += scale_group_size + + # Perform bf16 grouped GEMM using dequantized A and B. + out = torch._grouped_mm(A, B, offs=offs, out_dtype=out_dtype) + return out + + +def round_up(x, y): + return ((x + y - 1) // y) * y diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 3ea9529237..0bbbda850e 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -1,10 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import logging from typing import Any, Optional, Tuple import torch import torch.utils._pytree as pytree +from torch import nn from torch._prims_common import suggest_memory_format +from torch.distributed._tensor import DTensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy from torchao.prototype.moe_training import _scaled_grouped_mm +from torchao.prototype.moe_training.conversion_utils import MoEScalingType + +logger: logging.Logger = logging.getLogger(__name__) _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, @@ -13,10 +27,11 @@ torch.ops.aten.copy_.default, torch.ops.aten.view.default, torch.ops.aten.as_strided.default, - torch.ops.aten._to_copy.default, + torch.ops.aten._to_copy.default, # for *.to(dtype) torch.ops.aten._pin_memory.default, torch.ops.aten.split.Tensor, torch.ops.aten.clone.default, + torch.ops.aten.transpose.int, } @@ -27,6 +42,7 @@ class ScaledGroupedMMTensor(torch.Tensor): differentiable _scaled_grouped_mm autograd function. """ + scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE grouped_mm_func_name = "_grouped_mm" offs_arg_name = "offs" @@ -34,8 +50,9 @@ class ScaledGroupedMMTensor(torch.Tensor): def __new__( cls, tensor: torch.Tensor, + scaling_type: MoEScalingType, ): - return torch.Tensor._make_wrapper_subclass( + self = torch.Tensor._make_wrapper_subclass( cls, tensor.size(), strides=tensor.stride(), @@ -47,12 +64,16 @@ def __new__( pin_memory=tensor.is_pinned(), requires_grad=tensor.requires_grad, ) + self.scaling_type = scaling_type + return self def __init__( self, tensor: torch.Tensor, + scaling_type: MoEScalingType, ): self._data = tensor + self.scaling_type = scaling_type @classmethod def __torch_function__(cls, func, types, args, kwargs={}): @@ -66,12 +87,23 @@ def __torch_function__(cls, func, types, args, kwargs={}): # used for shared experts. This is basically the grouped_mm # kernel handling a bmm. A, B = args[0], args[1] - A_is_2d = A.dim() == 2 - B_is_3d = B.dim() == 3 + assert not isinstance(A, ScaledGroupedMMTensor), ( + "A should not be a ScaledGroupedMMTensor" + ) + assert isinstance(B, ScaledGroupedMMTensor), ( + "B should be a ScaledGroupedMMTensor" + ) + scaling_type = B.scaling_type + A_is_2d = A.ndim == 2 + B_is_2d_or_3d = B.ndim == 2 or B.ndim == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None - if A_is_2d and B_is_3d and has_offs: + other_args = args[2:] + if A_is_2d and B_is_2d_or_3d and has_offs: return _scaled_grouped_mm( - *args, + A, + B, + *other_args, + scaling_type=scaling_type, **kwargs, ) @@ -82,18 +114,30 @@ def __torch_function__(cls, func, types, args, kwargs={}): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs={}): - # detach is special case - if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args[0]._data) - - # unwrap args and kwargs - unwrap = lambda tensor: tensor._data - args, kwargs = pytree.tree_map_only( + # unwrap args/kwargs and extract scaling_type + scaling_type = None + + def unwrap(t): + nonlocal scaling_type + if scaling_type is None: + scaling_type = t.scaling_type + else: + assert t.scaling_type == scaling_type + return t._data + + args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) ) + assert scaling_type is not None, ( + f"__torch_dispatch__ called on {func.__name__} without any ScaledGroupedMMTensor arguments" + ) + + # detach is special case + if func == torch.ops.aten.detach.default: + return ScaledGroupedMMTensor(args_unwrapped[0], scaling_type) # perform op - out = func(*args, **kwargs) + out = func(*args_unwrapped, **kwargs_unwrapped) # return regular tensors for ops that don't preserve subclass if func not in _ops_to_preserve_subclass: @@ -102,12 +146,37 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}): # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass return pytree.tree_map_only( torch.Tensor, - lambda x: ScaledGroupedMMTensor(x), + lambda x: ScaledGroupedMMTensor(x, scaling_type), out, ) - def fsdp_pre_all_gather(self, mesh): - return (self._data,), () + def __repr__(self): + return f"ScaledGroupedMMTensor(data={self._data}, scaling_type={self.scaling_type})" + + def __tensor_flatten__(self): + metadata = {"scaling_type": self.scaling_type} + return ["_data"], metadata + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + return ScaledGroupedMMTensor( + inner_tensors["_data"], + flatten_spec["scaling_type"], + ) + + # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81 + def fsdp_pre_all_gather( + self, + mesh: DeviceMesh, + outer_size: torch.Size, + outer_stride: tuple[int, ...], + module: nn.Module, + mp_policy: MixedPrecisionPolicy, + ): + # cast to mixed precision dtype prior to all-gather + all_gather_inputs = (self._data.to(mp_policy.param_dtype),) + all_gather_metadata = () + return all_gather_inputs, all_gather_metadata def fsdp_post_all_gather( self, @@ -118,6 +187,39 @@ def fsdp_post_all_gather( out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs - return ScaledGroupedMMTensor( - data, - ), (data,) + + # For training step 1+, out=unsharded param. + if out is not None: + if isinstance(out, ScaledGroupedMMTensor): + out_data = out._data + out.scaling_type = self.scaling_type + elif isinstance(out, DTensor) and isinstance( + out._local_tensor, ScaledGroupedMMTensor + ): + out_data = out._local_tensor._data + out._local_tensor.scaling_type = self.scaling_type + else: + raise RuntimeError( + f"expect out to be ScaledGroupedMMTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}" + ) + + # If `data` (all gather outputs) is already in the mixed precision policy param_dtype, + # verify it has underlying storage as `out` (pre-allocated unsharded param), + # and then we can just return directly. + if data.dtype == param_dtype: + assert ( + data.untyped_storage().data_ptr() + == out_data.untyped_storage().data_ptr() + ) + else: + # Otherwise, verify that `out` (pre-allocated unsharded param) has the + # mixed precision policy param_dtype, then copy `data` to `out`. + assert out_data.dtype == param_dtype, f"{out_data.dtype} {param_dtype}" + out_data.copy_(data) + + return + + # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor. + output = ScaledGroupedMMTensor(data, self.scaling_type) + inner_tensors = (data,) + return output, inner_tensors diff --git a/torchao/prototype/moe_training/utils.py b/torchao/prototype/moe_training/utils.py index 038c379d62..5bcbd21d70 100644 --- a/torchao/prototype/moe_training/utils.py +++ b/torchao/prototype/moe_training/utils.py @@ -1,12 +1,15 @@ +import random from typing import Tuple import torch from torchao.float8.config import ScalingGranularity from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.prototype.mx_formats.mx_tensor import to_mx -def _to_2d_jagged_float8_tensor_colwise( +# --- float8 rowwise scaling --- +def torch_to_float8_per_group_colwise( A_col_major: torch.Tensor, offs: torch.Tensor, target_dtype: torch.dtype = torch.float8_e4m3fn, @@ -75,7 +78,7 @@ def _to_2d_jagged_float8_tensor_colwise( return A_fp8_col_major, A_scales -def _to_2d_jagged_float8_tensor_rowwise( +def torch_to_float8_per_group_rowwise( x: torch.Tensor, offs: torch.Tensor, target_dtype: torch.dtype, @@ -142,6 +145,140 @@ def _to_2d_jagged_float8_tensor_rowwise( return x_fp8, x_scales +def torch_to_3d_rowwise_float8_transpose_rhs( + input_hp_t: torch.Tensor, # (E, K, N) + target_dtype: torch.dtype = torch.float8_e4m3fn, + round_scales_to_power_of_2: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function converts the 3D input tensor to a float8 tensor, with scales computed along logical columns + on a per-expert basis. Output will be in column-major memory layout. + + Args: + x (torch.Tensor): The input tensor to be converted to a float8 tensor. Shape (E, K, N). + + Returns: + A tuple containing the float8 tensor and the scales used for the conversion. + Output shape: (E, N, K) + Scales shape: (E, 1, K + """ + assert _is_column_major(input_hp_t), "input tensor must be column-major" + scales = tensor_to_scale( + input_hp_t, + target_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=-1, + round_scales_to_power_of_2=round_scales_to_power_of_2, + ) # (E, K, 1) + + # Apply scales to tensor and convert to float8. + tensor_scaled = input_hp_t.to(torch.float32) * scales + float8_tensor = to_fp8_saturated(tensor_scaled, target_dtype) + + # To column major + float8_tensor = float8_tensor.contiguous().transpose(-2, -1) + scales = scales.transpose(-2, -1) + return float8_tensor, scales + + +# --- mxfp8 scaling --- +def _to_mxfp8_per_group_rowwise( + x: torch.Tensor, + offs: torch.Tensor, + block_size: int = 32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This is a reference implementation used for testing correctness, it is not performant. + + This function converts the 2D input tensor a mxpf8 tensor along dim 0 with per-token-group scaling, + where groups are determined based on the offsets. + + Args: + A (torch.Tensor): The input tensor to be converted to a jagged mxfp8 tensor. + + Returns: + A tuple containing the jagged mxpf8 tensor and the scales used for the conversion. + """ + assert x.ndim == 2, "input tensor must be 2D" + assert offs.numel() > 0, "offs must be non-empty" + + x_mx = torch.empty_like(x, dtype=torch.float8_e4m3fn) + x_scales = None + + start_idx = 0 + for end_idx in offs.tolist(): + # Get the subtensor of A for this group, fetching all rows with the next group of rows. + subtensor = x[:, start_idx:end_idx] # (M, local_group_size) + + # Perform mxfp8 conversion on logically distinct subtensor. + scales, mx_subtensor = to_mx( + subtensor.contiguous(), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Store this portion of the resulting mxfp8 tensor and scales. + x_mx[:, start_idx:end_idx] = mx_subtensor + if x_scales is None: + x_scales = scales.view(torch.uint8) # Needed to support cat op below + else: + x_scales = torch.cat((x_scales, scales.view(torch.uint8)), dim=1) + + # Update start index for next group. + start_idx = end_idx + + return x_mx, x_scales.view(torch.float8_e8m0fnu) + + +def _to_mxfp8_per_group_colwise( + A_col_major: torch.Tensor, # (K, N) + offs: torch.Tensor, + block_size: int = 32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This is a reference implementation used for testing correctness, it is not performant. + + This function converts the 2D input tensor a mxpf8 tensor along dim 1 with per-token-group scaling, + where groups are determined based on the offsets. + + Args: + A (torch.Tensor): The input tensor to be converted to a mxfp8 tensor. + + Returns: + A tuple containing the mxpf8 tensor and the scales used for the conversion. + """ + assert A_col_major.ndim == 2, "A must be 2D" + assert offs.numel() > 0, "offs must be non-empty" + + A_mx = torch.empty_like(A_col_major, dtype=torch.float8_e4m3fn) + A_scales = None + + start_idx = 0 + for end_idx in offs.tolist(): + # Get the subtensor of A for this group, fetching the next group of rows, with all columns for each. + subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, N) + + # Convert to mxfp8 along dim1, by transposing, converting, and transposing back. + scales, mx_subtensor = to_mx( + subtensor.transpose(-2, -1).contiguous(), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + scales, mx_subtensor = scales.transpose(-2, -1), mx_subtensor.transpose(-2, -1) + + # Store this portion of the resulting mxfp8 tensor and scales. + A_mx[start_idx:end_idx, :] = mx_subtensor + if A_scales is None: + A_scales = scales.view(torch.uint8) # Needed to support cat op below + else: + A_scales = torch.cat((A_scales, scales.view(torch.uint8)), dim=0) + + # Update start index for next group. + start_idx = end_idx + + return A_mx, A_scales.view(torch.float8_e8m0fnu) + + def _is_column_major(x: torch.Tensor) -> bool: """ This function checks if the input tensor is column-major. @@ -153,4 +290,53 @@ def _is_column_major(x: torch.Tensor) -> bool: A boolean indicating whether the input tensor is column-major. """ assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" - return x.stride(-2) == 1 and x.stride(-1) > 1 + return x.stride(-2) == 1 + + +def _is_row_major(x: torch.Tensor) -> bool: + """ + This function checks if the input tensor is row-major. + + Args: + x (torch.Tensor): The input tensor to be checked. + + Returns: + A boolean indicating whether the input tensor is row-major. + """ + assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" + return x.stride(-1) == 1 + + +def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"): + """ + Utility function for tests and benchmarks. + + Generates a tensor of length E, containing random values divisible by 16, + from 0 to M, in sorted order, and where the final value in the tensor is always M. + Args: + E (int): The length of the tensor. + M (int): The maximum value in the tensor. + Returns: + torch.Tensor: A tensor of length E with the specified properties. + """ + # Ensure M is divisible by 16 + if M % multiple_of != 0: + raise ValueError(f"M must be divisible by {multiple_of}") + + # Generate a list of possible values + possible_values = [i for i in range(multiple_of, M + 1, multiple_of)] + + # If E is larger than the number of possible values, raise an error + if E > len(possible_values): + raise ValueError("E cannot be larger than the number of possible values") + + # Randomly select E - 1 values from the possible values (excluding M) + selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1)) + + # Append M to the selected values + selected_values = torch.cat((selected_values, torch.tensor([M]))) + + # Sort the selected values + selected_values, _ = torch.sort(selected_values) + + return selected_values.to(dtype).to(device) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 587d81f6a6..ba3d152c90 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -7,15 +7,37 @@ in native PyTorch. We are currently in prototype and are actively working on op | workflow | emulation | performance | accuracy | | --- | --- | --- | --- | -| training with mxfp8 | ✅ | 🚧 [active development](https://github.com/pytorch/ao/issues/1768) | ✅ | -| inference (weight-only) with mxfp8, mxfp6, mxfp4 | ✅ | 🔲 | 🔲 | - -We plan to add the following features in the near future: -* other inference workflows such as dynamic quantization -* a unified training to inference workflow +| training with mxfp8 | ✅ | ✅ | ✅ | +| inference with mxfp8, mxfp6, mxfp4 | ✅ | 🔲 | 🔲 | ℹ️ See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features. +## Training e2e benchmarks on NVIDIA B200 + +- Single-node training on 8x power limited B200 GPUs, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC +- pytorch version: `2.9.0.dev20250815+cu128`, torchao version: `0.13.0+gite4e681be6`, torchtitan commit: `6fc499f6f5b32151a799188be2208cfb09faed30` + +| Model | Scaling | Peak Memory (GB) | Median tokens/second | Speedup over baseline +| ------------- | ---------------------------------- | ------------------| -------------------- | --------------------- +| Llama3-8b | none (bfloat16) | 33.71 | 8307.5 | - +| Llama3-8b | float8 tensorwise (f8 all-gather) | 33.38 | 10417.0 | 25.4% +| Llama3-8b | mxfp8_cublas | 33.88 | 9969.0 | 20.0% +| Llama3-8b | mxfp8_cublas_rceil | 33.88 | 9642.0 | 16.1% +| Llama3-8b | float8 rowwise | 33.72 | 8640.5 | 4.0% + +**Reproducing training benchmarks** +To reproduce these benchmarks, you can follow these steps: + +1. On a machine with compatible GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation), +including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer). +2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation). +3. From the `torchao/` directory, you can run the following commands to reproduce the benchmarks above: + - bf16 + compile: `TORCHTITAN_ROOT= ./benchmarks/float8/training/llama3.sh` + - mxfp8_cublas: `TORCHTITAN_ROOT= MX_RECIPE="mxfp8_cublas" ./benchmarks/float8/training/llama3.sh` + - mxfp8_cublas_rceil: `TORCHTITAN_ROOT= MX_RECIPE="mxfp8_cublas_rceil" ./benchmarks/float8/training/llama3.sh` + +> :warning: For now you need to build `torchao` from source for optimal training performance. See https://github.com/pytorch/ao/issues/2932 for details. + # User API ## MX training @@ -23,20 +45,23 @@ We plan to add the following features in the near future: ```python import torch from torchao.quantization import quantize_ -from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice +from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice, ScaleCalculationMode # on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels gemm_kernel_choice = MXGemmKernelChoice.CUBLAS # gemm_kernel_choice = MXGemmKernelChoice.CUTLASS - # on older NVIDIA gpus, you can run training with emulated MX gemm # gemm_kernel_choice = MXGemmKernelChoice.EMULATED +scale_calculation_mode = ScaleCalculationMode.FLOOR +# other supported modes: RCEIL, CEIL, EVEN + m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() config = MXLinearConfig( elem_dtype=torch.float8_e4m3fn, block_size=32, gemm_kernel_choice=gemm_kernel_choice, + scale_calculation_mode=scale_calculation_mode, ) quantize_(m, config) @@ -45,24 +70,8 @@ quantize_(m, config) ## MX inference -Note: currently only weight-only quantization is supported. - -```python -import torch -from torchao.quantization import quantize_ -from torchao.prototype.mx_formats import MXInferenceLinearConfig, MXGemmKernelChoice - -m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -gemm_kernel_choice = MXGemmKernelChoice.CUBLAS -config = MXInferenceLinearConfig( - elem_dtype=torch.float8_e4m3fn, - block_size=32, - gemm_kernel_choice=gemm_kernel_choice, -) -quantize_(m, config=config) +Coming soon! -# do inference (not shown) -``` ## MXTensor This is casts between high precision and MX formats implemented in native PyTorch. Currently diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index 5947d616be..c7a4c47f9d 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -1,12 +1,11 @@ from torchao.prototype.mx_formats.config import ( MXGemmKernelChoice, - MXInferenceLinearConfig, MXLinearConfig, MXLinearRecipeName, ) # Note: Prototype and subject to change -from torchao.prototype.mx_formats.mx_subclass import ( +from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, NVFP4MMConfig, @@ -18,7 +17,6 @@ __all__ = [ "MXGemmKernelChoice", - "MXInferenceLinearConfig", "MXLinearConfig", "MXLinearRecipeName", "MXFPInferenceConfig", diff --git a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py deleted file mode 100644 index ca0b926ce5..0000000000 --- a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Benchmarking mx quantize/dequantize -""" - -from typing import Optional - -import fire -import tabulate -import torch -from torch.profiler import ProfilerActivity, profile - -from torchao.prototype.mx_formats import config -from torchao.prototype.mx_formats.constants import ( # noqa: E501 - SUPPORTED_ELEM_DTYPES, -) -from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import benchmark_torch_function_in_microseconds - - -def run(profile_folder: Optional[str] = None): - headers = [ - "elem_dtype", - "use_fp4_custom_triton_dequant_kernel", - "q_time_us", - "q_mem_bw_tb_s", - "dq_time_us", - "dq_mem_bw_tb_s", - ] - results = [] - - data_hp = torch.randn(1, 4096, 11008, dtype=torch.bfloat16, device="cuda") - - for elem_dtype in SUPPORTED_ELEM_DTYPES: - for use_fp4_custom_triton_dequant_kernel in (False, True): - config.use_fp4_custom_triton_dequant_kernel = ( - use_fp4_custom_triton_dequant_kernel - ) - - if ( - elem_dtype != torch.float4_e2m1fn_x2 - and use_fp4_custom_triton_dequant_kernel # noqa: E501 - ): - # custom_triton_kernels only works for fp4 - continue - - print( - "elem_dtype", - elem_dtype, - "use_fp4_custom_triton_dequant_kernel", - use_fp4_custom_triton_dequant_kernel, - ) - - data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32) - - if not use_fp4_custom_triton_dequant_kernel: - quant = torch.compile(MXTensor.to_mx, fullgraph=True) - dequant = torch.compile(data_lp.to_dtype, fullgraph=True) - else: - # As of 2024-04, torch.compile didn't work with the - # handwritten triton kernel, - # crashed on tl.interleave: - # https://github.com/pytorch/pytorch/issues/123967 - # As of 2024-05-24, now there is message asking to convert to - # an opaque custom op: - # https://gist.github.com/vkuzo/0b0b90dca03bdb8e0446e4135644238a # noqa: E501 - # TODO(future): make this better - quant = MXTensor.to_mx - dequant = data_lp.to_dtype - - # warm up - quant(data_hp, elem_dtype, block_size=32) - res = dequant(torch.bfloat16) - - if profile_folder is not None: - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - ) as prof: - for _ in range(5): - quant(data_hp, elem_dtype, block_size=32) - dequant(torch.bfloat16) - prof.export_chrome_trace( - profile_folder - + f"/mx_qdq_{elem_dtype}_{use_fp4_custom_triton_dequant_kernel}.json" # noqa: E501 - ) - - q_execution_time_us = benchmark_torch_function_in_microseconds( - quant, data_hp, elem_dtype, block_size=32 - ) - dq_execution_time_us = benchmark_torch_function_in_microseconds( - dequant, torch.bfloat16 - ) - print(f"q time: {q_execution_time_us} us") - print(f"dq time: {dq_execution_time_us} us") - - # memory reads per element: - byte_per_stored_element = 1.0 # fp8 or 2xfp4 - byte_per_stored_exp_element = 1.0 # e8m0 - byte_per_dequantized_element = 2.0 # bfloat16 - mem_reads_writes_bytes = ( - # read raw data - (data_lp._data.numel() * byte_per_stored_element) - + - # read exponent - (data_lp._scale_e8m0.numel() * byte_per_stored_exp_element) - + - # write dequant - (res.numel() * byte_per_dequantized_element) - ) - # note: the above also works for quant, with reads/writes in - # reverse - - q_mem_bw_tb_s = (mem_reads_writes_bytes / 1e12) / ( - q_execution_time_us / 1e6 - ) - dq_mem_bw_tb_s = (mem_reads_writes_bytes / 1e12) / ( - dq_execution_time_us / 1e6 - ) - print(f"q mem bw: {q_mem_bw_tb_s} TB/s") - print(f"dq mem bw: {dq_mem_bw_tb_s} TB/s") - - results.append( - ( - elem_dtype, - use_fp4_custom_triton_dequant_kernel, - q_execution_time_us, - q_mem_bw_tb_s, - dq_execution_time_us, - dq_mem_bw_tb_s, - ) - ) - config.use_fp4_custom_triton_dequant_kernel = False - - torch._dynamo.reset() - - print(tabulate.tabulate(results, headers=headers, floatfmt=".2f")) - - -if __name__ == "__main__": - fire.Fire(run) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 525bf21fc6..388af07874 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -12,8 +12,6 @@ from torchao.core.config import AOBaseConfig from torchao.prototype.mx_formats.constants import ( - DTYPE_FP6_E2M3, - DTYPE_FP6_E3M2, DTYPE_TO_SHORT_STR, SUPPORTED_ELEM_DTYPES, ) @@ -33,14 +31,54 @@ class MXGemmKernelChoice(Enum): CUBLAS = "cublas" +class MXFP8Dim1CastKernelChoice(Enum): + """ + Defines which kernel to use for mxfp8 casting. Currently custom casting kernels are + only for scaling along dim1, and torch native code is always used for scaling along dim0. + """ + + TRITON = "triton" + CUDA = "cuda" + TORCH = "torch" + + # Pre-made recipes for common configurations class MXLinearRecipeName(Enum): MXFP8_EMULATED = "mxfp8_emulated" MXFP8_CUBLAS = "mxfp8_cublas" + MXFP8_CUBLAS_RCEIL = "mxfp8_cublas_rceil" MXFP4_EMULATED = "mxfp4_emulated" MXFP4_CUTLASS = "mxfp4_cutlass" +class ScaleCalculationMode(Enum): + """ + Enum representing the different methods for calculating MX block scaling. + There are four methods available: + + FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). + It result in overflow issues for large values and bad for gradient quantization. + + RCEIL: The method is to apply ceil to the ratio of max_abs(v) and max_pos. + This method's detail is described in https://docs.nvidia.com/cuda/cublas/index.html#d-block-quantization + Section "Computing scaling and conversion factors for FP8 with UE8M0 scales" + + CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. + It uses X = 2^ceil(log2(max_abs(v))-max_exp). + + EVEN: This method is a trade-off between FLOOR and CEIL. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). + It provides better accuracy for MX4 training compared to FLOOR and CEIL. + Note: EVEN does not work with torch.compile yet: + https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4 + + """ + + FLOOR = "floor" + RCEIL = "rceil" + CEIL = "ceil" + EVEN = "even" + + def _validate_elem_dtype(elem_dtype): assert elem_dtype in SUPPORTED_ELEM_DTYPES, ( f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {elem_dtype}" @@ -66,6 +104,22 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): ) +def _validate_mxfp8_cast_kernel_choice( + mxfp8_cast_kernel_choice, scale_calculation_mode +): + if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON: + assert scale_calculation_mode == ScaleCalculationMode.FLOOR, ( + f"unsupported ScaleCalculationMode value {scale_calculation_mode} for dim1 triton cast" + ) + elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA: + assert scale_calculation_mode in ( + ScaleCalculationMode.FLOOR, + ScaleCalculationMode.RCEIL, + ), ( + f"unsupported ScaleCalculationMode value {scale_calculation_mode} for dim1 cuda cast" + ) + + @dataclass class MXLinearConfig(AOBaseConfig): # block size for scaling, default is 32 to match @@ -85,13 +139,14 @@ class MXLinearConfig(AOBaseConfig): # on the given hardware an exception will be thrown gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED - # If True, uses a custom triton kernel for cast to mxfp8 across dim1 + # define which kernel to use for mxfp8 casting # TODO(1945): remove this config option once torch.compile gives us # a fast kernel - use_fp8_dim1_cast_triton_kernel: bool = False + mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice = ( + MXFP8Dim1CastKernelChoice.TORCH + ) - # If True, uses a custom triton kernel for fp4 dequantize - use_fp4_custom_triton_dequant_kernel: bool = False + scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR def __post_init__(self): _validate_elem_dtype(self.elem_dtype) @@ -104,6 +159,9 @@ def __post_init__(self): if self.elem_dtype_grad_output_override is not None: _validate_elem_dtype(self.elem_dtype_grad_output_override) assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" + _validate_mxfp8_cast_kernel_choice( + self.mxfp8_cast_kernel_choice, self.scale_calculation_mode + ) @staticmethod def from_recipe_name( @@ -123,7 +181,16 @@ def from_recipe_name( if recipe_name is MXLinearRecipeName.MXFP8_EMULATED: return MXLinearConfig() elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS: - return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS) + return MXLinearConfig( + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, + ) + elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL: + return MXLinearConfig( + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, + scale_calculation_mode=ScaleCalculationMode.RCEIL, + ) elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2) elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: @@ -146,51 +213,7 @@ def short_str(self) -> str: if self.elem_dtype_grad_output_override is not None: s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}" s += f", kernel={self.gemm_kernel_choice.value}" - if self.use_fp8_dim1_cast_triton_kernel: - s += ", use_fp8_dim1_cast_triton_kernel=True" - if self.use_fp4_custom_triton_dequant_kernel: - s += ", use_fp4_custom_triton_dequant_kernel=True" - return s - - -@dataclass -class MXInferenceLinearConfig(AOBaseConfig): - # block size for scaling, default is 32 to match - # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, - # section 5.2 - block_size: int = 32 - - # element dtype, used for activations, weights and gradients - elem_dtype: Any = torch.float8_e4m3fn - # TODO(future PR): support different elem_dtype for activations vs weights - - # defines the gemm kernel choice, if the chosen kernel is not supported - # on the given hardware an exception will be thrown - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED - - # If True, uses a custom triton kernel for fp4 dequantize - use_fp4_custom_triton_dequant_kernel: bool = False - - # If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton - # kernels (fused unpack/dequantize). - pack_fp6: bool = True - - def __post_init__(self): - _validate_elem_dtype(self.elem_dtype) - _validate_gemm_kernel_choice( - self.gemm_kernel_choice, self.block_size, self.elem_dtype - ) - - def short_str(self) -> str: - """ - Returns a concise representation of the current config. - """ - s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}" - s += f", kernel={self.gemm_kernel_choice.value}" - if self.use_fp4_custom_triton_dequant_kernel: - s += ", use_fp4_custom_triton_dequant_kernel=True" - if self.elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2) and self.pack_fp6: - s += ", pack_fp6=True" + s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}" + if self.scale_calculation_mode != ScaleCalculationMode.FLOOR: + s += f", scale_calculation_mode={self.scale_calculation_mode}" return s - - # TODO(future PR): add a recipe to config API for inference diff --git a/torchao/prototype/mx_formats/constants.py b/torchao/prototype/mx_formats/constants.py index ffac3b1d5f..3111bc771b 100644 --- a/torchao/prototype/mx_formats/constants.py +++ b/torchao/prototype/mx_formats/constants.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least # This is conceptually an enum of non-core dtypes # TODO(future PR): change to a cleaner way to represent this without @@ -23,7 +23,7 @@ ] SUPPORTED_ELEM_DTYPES = ( SUPPORTED_ELEM_DTYPES + [torch.float4_e2m1fn_x2] - if TORCH_VERSION_AT_LEAST_2_8 + if torch_version_at_least("2.8.0") else SUPPORTED_ELEM_DTYPES ) @@ -33,7 +33,7 @@ DTYPE_FP6_E2M3: "f6e2m3", DTYPE_FP6_E3M2: "f6e3m2", } -if TORCH_VERSION_AT_LEAST_2_8: +if torch_version_at_least("2.8.0"): DTYPE_TO_SHORT_STR[torch.float4_e2m1fn_x2] = "f4e2m1" F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 diff --git a/torchao/prototype/mx_formats/mx_subclass.py b/torchao/prototype/mx_formats/inference_workflow.py similarity index 66% rename from torchao/prototype/mx_formats/mx_subclass.py rename to torchao/prototype/mx_formats/inference_workflow.py index d1be8a04f4..34cf9e9506 100644 --- a/torchao/prototype/mx_formats/mx_subclass.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -6,7 +6,6 @@ import types from dataclasses import dataclass -from typing import Optional import torch @@ -18,16 +17,18 @@ _validate_elem_dtype, _validate_gemm_kernel_choice, ) -from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor -from torchao.quantization.quant_api import to_linear_activation_quantized +from torchao.prototype.mx_formats.mx_tensor import MXTensor, QuantizeTensorToMXKwargs +from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4MMConfig, + NVFP4Tensor, + QuantizeTensorToNVFP4Kwargs, +) from torchao.quantization.transform_module import ( register_quantize_module_handler, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100, + torch_version_at_least, ) @@ -90,26 +91,6 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}" -def _input_activation_quant_func_mxfp( - x: torch.Tensor, - activation_dtype: torch.dtype, - block_size: int, - scale: Optional[torch.Tensor] = None, -): - """ """ - - # TODO scale for static quant - - activation = MXTensor.to_mx( - x, - activation_dtype, - block_size=block_size, - gemm_kernel_choice=None, # Get from weight - pack_fp6=False, # TODO - ) - return activation - - @register_quantize_module_handler(MXFPInferenceConfig) def _mx_inference_linear_transform( module: torch.nn.Module, config: MXFPInferenceConfig @@ -118,32 +99,26 @@ def _mx_inference_linear_transform( # TODO handle AMD assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now" - activation_dtype = config.activation_dtype - weight_dtype = config.weight_dtype weight = module.weight assert weight.dtype == torch.bfloat16, ( f"Only supporting bf16 out dtype for now, got {weight.dtype}" ) + act_quant_kwargs = QuantizeTensorToMXKwargs( + elem_dtype=config.activation_dtype, + block_size=config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, + pack_fp6=False, + ) # Convert weight to MX Tensor quantized_weight = MXTensor.to_mx( weight, - weight_dtype, + config.weight_dtype, block_size=config.block_size, gemm_kernel_choice=config.gemm_kernel_choice, pack_fp6=False, # TODO - ) - - input_quant_func = _input_activation_quant_func_mxfp - input_quant_kwargs = { - "block_size": config.block_size, - "activation_dtype": activation_dtype, - "scale": None, - } - - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + act_quant_kwargs=act_quant_kwargs, ) module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) @@ -157,18 +132,23 @@ class NVFP4InferenceConfig(AOBaseConfig): NVIDIA FP4 (NVFP4) Inference Quantization Configuration This is a specialized configuration for NVIDIA's FP4 format. - All parameters are fixed in the NVFP4 implementation except mm_config: + Configuration parameters: - mm_config: NVFP4MMConfig, which can be set to DYNAMIC or WEIGHT_ONLY (emulated mm in high precision) + - use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: False) - Data: float4_e2m1fn_x2 - Scales: float8_e4m3fn - Block size: 16 along the reduction dim + + Note: Triton kernel only works with DYNAMIC mode and has constraints that input dimensions + must satisfy M % 128 == 0 and K % 64 == 0. Will automatically fallback when constraints aren't met. """ mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC + use_triton_kernel: bool = True def __post_init__(self): # Validate PyTorch version - if not TORCH_VERSION_AT_LEAST_2_8: + if not torch_version_at_least("2.8.0"): raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later") @@ -184,29 +164,73 @@ def _nvfp4_inference_linear_transform( weight = module.weight + if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: + raise RuntimeError( + f"NVFP4 only supports weight shape divisible by 16, got {weight.shape}" + ) + if module.bias is not None and weight.dtype == torch.float32: raise RuntimeError( "Bias is not supported when module weight is in fp32 (out_dtype=Float32). " "Please use bfloat16 or float16 weights, or remove the bias from the linear layer." ) + act_quant_kwargs = None + if config.mm_config == NVFP4MMConfig.DYNAMIC: + act_quant_kwargs = QuantizeTensorToNVFP4Kwargs() + quantized_weight = NVFP4Tensor.to_nvfp4( weight, - mm_config=config.mm_config, + is_swizzled_scales=True, + use_triton_kernel=False, # Always use traditional construction for weights + act_quant_kwargs=act_quant_kwargs, ) - + # Set triton preference after construction + quantized_weight.use_triton_kernel = config.use_triton_kernel module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals( - [ - MXTensor, - NVFP4Tensor, - NVFP4MMConfig, - MXGemmKernelChoice, - _input_activation_quant_func_mxfp, - ] - ) +torch.serialization.add_safe_globals( + [ + MXTensor, + NVFP4Tensor, + NVFP4MMConfig, + MXGemmKernelChoice, + ] +) + + +import torch.nn as nn + + +def _auto_filter_for_nfp4(mod: nn.Module, fqn: str) -> bool: + """Generic Filter fn for NVFP4 that is best practice for most models.""" + # Define any FQNs you want to exclude directly in the function + filter_fqns = ["embedder", "embed", "embedding", "time_text_embed"] + + # Only support Linear modules + if not isinstance(mod, nn.Linear): + return False + + # If the fqn matches any filtered fqn, then we should not convert this module + is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns) + if is_filtered_fqn: + return False + + # All dims must be divisible by 16 due to float8 hardware requirements. + N, K = mod.weight.shape + dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0 + if not dims_multiples_of_16: + return False + if N <= 64: + print("skiping small linear layer") + # TODO cublas doesn't like this one + return False + + # Dims below these thresholds may result in worse performance + if K <= 1024 and N <= 1024: + print("skiping small linear layer") + return False + return True diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 72cbba1802..5811dd9d21 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -4,33 +4,40 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +import logging +from typing import Optional, Tuple import numpy as np import torch +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.experimental import register_sharding from torch.utils._triton import has_triton from torchao.prototype.custom_fp_utils import ( _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import ( + is_sm_at_least_100, + torch_version_at_least, +) # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert # at the callsite prevents usage of this on unsupported versions. -if TORCH_VERSION_AT_LEAST_2_4 and has_triton(): +if has_triton(): from torch._inductor.runtime.triton_helpers import libdevice from torchao.prototype.mx_formats.constants import ( E8M0_EXPONENT_BIAS, E8M0_EXPONENT_NAN_VAL, - F4_E2M1_EXP_BIAS, F6_E2M3_EXP_BIAS, F6_E3M2_EXP_BIAS, F32_EXP_BIAS, ) +logger = logging.getLogger(__name__) + def get_bits(x: torch.Tensor) -> str: bits_per_byte = 8 @@ -191,138 +198,6 @@ def _fp4_packed_to_bf16( output = output.to(tl.bfloat16) return output - @triton.jit - def triton_f4_to_bf16_kernel( - x_ptr, - output_ptr, - n_elements_in, - sign_mask_f4: tl.constexpr, - mantissa_mask_f4: tl.constexpr, - mbits_f4_e2m1: tl.constexpr, - ebits_f4_e2m1: tl.constexpr, - f4_e2m1_exp_bias: tl.constexpr, - mbits_f32: tl.constexpr, - ebits_f32: tl.constexpr, - f32_exp_bias: tl.constexpr, - zero_bits_f32: tl.constexpr, - zero_point_five_bits_f32: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - n_elements_out = n_elements_in * 2 - BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2 - - block_start_in = pid * BLOCK_SIZE_IN - offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN) - - mask_in = offsets_in < n_elements_in - - # packed uint8 - x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) - output = _fp4_packed_to_bf16( - x_packed, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - ) - - # set up output offsets - block_start_out = pid * BLOCK_SIZE_OUT - offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT) - mask_out = offsets_out < n_elements_out - - tl.store(output_ptr + offsets_out, output, mask=mask_out) - - @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_IN": 128}), - triton.Config({"BLOCK_SIZE_IN": 256}), - triton.Config({"BLOCK_SIZE_IN": 512}), - triton.Config({"BLOCK_SIZE_IN": 1024}), - triton.Config({"BLOCK_SIZE_IN": 2048}), - ], - key=["n_elements_in"], - ) - @triton.jit - def triton_f4_to_scaled_bf16_kernel( - x_ptr, - s_ptr, - output_ptr, - n_elements_in, - mx_block_size: tl.constexpr, - sign_mask_f4: tl.constexpr, - mantissa_mask_f4: tl.constexpr, - mbits_f4_e2m1: tl.constexpr, - ebits_f4_e2m1: tl.constexpr, - f4_e2m1_exp_bias: tl.constexpr, - mbits_f32: tl.constexpr, - ebits_f32: tl.constexpr, - f32_exp_bias: tl.constexpr, - zero_bits_f32: tl.constexpr, - zero_point_five_bits_f32: tl.constexpr, - e8m0_exponent_bias: tl.constexpr, - e8m0_exponent_nan_val: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - n_elements_out = n_elements_in * 2 - n_elements_s = n_elements_out // 32 - - BLOCK_SIZE_S: tl.constexpr = BLOCK_SIZE_IN // 16 - BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2 - - block_start_in = pid * BLOCK_SIZE_IN - offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN) - mask_in = offsets_in < n_elements_in - # packed uint8 - x_packed = tl.load(x_ptr + offsets_in, mask=mask_in) - output = _fp4_packed_to_bf16( - x_packed, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - ) - - # load scale - block_start_s = pid * BLOCK_SIZE_S - offsets_s = block_start_s + tl.arange(0, BLOCK_SIZE_S) - mask_s = offsets_s < n_elements_s - s = tl.load(s_ptr + offsets_s, mask=mask_s) - - # create the scale in bf16 - s_offset = s.to(tl.int16) - e8m0_exponent_bias - s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16) - s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan")) - - # multiply output by scale - # TODO(later): see if manipulating the exponent instead of fp - # multiplication is going to give a significant speedup - output = tl.reshape(output, (BLOCK_SIZE_OUT // mx_block_size, mx_block_size)) # noqa: E501 - s_fp = tl.reshape(s_fp, (BLOCK_SIZE_S // 1, 1)) - output = output * s_fp - output = tl.reshape(output, (BLOCK_SIZE_OUT,)) - - # set up output offsets - block_start_out = pid * BLOCK_SIZE_OUT - offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT) - mask_out = offsets_out < n_elements_out - - tl.store(output_ptr + offsets_out, output, mask=mask_out) - @triton.jit def _fp6_packed_to_bf16( packed_4bits_a, @@ -619,46 +494,6 @@ def triton_pack_uint6_kernel( else: - def triton_f4_to_bf16_kernel( - x_ptr, - output_ptr, - n_elements_in, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - BLOCK_SIZE_IN, - ): - raise AssertionError("unsupported without triton") - - def triton_f4_to_scaled_bf16_kernel( - x_ptr, - s_ptr, - output_ptr, - n_elements_in, - mx_block_size, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - e8m0_exponent_bias, - e8m0_exponent_nan_val, - BLOCK_SIZE_IN, - ): - raise AssertionError("unsupported without triton") - def triton_f6_to_bf16_kernel( x_ptr, output_ptr, @@ -700,83 +535,6 @@ def triton_pack_uint6_kernel( raise AssertionError("unsupported without triton") -def triton_f4_to_bf16(x: torch.Tensor): - """ - Input: a tensor of packed fp4 values - Output: a tensor of bfloat16 values - - Note: this function is only used in testing, so we can test - the numerical correctness of the cast without the scaling. - """ - new_shape = (*x.shape[:-1], x.shape[-1] * 2) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - n_elements_in = x.numel() - grid = lambda meta: ( # noqa: E731 - triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), - ) # noqa: E731,E501 - triton_f4_to_bf16_kernel[grid]( - x, - output, - n_elements_in, - sign_mask_f4=SIGN_MASK_F4, - mantissa_mask_f4=MANTISSA_MASK_F4, - mbits_f4_e2m1=MBITS_F4_E2M1, - ebits_f4_e2m1=EBITS_F4_E2M1, - f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS, - mbits_f32=MBITS_F32, - ebits_f32=EBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - zero_bits_f32=ZERO_BITS_F32, - zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32, - BLOCK_SIZE_IN=512, - ) - return output - - -def triton_f4_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, -): - """ - Input: a tensor of packed fp4 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) - assert TORCH_VERSION_AT_LEAST_2_4, "unsupported" - new_shape = (*x.shape[:-1], x.shape[-1] * 2) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - n_elements_in = x.numel() - grid = lambda meta: ( # noqa: E731 - triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]), - ) - triton_f4_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_elements_in, - mx_block_size, - sign_mask_f4=SIGN_MASK_F4, - mantissa_mask_f4=MANTISSA_MASK_F4, - mbits_f4_e2m1=MBITS_F4_E2M1, - ebits_f4_e2m1=EBITS_F4_E2M1, - f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS, - mbits_f32=MBITS_F32, - ebits_f32=EBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - zero_bits_f32=ZERO_BITS_F32, - zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output - - def triton_f6_e2m3_to_bf16(x: torch.Tensor) -> torch.Tensor: """ Input: a tensor of packed fp6 values @@ -849,119 +607,104 @@ def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor: return output -if TORCH_VERSION_AT_LEAST_2_4: - - @torch.library.custom_op("ao::triton_f6_e2m3_to_scaled_bf16", mutates_args=()) - def triton_f6_e2m3_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) +@torch.library.custom_op("ao::triton_f6_e2m3_to_scaled_bf16", mutates_args=()) +def triton_f6_e2m3_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + s_e8m0 = s_e8m0.view(torch.uint8) - packed_mx_block_size = 3 * mx_block_size // 4 + packed_mx_block_size = 3 * mx_block_size // 4 - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda - n_mx_blocks = x.shape[0] - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E2M3, - mbits_f6=MBITS_F6_E2M3, - f6_exp_bias=F6_E2M3_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output + n_mx_blocks = x.shape[0] + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_scaled_bf16_kernel[grid]( + x, + s_e8m0, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E2M3, + mbits_f6=MBITS_F6_E2M3, + f6_exp_bias=F6_E2M3_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, + ) + return output - @torch.library.custom_op("ao::triton_f6_e3m2_to_scaled_bf16", mutates_args=()) - def triton_f6_e3m2_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) - packed_mx_block_size = 3 * mx_block_size // 4 +@torch.library.custom_op("ao::triton_f6_e3m2_to_scaled_bf16", mutates_args=()) +def triton_f6_e3m2_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + s_e8m0 = s_e8m0.view(torch.uint8) - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + packed_mx_block_size = 3 * mx_block_size // 4 - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - n_mx_blocks = x.numel() // packed_mx_block_size - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E3M2, - mbits_f6=MBITS_F6_E3M2, - f6_exp_bias=F6_E3M2_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda - @triton_f6_e3m2_to_scaled_bf16.register_fake - def _(x, s_e8m0, mx_block_size): - _padded_mx_block_size = 3 * mx_block_size // 4 - out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) - return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) + n_mx_blocks = x.numel() // packed_mx_block_size + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_scaled_bf16_kernel[grid]( + x, + s_e8m0, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E3M2, + mbits_f6=MBITS_F6_E3M2, + f6_exp_bias=F6_E3M2_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, + ) + return output - @triton_f6_e2m3_to_scaled_bf16.register_fake - def _(x, s_e8m0, mx_block_size): - _padded_mx_block_size = 3 * mx_block_size // 4 - out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) - return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) -else: +@triton_f6_e3m2_to_scaled_bf16.register_fake +def _(x, s_e8m0, mx_block_size): + _padded_mx_block_size = 3 * mx_block_size // 4 + out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) + return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) - def triton_f6_e2m3_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - raise AssertionError("unsupported without torch >= 2.4") - def triton_f6_e3m2_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, - ) -> torch.Tensor: - raise AssertionError("unsupported without torch >= 2.4") +@triton_f6_e2m3_to_scaled_bf16.register_fake +def _(x, s_e8m0, mx_block_size): + _padded_mx_block_size = 3 * mx_block_size // 4 + out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) + return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) # pack/unpack code copy-pasted from @@ -1043,51 +786,45 @@ def pack_uint6_pytorch(uint8_data: torch.Tensor) -> torch.Tensor: ).view(packed_shape) -if TORCH_VERSION_AT_LEAST_2_4: - - @torch.library.custom_op("ao::pack_uint6", mutates_args=()) - def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - # ensure input data is contiguous before passing to kernel - assert uint8_data.is_contiguous() +@torch.library.custom_op("ao::pack_uint6", mutates_args=()) +def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: + # ensure input data is contiguous before passing to kernel + assert uint8_data.is_contiguous() - # tensor should already be of shape [..., mx_block_size] - mx_block_size = uint8_data.shape[-1] - assert mx_block_size % 4 == 0 + # tensor should already be of shape [..., mx_block_size] + mx_block_size = uint8_data.shape[-1] + assert mx_block_size % 4 == 0 - # effective mx block size since we're packing 2 fp4 into 1 uint8 - packed_mx_block_size = 3 * mx_block_size // 4 - packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] - n_mx_blocks = uint8_data.numel() // mx_block_size + # effective mx block size since we're packing 2 fp4 into 1 uint8 + packed_mx_block_size = 3 * mx_block_size // 4 + packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] + n_mx_blocks = uint8_data.numel() // mx_block_size - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - # contiguous uint8 container in which we can store the unpacked tensor - packed_uint8_data = torch.empty( - packed_shape, dtype=torch.uint8, device=uint8_data.device - ) + # contiguous uint8 container in which we can store the unpacked tensor + packed_uint8_data = torch.empty( + packed_shape, dtype=torch.uint8, device=uint8_data.device + ) - triton_pack_uint6_kernel[grid]( - uint8_data, - packed_uint8_data, - n_mx_blocks, - MX_BLOCK_SIZE=mx_block_size, - PACKED_MX_BLOCK_SIZE=packed_mx_block_size, - ) + triton_pack_uint6_kernel[grid]( + uint8_data, + packed_uint8_data, + n_mx_blocks, + MX_BLOCK_SIZE=mx_block_size, + PACKED_MX_BLOCK_SIZE=packed_mx_block_size, + ) - return packed_uint8_data + return packed_uint8_data - @pack_uint6.register_fake - def _(uint8_data): - out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) - return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) -else: - def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - # Dummy placeholder op for torch < 2.4 - raise AssertionError("fp6 packing unsupported without torch >= 2.4") +@pack_uint6.register_fake +def _(uint8_data): + out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) + return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) -if TORCH_VERSION_AT_LEAST_2_7 and has_triton(): +if torch_version_at_least("2.7.0") and has_triton(): import triton import triton.language as tl from torch.library import triton_op, wrap_triton @@ -1315,7 +1052,6 @@ def triton_to_mxfp8_dim1( * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 """ assert x.is_contiguous(), "`x` must be contiguous" - assert x.dtype == torch.bfloat16 assert inner_block_size <= 32 # Get tensor shape @@ -1363,6 +1099,16 @@ def triton_to_mxfp8_dim1( col_scale.view(torch.float8_e8m0fnu), ) + @register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default) + def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32): + replicate = ([Replicate(), Replicate()], [Replicate(), None]) + # Note that the data is returned transposed, which is why + # we flip the sharding dim below + shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None]) + shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None]) + acceptable_shardings = [replicate, shard_dim0, shard_dim1] + return acceptable_shardings + def triton_to_mxfp8_dim1_reference( x_hp: torch.Tensor, block_size ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -1379,7 +1125,7 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu) return ( x_hp_d1_normalized.t(), - scale_e8m0_dim1, + scale_e8m0_dim1.unsqueeze(-1), ) @triton.jit @@ -1389,6 +1135,7 @@ def triton_scale_swizzle( scale_cols, output_ptr, input_row_stride, + input_col_stride, output_block_stride, BLOCK_ROWS: tl.constexpr, BLOCK_COLS: tl.constexpr, @@ -1408,7 +1155,7 @@ def triton_scale_swizzle( mask = (global_rows < scale_rows) & (global_cols < scale_cols) input_scales = tl.load( - scale_ptr + global_rows * input_row_stride + global_cols, + scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride, mask=mask, other=0.0, ) @@ -1432,9 +1179,10 @@ def triton_scale_swizzle( scales_flat, ) + @torch.library.custom_op("torchao::triton_mx_block_rearrange", mutates_args=()) def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: """ - Rearranges an E8M0 tensor scale from row-major format to block-scaled swizzle format. + Rearranges an E8M0 tensor scale to block-scaled swizzle format. This format is suitable for Tmem as described in NVIDIA documentation: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout @@ -1448,7 +1196,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: assert scale_tensor.element_size() == 1, ( "Expected element size to be 1 byte (8 bits)" ) - assert scale_tensor.is_contiguous(), "Input tensor must be contiguous" rows, cols = scale_tensor.shape @@ -1461,7 +1208,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: out = scale_tensor.new_empty((padded_rows, padded_cols)) # Input stride (for row-major format) - input_row_stride = cols + input_row_stride = scale_tensor.stride()[0] + input_col_stride = scale_tensor.stride()[1] # We probably want handle multiple blocks per tile but for now keep it simple BLOCK_ROWS, BLOCK_COLS = 128, 4 @@ -1480,6 +1228,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: cols, out.view(torch.uint8), input_row_stride, + input_col_stride, output_block_stride, BLOCK_ROWS=BLOCK_ROWS, BLOCK_COLS=BLOCK_COLS, @@ -1487,6 +1236,227 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: return out + @triton.jit + def convert_fp32_to_fp4_packed(x_pairs): + """Convert FP32 pairs to packed FP4 format. + + This function takes tensor where consecutive values along the last dimension + are packed together into single bytes. + + Args: + x_pairs: [Tensor, Tensor] both w/ shapes [..., 1] where zipped last dimension contains + interleaved pairs of FP32 values to be packed together. + + Returns: + Packed tensor with shape [...] (last dimension removed) where each + element is an int8 containing 2 FP4 values: + - First value of pair → high nibble (bits 4-7) + - Second value of pair → low nibble (bits 0-3) + + Example: + Input: [128, 32, 2] containing FP32 pairs + Output: [128, 32] containing packed FP4 bytes + + """ + + x_fp4x2 = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 byte0, byte1, byte2, byte3; + cvt.rn.satfinite.e2m1x2.f32 byte0, $1, $5; + cvt.rn.satfinite.e2m1x2.f32 byte1, $2, $6; + cvt.rn.satfinite.e2m1x2.f32 byte2, $3, $7; + cvt.rn.satfinite.e2m1x2.f32 byte3, $4, $8; + mov.b32 $0, {byte0, byte1, byte2, byte3}; + } + """, + constraints=("=r,r,r,r,r,r,r,r,r"), + args=x_pairs, + dtype=tl.uint8, + is_pure=True, + pack=4, + ) + + return x_fp4x2 + + # Sauce: https://github.com/gau-nernst/quantized-training + @triton.jit + def quantize_nvfp4_triton_kernel( + x_ptr, + tensor_scale_ptr, + q_ptr, + s_ptr, + stride_xm, + stride_xn, + M, + N, + USE_TENSOR_SCALE: tl.constexpr, + MASK_SCALES: tl.constexpr, + ): + F4_E2M1_MAX = 6.0 + F8E4M3_MAX = 448.0 + E4M3_EPS = 1.5258789e-05 + + pid_m = tl.program_id(1) + pid_n = tl.program_id(0) + + offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] + offs_n = pid_n * 64 + tl.arange(0, 64)[None, :] + if MASK_SCALES: + mask = (offs_m < M) & (offs_n < N) + other = 0.0 + else: + mask = None + other = None + x = tl.load( + x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other + ) # [128, 64] + x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16] + + # Compute block-wise scales + block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4] + + if USE_TENSOR_SCALE: + # Two-level scaling: quantize block scales with per-tensor scale + tensor_scale = tl.load(tensor_scale_ptr) + + # First compute block scales + block_scale_f32 = (block_amax / F4_E2M1_MAX).to(tl.float32) + + # Quantize the block scales with per-tensor scale + scaled_block_scales = block_scale_f32 / tensor_scale + scaled_block_scales = tl.clamp(scaled_block_scales, E4M3_EPS, F8E4M3_MAX) + scales = scaled_block_scales.to(tl.float8e4nv) + + # Apply combined scale to data: per_tensor_scale * quantized_block_scale + total_scale = tensor_scale * scales.to(tl.float32)[:, :, None] + x_blocks = tl.div_rn(x_blocks, total_scale) + else: + # Single-level scaling: use block scales directly + scales_f32 = block_amax / F4_E2M1_MAX + scales_f32 = tl.clamp(scales_f32, E4M3_EPS, F8E4M3_MAX) + scales = scales_f32.to(tl.float8e4nv) + + # Apply block scale to data + total_scale = scales.to(tl.float32)[:, :, None] + x_blocks = tl.div_rn(x_blocks, total_scale) + + # NVIDIA layout for scales + if MASK_SCALES: + # Create offsets for the scale dimensions (4 blocks per row) + scale_offs_n = pid_n * 4 + tl.arange(0, 4)[None, :] + + # Mask out scales to 0 if we are not aligned to 128 x 64 + scales = tl.where( + (offs_m < M) & (scale_offs_n < N // 16), + scales, + 0.0, + ) + packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16) + offs_m = tl.arange(0, 32)[:, None] + offs_n = tl.arange(0, 16)[None, :] + tl.store( + s_ptr + + (pid_m * tl.num_programs(0) + pid_n) * (32 * 16) + + offs_m * 16 + + offs_n, + packed_scales, + ) + + # Convert to FP4 + x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) + offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] + offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] + if MASK_SCALES: + mask = (offs_m < M) & (offs_n < N // 2) + else: + mask = None + tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask) + + @torch.library.custom_op("ao::triton_quantize_nvfp4", mutates_args=()) + def triton_quantize_nvfp4( + x: torch.Tensor, per_tensor_scale: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize a tensor to NVFP4 format. + + Args: + x (torch.Tensor): Input tensor to be quantized. + tensor_scale (Optional[torch.Tensor]): Per-tensor scale for two-level quantization. + If None, uses single-level block-wise quantization only. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout. + + Note: + Since VLLM does not use dyanmo guards we need to make this a custom op + to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES` + """ + M, N = x.shape + # assert M % 128 == 0 and N % 64 == 0 + assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization" + + # Calculate blocks needed + num_scales = N // 16 + n_row_blocks = triton.cdiv(M, 128) + n_col_blocks = triton.cdiv(num_scales, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + # mask out scales to 0 if we are not aligned to 128 x 64 + MASK_SCALES = M % 128 != 0 or N % 64 != 0 + + xq = x.new_empty(M, N // 2, dtype=torch.uint8) + scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn) + + grid = (triton.cdiv(N, 64), triton.cdiv(M, 128)) + + if per_tensor_scale is None: + # Don't allocate tensor, we just steal this since it won't be used in kernel + tensor_scale_ptr = x + use_tensor_scale = False + else: + tensor_scale_ptr = per_tensor_scale + use_tensor_scale = True + + quantize_nvfp4_triton_kernel[grid]( + x, + tensor_scale_ptr, + xq, + scales, + x.stride(0), + x.stride(1), + M, + N, + USE_TENSOR_SCALE=use_tensor_scale, + MASK_SCALES=MASK_SCALES, + ) + + return scales, xq.view(torch.uint8) + + @triton_quantize_nvfp4.register_fake + def _(x, per_tensor_scale=None): + M, N = x.shape + num_scales = N // 16 + n_row_blocks = triton.cdiv(M, 128) + n_col_blocks = triton.cdiv(num_scales, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + scales = torch.empty( + padded_rows, padded_cols, device=x.device, dtype=torch.float8_e4m3fn + ) + xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8) + return scales, xq + + @triton_mx_block_rearrange.register_fake + def _(scale_tensor): + rows, cols = scale_tensor.shape + n_row_blocks = triton.cdiv(rows, 128) + n_col_blocks = triton.cdiv(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + return scale_tensor.new_empty((padded_rows, padded_cols)) else: def triton_to_mxfp8_dim1( @@ -1495,9 +1465,153 @@ def triton_to_mxfp8_dim1( raise AssertionError("needs torch version 2.8+ and triton") def triton_to_mxfp8_dim1_reference( - x_hp: torch.Tensor, block_size + x_hp: torch.Tensor, + block_size, ) -> Tuple[torch.Tensor, torch.Tensor]: raise AssertionError("needs torch version 2.8+ and triton") def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: raise AssertionError("needs torch version 2.8+ and triton") + + def triton_quantize_nvfp4( + x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise AssertionError("needs torch version 2.8+ and triton") + + +mxfp8_cuda_extension_available = False +if is_sm_at_least_100(): + try: + # MXFP8 CUDA kernel is only built on SM100+. Furthermore, + # currently our CI runners are not SM100+, so the user needs to build + # from source. + # TODO(#2932): improve this + from torchao.prototype import mxfp8_cuda + + mxfp8_cuda_extension_available = True + except ImportError: + logging.debug("Skipping import of torchao.prototype.mxfp8_cuda") + +if mxfp8_cuda_extension_available: + # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string. + # Currently we have to use an arbitrary string because custom ops don't support enum + # params. + @torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=()) + def mxfp8_quantize_cuda( + x: torch.Tensor, + rowwise: bool = False, + colwise: bool = True, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Input shape must be 2D. + assert x.ndim == 2 + rows, cols = x.shape + + # Block size must be a multiple of 32. + block_size = 32 + assert rows % block_size == 0, "rows must be a multiple of 32" + assert cols % block_size == 0, "cols must be a multiple of 32" + + # Convert scaling mode to expected string format and call into kernel. + output_rowwise, output_colwise, scales_rowwise, scales_colwise = ( + mxfp8_cuda.quantize( + x, + rowwise=rowwise, + colwise=colwise, + scaling_mode=scaling_mode, + ) + ) + return output_rowwise, output_colwise, scales_rowwise, scales_colwise + + @mxfp8_quantize_cuda.register_fake + def _( + x: torch.Tensor, + rowwise: bool = False, + colwise: bool = True, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert x.ndim == 2 + rows, cols = x.shape + block_size = 32 + assert rows % block_size == 0, "rows must be a multiple of 32" + assert cols % block_size == 0, "cols must be a multiple of 32" + num_row_blocks = rows // 32 + num_col_blocks = cols // 32 + + # rowwise + if rowwise: + output_rowwise = x.new_empty(rows, cols, dtype=torch.float8_e4m3fn) + scales_rowwise = x.new_empty( + rows, num_col_blocks, 1, dtype=torch.float8_e8m0fnu + ) + else: + output_rowwise = x.new_empty(0, dtype=torch.float8_e4m3fn) + scales_rowwise = x.new_empty(0, dtype=torch.float8_e8m0fnu) + + # colwise + if colwise: + # column major + output_colwise = torch.empty_strided( + (rows, cols), (1, rows), dtype=torch.float8_e4m3fn, device=x.device + ) + + # colwise scales are written in column-major format to avoid uncoalesced global memory accesses + scales_colwise = torch.empty_strided( + (cols, num_row_blocks), + (1, cols), + dtype=torch.float8_e8m0fnu, + device=x.device, + ) + else: + output_colwise = x.new_empty(0, dtype=torch.float8_e4m3fn) + scales_colwise = x.new_empty(0, dtype=torch.float8_e8m0fnu) + + return output_rowwise, output_colwise, scales_rowwise, scales_colwise + + @register_sharding(torch.ops.torchao.mxfp8_quantize_cuda.default) + def custom_mxfp8_quantize_cuda_dim1_sharding( + x: torch.Tensor, + rowwise: bool = False, + colwise: bool = True, + scaling_mode: str = "floor", + ): + # This function signature can be used to understand the shardings: + # _, colwise_data, _, colwise_scales = mxfp8_quantize_cuda(x, rowwise=False, colwise=True) + + # When inputs and scale are replicated, we return a quantized output tensor (replicated). + inputs_replicated = [None, Replicate(), None, Replicate()] + outputs_replicated = [None, Replicate(), None, None] + rule_for_input_replicated = ( + inputs_replicated, + outputs_replicated, + ) + + # When inputs and scale are sharded along dim 0, + # we return a quantized output tensor (sharded along dim1 due to transpose). + inputs_sharded_dim0 = [None, Shard(0), None, Shard(0)] + outputs_sharded_dim1 = [None, Shard(1), None, None] + rule_for_input_sharded_dim0 = (inputs_sharded_dim0, outputs_sharded_dim1) + + # When inputs and scale are sharded along dim 1, + # we return a quantized output tensor (sharded along dim0 due to transpose). + inputs_sharded_dim1 = [None, Shard(1), None, Shard(1)] + outputs_sharded_dim0 = [None, Shard(0), None, None] + rule_for_input_sharded_dim1 = (inputs_sharded_dim1, outputs_sharded_dim0) + + acceptable_shardings = [ + rule_for_input_replicated, + rule_for_input_sharded_dim0, + rule_for_input_sharded_dim1, + ] + return acceptable_shardings +else: + + def mxfp8_quantize_cuda( + x: torch.Tensor, + rowwise: bool = False, + colwise: bool = True, + scaling_mode: str = "floor", + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError( + "`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details." + ) diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 4db029480f..19d658a6fc 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -11,15 +11,15 @@ from typing import Any, Optional import torch -import torch.nn.functional as F from torchao.prototype.mx_formats.config import ( + MXFP8Dim1CastKernelChoice, MXGemmKernelChoice, - MXInferenceLinearConfig, MXLinearConfig, + ScaleCalculationMode, ) -from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim1 from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -45,7 +45,8 @@ def forward( grad_elem_dtype: Any, block_size: int, gemm_kernel_choice: MXGemmKernelChoice, - use_fp8_dim1_cast_triton_kernel: bool, + mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice, + scale_calculation_mode: ScaleCalculationMode, ): ctx.save_for_backward(input_hp, weight_hp) ctx.in_elem_dtype = in_elem_dtype @@ -53,17 +54,26 @@ def forward( ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size ctx.gemm_kernel_choice = gemm_kernel_choice - ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel + ctx.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice + ctx.scale_calculation_mode = scale_calculation_mode # input @ weight_t = output input_orig_shape = input_hp.shape input_hp_r = input_hp.reshape(-1, input_orig_shape[-1]) input_mx_r_dim0 = MXTensor.to_mx( - input_hp_r, in_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice + input_hp_r, + in_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, + scaling_mode=scale_calculation_mode, ) weight_mx_dim0 = MXTensor.to_mx( - weight_hp, w_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice + weight_hp, + w_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, + scaling_mode=scale_calculation_mode, ) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) @@ -78,7 +88,8 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size gemm_kernel_choice = ctx.gemm_kernel_choice - use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel + mxfp8_cast_kernel_choice = ctx.mxfp8_cast_kernel_choice + scale_calculation_mode = ctx.scale_calculation_mode grad_output_orig_shape = grad_output_hp.shape grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1]) @@ -92,23 +103,19 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice, + scaling_mode=scale_calculation_mode, ) - if use_fp8_dim1_cast_triton_kernel: - weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1( - weight_hp, block_size - ) - weight_mx_dim1 = MXTensor( - weight_mx_dim1_scale.reshape(-1), - weight_mx_dim1_data.t(), - w_elem_dtype, + if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: + weight_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper( + weight_hp, block_size, + w_elem_dtype, weight_hp.dtype, - False, gemm_kernel_choice, - False, + mxfp8_cast_kernel_choice, + scale_calculation_mode, ) - else: weight_hp_t_c = weight_hp.t().contiguous() weight_mx_dim1 = MXTensor.to_mx( @@ -116,6 +123,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): w_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice, + scaling_mode=scale_calculation_mode, ) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( @@ -123,19 +131,15 @@ def backward(ctx, grad_output_hp: torch.Tensor): ) # input_t @ grad_output = grad_weight - if use_fp8_dim1_cast_triton_kernel: - grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1( - grad_output_hp_r, block_size - ) - grad_output_mx_dim1 = MXTensor( - grad_output_mx_dim1_scale.reshape(-1), - grad_output_mx_dim1_data.t(), - grad_elem_dtype, + if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: + grad_output_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper( + grad_output_hp_r, block_size, + grad_elem_dtype, grad_output_hp_r.dtype, - False, gemm_kernel_choice, - False, + mxfp8_cast_kernel_choice, + scale_calculation_mode, ) else: grad_output_mx_dim1 = MXTensor.to_mx( @@ -143,21 +147,18 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice, + scaling_mode=scale_calculation_mode, ) - if use_fp8_dim1_cast_triton_kernel: - input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1( - input_hp_r, block_size - ) - input_t_mx_dim0_tmp = MXTensor( - input_t_mx_dim0_tmp_scale.reshape(-1), - input_t_mx_dim0_tmp_data.t(), - in_elem_dtype, + if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: + input_t_mx_dim0_tmp = _to_mxfp8_dim1_kernel_wrapper( + input_hp_r, block_size, + in_elem_dtype, input_hp_r.dtype, - False, gemm_kernel_choice, - False, + mxfp8_cast_kernel_choice, + scale_calculation_mode, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() else: @@ -166,11 +167,12 @@ def backward(ctx, grad_output_hp: torch.Tensor): in_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice, + scaling_mode=scale_calculation_mode, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None, None, None, None, None + return grad_input, grad_weight, None, None, None, None, None, None, None class MXLinear(torch.nn.Linear): @@ -214,7 +216,8 @@ def forward(self, x): config.elem_dtype_grad_output_override or config.elem_dtype, config.block_size, config.gemm_kernel_choice, - config.use_fp8_dim1_cast_triton_kernel, + config.mxfp8_cast_kernel_choice, + config.scale_calculation_mode, ) if self.bias is not None: y = y + self.bias @@ -225,59 +228,6 @@ def extra_repr(self): return s -class MXInferenceLinear(torch.nn.Linear): - """ - Inference version of MXLinear, with the weight pre-quantized to MX. - - Note: this is weight-only quantization, with the gemm being executed - in high precision. - """ - - @classmethod - @torch.no_grad() - def from_float( - cls, - mod, - config: Optional[MXInferenceLinearConfig] = MXInferenceLinearConfig(), - ): - with torch.device("meta"): - super_kwargs = { - "in_features": mod.in_features, - "out_features": mod.out_features, - "bias": False, - } - new_mod = cls(**super_kwargs) - # TODO(future PR): set to new_mod.weight directly, will need to work - # through some errors - new_mod.weight_mx = MXTensor.to_mx( - mod.weight, - config.elem_dtype, - block_size=config.block_size, - gemm_kernel_choice=config.gemm_kernel_choice, - pack_fp6=config.pack_fp6, - ) - new_mod.bias = mod.bias - new_mod.config = config - return new_mod - - @torch.no_grad() - def forward(self, x): - w_hp = self.weight_mx.to_dtype(x.dtype) - y = F.linear(x, w_hp, self.bias) - return y - - def extra_repr(self): - s = f"{super().extra_repr()}, {self.config.short_str()}" - return s - - @register_quantize_module_handler(MXLinearConfig) def _mx_linear_transform(module: torch.nn.Module, config: MXLinearConfig): return MXLinear.from_float(module, config=config) - - -@register_quantize_module_handler(MXInferenceLinearConfig) -def _mx_inference_linear_transform( - module: torch.nn.Module, config: MXInferenceLinearConfig -): - return MXInferenceLinear.from_float(module, config=config) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index c7e673dc37..07e47eed66 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -80,19 +80,32 @@ def _get_gemm_choice( def _addmm_mx_dispatch( - a: MXTensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None + a: torch.Tensor, b: MXTensor, aten_op, bias: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Core implementation shared between mx_mm and mx_addmm. The only difference is whether bias is None or not. """ + + if not isinstance(a, MXTensor): + assert b.act_quant_kwargs is not None, "weight-only quant not yet supported" + k = b.act_quant_kwargs + a = MXTensor.to_mx( + a, + k.elem_dtype, + k.block_size, + k.scaling_mode, + k.gemm_kernel_choice, + k.pack_fp6, + ) + gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] - assert a._data.is_contiguous() - assert b._data.t().is_contiguous() + assert a.qdata.is_contiguous() + assert b.qdata.t().is_contiguous() assert a._block_size == 32, f"Invalid block size {a._block_size}" assert b._block_size == 32, f"Invalid block size {b._block_size}" @@ -108,8 +121,8 @@ def _addmm_mx_dispatch( ) res = torch._scaled_mm( - a._data, - b._data, + a.qdata, + b.qdata, a_scale_block.view(torch.float8_e8m0fnu), b_scale_block.view(torch.float8_e8m0fnu), bias=bias, @@ -121,7 +134,7 @@ def _addmm_mx_dispatch( assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported" # FP4 operations res = torchao.ops.mx_fp4_bf16( - a._data, b._data, a_scale_block, b_scale_block + a.qdata, b.qdata, a_scale_block, b_scale_block ) # TODO add optional bias to kernel if bias is not None: @@ -148,18 +161,14 @@ def _addmm_mx_dispatch( def mx_mm(func, types, args, kwargs): a = args[0] b = args[1] - assert isinstance(a, MXTensor) and isinstance(b, MXTensor) + assert isinstance(b, MXTensor) return _addmm_mx_dispatch(a, b, func) @implements([aten.addmm.default]) def mx_addmm(func, types, args, kwargs): - assert ( - isinstance(args[0], torch.Tensor) - and isinstance(args[1], MXTensor) - and isinstance(args[2], MXTensor) - ) + assert isinstance(args[0], torch.Tensor) and isinstance(args[2], MXTensor) bias = args[0] a = args[1] b = args[2] @@ -171,14 +180,14 @@ def mx_t(func, types, args, kwargs): # For now, only transpose(input, 0, 1) is supported. old = args[0] new = MXTensor( + old.qdata.t(), old._scale_e8m0, - old._data.t(), old._elem_dtype, old._block_size, old._orig_dtype, - old._use_fp4_custom_triton_dequant_kernel, old._gemm_kernel_choice, old._pack_fp6, + old.act_quant_kwargs, ) return new @@ -205,7 +214,7 @@ def unwrap(x): @implements([aten.view.default]) def mx_view_op(func, types, args, kwargs): - data = args[0]._data + data = args[0].qdata new_size = args[1] if args[0]._elem_dtype == torch.float4_e2m1fn_x2: # special case fp4 as we pack two elements per byte @@ -215,14 +224,14 @@ def mx_view_op(func, types, args, kwargs): new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) new_data = func(data, new_size, *args[2:], **kwargs) return MXTensor( - args[0]._scale_e8m0, new_data, + args[0]._scale_e8m0, args[0]._elem_dtype, args[0]._block_size, args[0]._orig_dtype, - args[0]._use_fp4_custom_triton_dequant_kernel, args[0]._gemm_kernel_choice, args[0]._pack_fp6, + args[0].act_quant_kwargs, ) @@ -240,8 +249,8 @@ def mx_slice(func, types, args, kwargs): if dim == 0: # Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now - sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step).flatten() - sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) + sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) + sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step).unsqueeze(-1) elif dim == 1: # Slicing along reduciton dim if start is not None: @@ -256,7 +265,7 @@ def mx_slice(func, types, args, kwargs): f"End index {end} must be a multiple of block_size {x._block_size}" ) - sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) + sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step) # Calculate which scale elements to keep start_block = 0 if start is None else start // x._block_size @@ -265,7 +274,7 @@ def mx_slice(func, types, args, kwargs): # Slice the scale tensor accordingly sliced_scale = aten.slice.Tensor( scale_shaped, 1, start_block, end_block, step - ).flatten() + ).unsqueeze(-1) else: raise ValueError( f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}" @@ -276,14 +285,14 @@ def mx_slice(func, types, args, kwargs): args, kwargs, MXTensor( - sliced_scale, sliced_data, + sliced_scale, x._elem_dtype, x._block_size, x._orig_dtype, - x._use_fp4_custom_triton_dequant_kernel, x._gemm_kernel_choice, x._pack_fp6, + x.act_quant_kwargs, ), ) @@ -330,14 +339,14 @@ def autocast_to_copy(func, types, args, kwargs): # If dtype is specified, create a new MXTensor with the requested dtype if dtype is not None: res = MXTensor( + tensor.qdata, tensor._scale_e8m0, - tensor._data, tensor._elem_dtype, tensor._block_size, dtype, - tensor._use_fp4_custom_triton_dequant_kernel, tensor._gemm_kernel_choice, tensor._pack_fp6, + tensor.act_quant_kwargs, ) return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index e98878af77..b717462b4d 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -17,13 +17,13 @@ * Zeros: N/A """ -from enum import Enum, auto -from typing import Callable, Dict, Union +from dataclasses import dataclass +from typing import Optional, Union import torch from torch.distributed._tensor import DTensor -from torchao.prototype.mx_formats.config import MXGemmKernelChoice +from torchao.prototype.mx_formats.config import MXGemmKernelChoice, ScaleCalculationMode from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP6_E2M3, @@ -53,11 +53,14 @@ f32_to_f6_e3m2_unpacked, pack_uint4, pack_uint6, - triton_f4_to_scaled_bf16, triton_f6_e2m3_to_scaled_bf16, triton_f6_e3m2_to_scaled_bf16, unpack_uint4, ) +from torchao.quantization.quantize_.common import ( + QuantizeTensorKwargs, +) +from torchao.utils import TorchAOBaseTensor # TODO(later): read from somewhere else? SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 @@ -68,27 +71,13 @@ EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2 -class ScaleCalculationMode(Enum): - """ - Enum representing the different methods for calculating MX block scaling. - There are three methods available: - FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). - It result in overflow issues for large values and bad for gradient quantization. - CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. - It uses X = 2^ceil(log2(max_abs(v))-max_exp). - EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). - It provides better accuracy for MX4 training compared to FLOOR and CEIL. - RCEIL: The method is to apply ceil to the ratio of max_abs(v) and max_pos. - This method's detail is described in https://docs.nvidia.com/cuda/cublas/index.html#d-block-quantization - Section "Computing scaling and conversion factors for FP8 with UE8M0 scales" - - By default, we use the EVEN method for better accuracy. - """ - - FLOOR = auto() - CEIL = auto() - EVEN = auto() - RCEIL = auto() +@dataclass +class QuantizeTensorToMXKwargs(QuantizeTensorKwargs): + elem_dtype: Union[torch.dtype, str] = torch.float8_e4m3fn + block_size: int = 32 + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + pack_fp6: bool = False def _to_mx_rceil( @@ -150,7 +139,6 @@ def to_mx( Takes a high precision tensor and converts to MX scale and raw data, in naive layout (scale and raw data are separate tensors). """ - assert data_hp.dtype in ( torch.bfloat16, torch.float, @@ -331,6 +319,7 @@ def to_mx( raise AssertionError("unsupported") scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + scale_e8m0_biased = scale_e8m0_biased.squeeze(-1) return scale_e8m0_biased, data_lp @@ -357,7 +346,6 @@ def to_dtype( elem_dtype, block_size, target_dtype, - use_fp4_custom_triton_dequant_kernel, pack_fp6, ): orig_shape = data_lp.shape @@ -400,25 +388,15 @@ def to_dtype( data_hp = f6_e3m2_unpacked_to_f32(data_lp) data_hp = data_hp.to(target_dtype).reshape(orig_shape) elif elem_dtype == torch.float4_e2m1fn_x2: - if use_fp4_custom_triton_dequant_kernel: - data_hp_rescaled = triton_f4_to_scaled_bf16( - data_lp, - scale_e8m0, - block_size, - ) - if is_transposed: - data_hp_rescaled = data_hp_rescaled.t() - return data_hp_rescaled.to(target_dtype) - else: - # fp4 - f4_unpacked = unpack_uint4(data_lp) - # for now we only have a cast to f32 - # TODO(future PR): add cast directly to bf16 - f32 = f4_unpacked_to_f32(f4_unpacked) - data_hp = f32.to(target_dtype) - # manually adjust shape to account for the unpacking - # TODO(future PR): clean up the shape code and remove the hack - # below + # fp4 + f4_unpacked = unpack_uint4(data_lp) + # for now we only have a cast to f32 + # TODO(future PR): add cast directly to bf16 + f32 = f4_unpacked_to_f32(f4_unpacked) + data_hp = f32.to(target_dtype) + # manually adjust shape to account for the unpacking + # TODO(future PR): clean up the shape code and remove the hack + # below orig_shape = (*orig_shape[:-1], orig_shape[-1] * 2) else: raise AssertionError("unsupported") @@ -471,19 +449,29 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous): return new_size -class MXTensor(torch.Tensor): +class MXTensor(TorchAOBaseTensor): + tensor_data_names = ["qdata", "_scale_e8m0"] + tensor_attribute_names = [ + "_elem_dtype", + "_block_size", + "_orig_dtype", + "_gemm_kernel_choice", + "_pack_fp6", + "act_quant_kwargs", + ] + def __new__( cls, + qdata, scale_e8m0_bits, - data_bits, elem_dtype, block_size, orig_dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, + act_quant_kwargs, ): - new_size = data_bits.size() + new_size = qdata.size() if elem_dtype == torch.float4_e2m1fn_x2: # set the tensor size to what it would be without 2x4 packing # Note: `is_contiguous` is going to return True for a tensor of size @@ -492,27 +480,27 @@ def __new__( # a time when fixing this becomes important. new_size = tensor_size_fp4x2_to_hp( new_size, - data_bits.is_contiguous(), + qdata.is_contiguous(), ) elif pack_fp6 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: # set the tensor size to what it would be without fp6 packing new_size = tensor_size_fp6x4_to_hpx3( new_size, - data_bits.is_contiguous(), + qdata.is_contiguous(), ) self = torch.Tensor._make_wrapper_subclass( cls, new_size, - strides=data_bits.stride(), - storage_offset=data_bits.storage_offset(), - layout=data_bits.layout, + strides=qdata.stride(), + storage_offset=qdata.storage_offset(), + layout=qdata.layout, dtype=orig_dtype, - device=data_bits.device, + device=qdata.device, ) assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, ( f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}" ) - assert data_bits.dtype in ( + assert qdata.dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8, @@ -523,10 +511,10 @@ def __new__( ): target_numel = scale_e8m0_bits.numel() * block_size elif elem_dtype == torch.float4_e2m1fn_x2: - assert data_bits.dtype is torch.uint8 # fp4 + assert qdata.dtype is torch.uint8 # fp4 target_numel = scale_e8m0_bits.numel() * block_size / 2 elif elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: - assert data_bits.dtype is torch.uint8 # fp4 + assert qdata.dtype is torch.uint8 # fp4 target_numel = scale_e8m0_bits.numel() * block_size if pack_fp6: target_numel = 3 * target_numel // 4 @@ -534,31 +522,27 @@ def __new__( raise AssertionError("unsupported") if not issubclass( torch._subclasses.fake_tensor.FakeTensor, - type(data_bits), + type(qdata), ): # this check is sometimes broken for FakeTensor # TODO investigate - assert target_numel == data_bits.numel(), ( - f"{target_numel} != {data_bits.numel()}" - ) + assert target_numel == qdata.numel(), f"{target_numel} != {qdata.numel()}" # `_scale_e8m0` has rank 1 and applies to a row-major memory layout of - # `_data` + # `qdata` + self.qdata = qdata self._scale_e8m0 = scale_e8m0_bits - self._data = data_bits self._elem_dtype = elem_dtype self._block_size = block_size self._orig_dtype = orig_dtype - self._use_fp4_custom_triton_dequant_kernel = ( - use_fp4_custom_triton_dequant_kernel - ) self._gemm_kernel_choice = gemm_kernel_choice self._pack_fp6 = pack_fp6 + self.act_quant_kwargs = act_quant_kwargs return self def __repr__(self): # TODO better elem dtype print for fp4 - return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self._data}, d_hp: {self.to_dtype(self._orig_dtype)}" # noqa: E501 + return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501 @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -579,12 +563,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): def to_dtype(self, target_dtype): return to_dtype( - self._data, + self.qdata, self._scale_e8m0, self._elem_dtype, self._block_size, target_dtype, - self._use_fp4_custom_triton_dequant_kernel, self._pack_fp6, ) @@ -595,9 +578,10 @@ def to_mx( elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, - use_fp4_custom_triton_dequant_kernel: bool = False, + # TODO(future PR): switch default gemm to cublas gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, pack_fp6: bool = False, + act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, ): scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode, pack_fp6 @@ -607,14 +591,14 @@ def to_mx( local_scale_e8m0_biased = scale_e8m0_biased.to_local() local_data_lp = data_lp.to_local() inner_mx_tensor = MXTensor( - local_scale_e8m0_biased, local_data_lp, + local_scale_e8m0_biased, elem_dtype, block_size, data_hp.dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, + act_quant_kwargs, ) return DTensor.from_local( inner_mx_tensor, @@ -625,76 +609,15 @@ def to_mx( stride=data_lp.stride(), ) return MXTensor( - scale_e8m0_biased, data_lp, + scale_e8m0_biased, elem_dtype, block_size, data_hp.dtype, - use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6, - ) - - def __tensor_flatten__(self): - ctx = { - "_elem_dtype": self._elem_dtype, - "_block_size": self._block_size, - "_orig_dtype": self._orig_dtype, - "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, - "_gemm_kernel_choice": self._gemm_kernel_choice, - "_pack_fp6": self._pack_fp6, - } - return ["_scale_e8m0", "_data"], ctx - - @staticmethod - def __tensor_unflatten__( - inner_tensors: Dict, - metadata, - outer_size, - outer_stride, - ): - return MXTensor( - inner_tensors["_scale_e8m0"], - inner_tensors["_data"], - metadata["_elem_dtype"], - metadata["_block_size"], - metadata["_orig_dtype"], - metadata["_use_fp4_custom_triton_dequant_kernel"], - metadata["_gemm_kernel_choice"], - metadata["_pack_fp6"], - ) - - def _apply_fn_to_data(self, fn: Callable): - """Applies a fn to all tensor components stored on this class""" - tensor_names, ctx = self.__tensor_flatten__() - - # Apply the function to each tensor component - new_tensors = {} - for name in tensor_names: - new_tensors[name] = fn(getattr(self, name)) - - return self.__class__.__tensor_unflatten__( - new_tensors, - ctx, - None, # outer_size parameter - None, # outer_stride parameter + act_quant_kwargs, ) # Do not force the MXTensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def _same_metadata(cls, self: "MXTensor", src: "MXTensor") -> bool: - return ( - isinstance(self, MXTensor) - and isinstance(src, MXTensor) - and self._elem_dtype == src._elem_dtype - and self._block_size == src._block_size - and self._orig_dtype == src._orig_dtype - and self._use_fp4_custom_triton_dequant_kernel - == src._use_fp4_custom_triton_dequant_kernel - and self._gemm_kernel_choice == src._gemm_kernel_choice - and self._pack_fp6 == src._pack_fp6 - and self._scale_e8m0.shape == src._scale_e8m0.shape - and self._data.shape == src._data.shape - ) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index ed1b5df1d0..3f2e8eeef3 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -4,8 +4,10 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import sys +from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Dict, Optional +from typing import Any, Dict, Optional import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -15,14 +17,18 @@ f4_unpacked_to_f32, f32_to_f4_unpacked, pack_uint4, + triton_quantize_nvfp4, unpack_uint4, ) from torchao.prototype.mx_formats.mx_tensor import ( tensor_size_fp4x2_to_hp, tensor_size_hp_to_fp4x2, ) -from torchao.prototype.mx_formats.utils import to_blocked -from torchao.utils import fill_defaults +from torchao.prototype.mx_formats.utils import from_blocked, to_blocked +from torchao.quantization.quantize_.common import ( + QuantizeTensorKwargs, +) +from torchao.utils import TorchAOBaseTensor, ceil_div, fill_defaults E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny @@ -36,6 +42,14 @@ class NVFP4MMConfig(Enum): WEIGHT_ONLY = "weight_only" +@dataclass +class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs): + block_size: int = 16 + is_swizzled_scales: bool = False + use_triton_kernel: bool = False + + +# TODO(future PR): move over to TorchAOBaseTensor's dispatch def implements(aten_ops): """Register aten ops to the NVFP4 op table""" @@ -47,62 +61,76 @@ def decorator(func): return decorator -class NVFP4Tensor(torch.Tensor): +class NVFP4Tensor(TorchAOBaseTensor): """NVIDIA FP4 (NVFP4) Tensor subclass. This implements the NVIDIA variant of MX FP4 format, which uses a specific quantization algorithm for FP4 data with UE4M3 scales. Attributes: - _scale_e4m3: Blockwise scales in float8_e4m3fn format + qdata: Packed FP4 data (2 values per byte) + _scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled) _per_tensor_scale: Optional global per-tensor scale in float32 format - _data: Packed FP4 data (2 values per byte) - _block_size: Block size for quantization (fixed at 16) - _orig_dtype: Original tensor dtype before quantization - mm_config: Matrix multiplication configuration + _act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation + _block_size (int): Block size for quantization (fixed at 16) + _orig_dtype (torch.dtype): Original tensor dtype before quantization + _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format + use_triton_kernel (bool): Whether to use triton kernels """ - _scale_e4m3: torch.Tensor - _per_tensor_scale: Optional[torch.Tensor] - _data: torch.Tensor - _block_size: int - _orig_dtype: torch.dtype - mm_config: NVFP4MMConfig + tensor_data_names = ["qdata", "_scale_e4m3"] + tensor_attribute_names = [ + "_block_size", + "_orig_dtype", + ] + optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"] + optional_tensor_attribute_names = [ + "_is_swizzled_scales", + "use_triton_kernel", + "act_quant_kwargs", + ] def __new__( cls, + qdata, blockwise_scales, - per_tensor_scale, - data_bits, block_size, orig_dtype, - mm_config=NVFP4MMConfig.DYNAMIC, + _per_tensor_scale=None, + _act_per_tensor_scale=None, + _is_swizzled_scales=False, + use_triton_kernel=False, + act_quant_kwargs=None, ): - # FP4 tensor size handling - new_size = data_bits.size() + # FP4 tensor size handling two paths, contiguous or not + new_size = qdata.size() + new_size = tensor_size_fp4x2_to_hp( new_size, - data_bits.is_contiguous(), + qdata.stride(0) > qdata.stride(1), ) self = torch.Tensor._make_wrapper_subclass( cls, new_size, dtype=orig_dtype, - device=data_bits.device, + device=qdata.device, requires_grad=False, ) + self.qdata = qdata self._scale_e4m3 = blockwise_scales - self._per_tensor_scale = per_tensor_scale - self._data = data_bits self._block_size = block_size self._orig_dtype = orig_dtype - self.mm_config = mm_config + self._per_tensor_scale = _per_tensor_scale + self._act_per_tensor_scale = _act_per_tensor_scale + self._is_swizzled_scales = _is_swizzled_scales + self.use_triton_kernel = use_triton_kernel + self.act_quant_kwargs = act_quant_kwargs return self def __repr__(self): - return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self._data}, d_hp: {self.to_dtype(self._orig_dtype)}" + return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}" @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -117,71 +145,54 @@ def to_nvfp4( data_hp: torch.Tensor, block_size: int = 16, per_tensor_scale: Optional[torch.Tensor] = None, - mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC, + act_per_tensor_scale: Optional[torch.Tensor] = None, + is_swizzled_scales: bool = False, + use_triton_kernel: bool = False, + act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None, ): """Convert high precision tensor to NVFP4 format. Args: data_hp: High precision input tensor (bfloat16 or float32) block_size: Block size for quantization (must be 16) - per_tensor_amax: Optional pre-computed absolute maximum for calibration. + per_tensor_scale: Optional pre-computed absolute maximum for calibration. + If provided, uses per-tensor scaling. If None, uses block-wise scaling only. + act_per_tensor_scale: Optional pre-computed absolute maximum for calibration for activation If provided, uses per-tensor scaling. If None, uses block-wise scaling only. + is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication + use_triton_kernel: If True, use Triton kernel for quantization + act_quant_kwargs: If specified, config for quantizing the activation Returns: NVFP4Tensor: Quantized tensor in NVFP4 format """ - blockwise_scales, data_lp = nvfp4_quantize( - data_hp, block_size, per_tensor_scale - ) + if use_triton_kernel: + assert is_swizzled_scales, "Triton kernel only supports swizzled scales" + assert data_hp.shape[1] % 16 == 0, ( + f"Triton kernel requires K (dim 1) to be divisible by 16, got {data_hp.shape[1]}" + ) + blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale) + else: + blockwise_scales, data_lp = nvfp4_quantize( + data_hp, block_size, per_tensor_scale + ) + if is_swizzled_scales: + M, K = data_hp.shape[0], data_hp.shape[1] + scale_shape = (M, K // block_size) + blockwise_scales = to_blocked( + blockwise_scales.view(scale_shape) + ).flatten() + return NVFP4Tensor( - blockwise_scales, - per_tensor_scale, data_lp, + blockwise_scales, block_size, data_hp.dtype, - mm_config, - ) - - def __tensor_flatten__(self): - ctx = { - "_block_size": self._block_size, - "_orig_dtype": self._orig_dtype, - "mm_config": self.mm_config, - } - tensor_list = ["_scale_e4m3", "_data"] - if self._per_tensor_scale is not None: - tensor_list.append("_per_tensor_scale") - return tensor_list, ctx - - def _apply_fn_to_data(self, fn: Callable): - """Applies a fn to all tensor components stored on this class""" - tensor_names, ctx = self.__tensor_flatten__() - new_tensors = {} - for name in tensor_names: - new_tensors[name] = fn(getattr(self, name)) - if "_per_tensor_scale" not in tensor_names: - new_tensors["_per_tensor_scale"] = None - return self.__class__.__tensor_unflatten__( - new_tensors, - ctx, - None, - None, - ) - - @staticmethod - def __tensor_unflatten__( - inner_tensors, - metadata, - outer_size, - outer_stride, - ): - return NVFP4Tensor( - inner_tensors["_scale_e4m3"], - inner_tensors.get("_per_tensor_scale", None), - inner_tensors["_data"], - metadata["_block_size"], - metadata["_orig_dtype"], - metadata["mm_config"], + per_tensor_scale, + act_per_tensor_scale, + is_swizzled_scales, + use_triton_kernel, + act_quant_kwargs, ) # Do not force the NVFP4Tensor type on the returned tensor @@ -196,12 +207,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor: Returns: torch.Tensor: Dequantized tensor in the target dtype """ - is_transposed = not self._data.is_contiguous() + is_transposed = self.qdata.stride(0) < self.qdata.stride(1) if is_transposed: M, K = self.shape[1], self.shape[0] else: M, K = self.shape[0], self.shape[1] - data = self._data.t() if is_transposed else self._data + data = self.qdata.t() if is_transposed else self.qdata data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8)) data_f32 = f4_unpacked_to_f32(data_unpacked) @@ -221,10 +232,21 @@ def get_hp_scales(self) -> torch.Tensor: Returns: torch.Tensor: Scales of the NVFP4Tensor """ + is_transposed = self.qdata.stride(0) < self.qdata.stride(1) + if is_transposed: + M, K = self.shape[1], self.shape[0] + else: + M, K = self.shape[0], self.shape[1] + + if self._is_swizzled_scales: + scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size) + else: + scale_e4m3 = self._scale_e4m3 + return ( - self._scale_e4m3.to(self._orig_dtype) + scale_e4m3.to(self._orig_dtype) if not self._per_tensor_scale - else self._per_tensor_scale * self._scale_e4m3.to(self._orig_dtype) + else self._per_tensor_scale * scale_e4m3.to(self._orig_dtype) ) @classmethod @@ -238,19 +260,24 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool: Returns: bool: True if both tensors have identical metadata, False otherwise """ - # Check per_tensor_scale equality per_tensor_scale_equal = ( self._per_tensor_scale is None and src._per_tensor_scale is None ) or (self._per_tensor_scale.shape == src._per_tensor_scale.shape) + act_per_tensor_scale_equal = ( + self._act_per_tensor_scale is None and src._act_per_tensor_scale is None + ) or (self._act_per_tensor_scale.shape == src._act_per_tensor_scale.shape) return ( isinstance(self, NVFP4Tensor) and isinstance(src, NVFP4Tensor) and self._block_size == src._block_size and self._orig_dtype == src._orig_dtype + and self._is_swizzled_scales == src._is_swizzled_scales and self._scale_e4m3.shape == src._scale_e4m3.shape and per_tensor_scale_equal - and self._data.shape == src._data.shape + and act_per_tensor_scale_equal + and self.qdata.shape == src.qdata.shape + and self.act_quant_kwargs == src.act_quant_kwargs ) @@ -278,7 +305,6 @@ def nvfp4_to_copy(func, types, args, kwargs): # Handle device parameter device = kwargs.pop("device", None) if device is not None: - # Apply device change using _apply_fn_to_data tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device)) tensor = return_and_correct_aliasing(func, args, {}, tensor) else: @@ -286,12 +312,15 @@ def nvfp4_to_copy(func, types, args, kwargs): if dtype is not None: res = NVFP4Tensor( + tensor.qdata, tensor._scale_e4m3, - tensor._per_tensor_scale, - tensor._data, tensor._block_size, dtype, - tensor.mm_config, + tensor._per_tensor_scale, + tensor._act_per_tensor_scale, + tensor._is_swizzled_scales, + tensor.use_triton_kernel, + tensor.act_quant_kwargs, ) return res @@ -332,78 +361,206 @@ def nvfp4_slice(func, types, args, kwargs): if step != 1: raise ValueError("Only support aten.slice with step=1") - assert x._data.is_contiguous(), "Only support contiguous data for now" + assert x.qdata.is_contiguous(), "Only support contiguous data for now" M, K = x.shape[0], x.shape[1] - scale_shaped = x._scale_e4m3.view(M, K // x._block_size) - - if dim == 0: - # Slicing along the first dimension (rows) - sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step).flatten() - sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) - elif dim == 1: - # Slicing along reduction dim - must align with block boundaries - if start is not None: - assert start % x._block_size == 0, ( - f"Start index {start} must be a multiple of block_size {x._block_size}" - ) - if end is not None: - assert end % x._block_size == 0, ( - f"End index {end} must be a multiple of block_size {x._block_size}" + if x._is_swizzled_scales: + scale_rows = M + scale_cols = K // x._block_size + n_row_blocks = ceil_div(scale_rows, 128) + n_col_blocks = ceil_div(scale_cols, 4) + elements_per_block = 32 * 16 # 512 elements + + if dim == 0: + # Row slicing + # Handle sys.maxsize (default slice end) + if end == sys.maxsize: + end = M + + # Check if start/end align with 128-row boundaries + if start is not None and start % 128 != 0: + raise RuntimeError( + f"Row slicing of NVFP4Tensor with swizzled scales requires " + f"start index to be a multiple of 128, got {start}" + ) + if end is not None and end != M and end % 128 != 0: + raise RuntimeError( + f"Row slicing of NVFP4Tensor with swizzled scales requires " + f"end index to be a multiple of 128 or equal to tensor size {M}, got {end}" + ) + + # Calculate which row blocks to keep + start_block = 0 if start is None else start // 128 + end_block = n_row_blocks if end is None or end >= M else end // 128 + + # The swizzled tensor has shape (n_row_blocks * n_col_blocks * 32 * 16,) + blocks_per_row = n_col_blocks + start_idx = start_block * blocks_per_row * elements_per_block + end_idx = ( + end_block * blocks_per_row * elements_per_block + if end_block < n_row_blocks + else None ) - sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) + sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1) + sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step) + + elif dim == 1: + # Column slicing + # Handle sys.maxsize (default slice end) + if end == sys.maxsize: + end = K + + # Check if start/end align with 64-column boundaries (4 scale columns * 16 block_size) + if start is not None and start % 64 != 0: + raise RuntimeError( + f"Column slicing of NVFP4Tensor with swizzled scales requires " + f"start index to be a multiple of 64, got {start}" + ) + if end is not None and end != K and end % 64 != 0: + raise RuntimeError( + f"Column slicing of NVFP4Tensor with swizzled scales requires " + f"end index to be a multiple of 64 or equal to tensor size {K}, got {end}" + ) + + # Also check FP4 packing alignment + if start is not None and start % 2 != 0: + raise RuntimeError(f"Start index {start} must be even for FP4 packing") + if end is not None and end != K and end % 2 != 0: + raise RuntimeError(f"End index {end} must be even for FP4 packing") + + # Calculate which column blocks to keep + start_scale_col = 0 if start is None else start // 16 + end_scale_col = scale_cols if end is None or end >= K else end // 16 + + start_col_block = start_scale_col // 4 + end_col_block = end_scale_col // 4 + + # Verify the end aligns with block boundary + if end_scale_col % 4 != 0: + raise RuntimeError( + f"Column slicing end index {end} does not align with scale block boundaries. " + f"End must result in a multiple of 4 scale columns (64 data columns)." + ) + + if start_col_block == 0 and end_col_block == n_col_blocks: + # Full width - no slicing needed + sliced_scale = x._scale_e4m3 + else: + # Extract specific column blocks from each row block + # Each row block in swizzled format contains n_col_blocks chunks of (32, 16) + elements_per_row_block = n_col_blocks * elements_per_block + + # Build list of slices to extract + slices_to_extract = [] + for row_block in range(n_row_blocks): + row_start = row_block * elements_per_row_block + col_start = row_start + start_col_block * elements_per_block + col_end = row_start + end_col_block * elements_per_block + slices_to_extract.append(x._scale_e4m3[col_start:col_end]) + + # Concatenate all the slices + sliced_scale = torch.cat(slices_to_extract, dim=0) + + # Slice the data tensor + packed_start = None if start is None else start // 2 + packed_end = None if end is None else end // 2 + sliced_data = aten.slice.Tensor( + x.qdata, dim, packed_start, packed_end, step + ) - # Calculate which scale blocks to keep - start_block = 0 if start is None else start // x._block_size - end_block = None if end is None else end // x._block_size + else: + raise ValueError( + f"NVFP4Tensor only supports slicing along dimensions 0 and 1, got dim={dim}" + ) - # Slice the scale tensor accordingly - sliced_scale = aten.slice.Tensor(scale_shaped, 1, start_block, end_block, step) else: - raise ValueError( - f"NVFP4Tensor only supports slicing along dimensions 0 and 1, got dim={dim}" - ) + scale_shaped = x._scale_e4m3.view(M, K // x._block_size) + + if dim == 0: + sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) + sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step) + + elif dim == 1: + if start is not None: + assert start % x._block_size == 0, ( + f"Start index {start} must be a multiple of block_size {x._block_size}" + ) + assert start % 2 == 0, ( + f"Start index {start} must be even for FP4 packing" + ) + + if end is not None and end != sys.maxsize: + assert end % x._block_size == 0, ( + f"End index {end} must be a multiple of block_size {x._block_size}" + ) + assert end % 2 == 0, f"End index {end} must be even for FP4 packing" + + packed_start = None if start is None else start // 2 + packed_end = None if end is None else end // 2 + sliced_data = aten.slice.Tensor( + x.qdata, dim, packed_start, packed_end, step + ) - return NVFP4Tensor( - sliced_scale, - x._per_tensor_scale, # Unchanged per-tensor scale + start_block = 0 if start is None else start // x._block_size + end_block = None if end is None else end // x._block_size + sliced_scale = aten.slice.Tensor( + scale_shaped, 1, start_block, end_block, step + ) + + sliced_scale = sliced_scale.flatten() + + # Create result tensor + result = NVFP4Tensor( sliced_data, + sliced_scale, x._block_size, x._orig_dtype, - x.mm_config, + x._per_tensor_scale, + x._act_per_tensor_scale, + x._is_swizzled_scales, + x.use_triton_kernel, + x.act_quant_kwargs, ) + return return_and_correct_aliasing(func, args, kwargs, result) + @implements([aten.t.default]) def nvfp4_t(func, types, args, kwargs): # For now, only transpose(input, 0, 1) is supported. old = args[0] new = NVFP4Tensor( + old.qdata.t(), old._scale_e4m3, - old._per_tensor_scale, - old._data.t(), old._block_size, old._orig_dtype, - old.mm_config, + old._per_tensor_scale, + old._act_per_tensor_scale, + old._is_swizzled_scales, + old.use_triton_kernel, + old.act_quant_kwargs, ) return new @implements([aten.view.default]) def nvfp4_view_op(func, types, args, kwargs): - data = args[0]._data + data = args[0].qdata new_size = args[1] new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) new_data = func(data, new_size, *args[2:], **kwargs) return NVFP4Tensor( - args[0]._scale_e4m3, - args[0]._per_tensor_scale, new_data, + args[0]._scale_e4m3, args[0]._block_size, args[0]._orig_dtype, - args[0].mm_config, + args[0]._per_tensor_scale, + args[0]._act_per_tensor_scale, + args[0]._is_swizzled_scales, + args[0].use_triton_kernel, + args[0].act_quant_kwargs, ) @@ -414,8 +571,8 @@ def _addmm_nvfp4_dispatch( Core implementation shared between nvfp4_mm, nvfp4_addmm, and nvfp4_linear. The only difference is whether bias is None or not. """ - assert a._data.is_contiguous() - assert b._data.t().is_contiguous() + assert a.qdata.is_contiguous() + assert b.qdata.t().is_contiguous() assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" @@ -423,10 +580,17 @@ def _addmm_nvfp4_dispatch( N = b.shape[1] # Swizzle Dizzle - a_scale = a._scale_e4m3.view(M, K // a._block_size) - b_scale = b._scale_e4m3.view(N, K // b._block_size) - a_scale_blocked = to_blocked(a_scale) - b_scale_blocked = to_blocked(b_scale) + if a._is_swizzled_scales: + a_scale_blocked = a._scale_e4m3 # Already swizzled + else: + a_scale = a._scale_e4m3.view(M, K // a._block_size) + a_scale_blocked = to_blocked(a_scale) + + if b._is_swizzled_scales: + b_scale_blocked = b._scale_e4m3 # Already swizzled + else: + b_scale = b._scale_e4m3.view(N, K // b._block_size) + b_scale_blocked = to_blocked(b_scale) # Merge double quant scales into 1 scale for Scale_In^D if a._per_tensor_scale is not None: @@ -441,10 +605,11 @@ def _addmm_nvfp4_dispatch( # When we have per-tensor scaling, we need to apply it before bias # since bias is not quantized should_add_bias_separately = (scale_result is not None) and (bias is not None) + # should_add_bias_separately = bias is not None result = torch._scaled_mm( - a._data.view(torch.float4_e2m1fn_x2), - b._data.view(torch.float4_e2m1fn_x2), + a.qdata.view(torch.float4_e2m1fn_x2), + b.qdata.view(torch.float4_e2m1fn_x2), a_scale_blocked.view(torch.float8_e4m3fn), b_scale_blocked.view(torch.float8_e4m3fn), bias=None if should_add_bias_separately else bias, @@ -473,14 +638,21 @@ def nvfp4_linear(func, types, args, kwargs): if not isinstance(weight_tensor, NVFP4Tensor): raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor") - config = weight_tensor.mm_config - - if config == NVFP4MMConfig.WEIGHT_ONLY: + if weight_tensor.act_quant_kwargs is None: + # weight_only quant weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype) return torch.nn.functional.linear(input_tensor, weight_dequant, bias) else: - input_quant = NVFP4Tensor.to_nvfp4(input_tensor, mm_config=config) - return _addmm_nvfp4_dispatch(input_quant, weight_tensor, func, bias=bias) + # dynamic quant + k = weight_tensor.act_quant_kwargs + input_tensor = NVFP4Tensor.to_nvfp4( + input_tensor, + block_size=k.block_size, + per_tensor_scale=weight_tensor._act_per_tensor_scale, + is_swizzled_scales=k.is_swizzled_scales, + use_triton_kernel=k.use_triton_kernel, + ) + return _addmm_nvfp4_dispatch(input_tensor, weight_tensor.t(), func, bias=bias) @implements([aten.mm.default, aten.matmul.default]) @@ -490,9 +662,7 @@ def nvfp4_mm(func, types, args, kwargs): if not isinstance(weight_tensor, NVFP4Tensor): raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor") - config = weight_tensor.mm_config - - if config == NVFP4MMConfig.WEIGHT_ONLY: + if weight_tensor.act_quant_kwargs is None: weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype) if isinstance(input_tensor, NVFP4Tensor): input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype) @@ -501,7 +671,14 @@ def nvfp4_mm(func, types, args, kwargs): return func(input_tensor, weight_dequant) else: if not isinstance(input_tensor, NVFP4Tensor): - input_tensor = NVFP4Tensor.to_nvfp4(input_tensor, mm_config=config) + k = weight_tensor.act_quant_kwargs + input_tensor = NVFP4Tensor.to_nvfp4( + input_tensor, + block_size=k.block_size, + per_tensor_scale=weight_tensor._act_per_tensor_scale, + is_swizzled_scales=k.is_swizzled_scales, + use_triton_kernel=k.use_triton_kernel, + ) return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func) @@ -512,9 +689,7 @@ def nvfp4_addmm(func, types, args, kwargs): if not isinstance(weight_tensor, NVFP4Tensor): raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor") - config = weight_tensor.mm_config - - if config == NVFP4MMConfig.WEIGHT_ONLY: + if weight_tensor.act_quant_kwargs is None: weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype) if isinstance(input_tensor, NVFP4Tensor): input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype) @@ -523,23 +698,31 @@ def nvfp4_addmm(func, types, args, kwargs): return torch.addmm(bias, input_tensor, weight_dequant) else: if not isinstance(input_tensor, NVFP4Tensor): - input_tensor = NVFP4Tensor.to_nvfp4(input_tensor, mm_config=config) + k = weight_tensor.act_quant_kwargs + input_tensor = NVFP4Tensor.to_nvfp4( + input_tensor, + block_size=k.block_size, + per_tensor_scale=weight_tensor._act_per_tensor_scale, + is_swizzled_scales=k.is_swizzled_scales, + use_triton_kernel=k.use_triton_kernel, + ) return _addmm_nvfp4_dispatch(input_tensor, weight_tensor, func, bias=bias) def per_tensor_amax_to_scale(amax: torch.Tensor) -> torch.Tensor: - """Convert per-tensor amax to per-tensor scale. - Used to scale fp32 scales down to fp8 scales + """Convert per-tensor amax to per-tensor scale for NVFP4 quantization. + + Divides by both F8E4M3_MAX and F4_E2M1_MAX to ensure block scales can utilize + the full FP8 E4M3 range (up to 448) when block_max equals tensor_max. + Without F4_E2M1_MAX, the maximum scale would only reach FP8_MAX / FP4_MAX. Args: - amax: Per-tensor amax tensor + amax: Per-tensor absolute maximum value from calibration Returns: - torch.Tensor: Per-tensor scale tensor + torch.Tensor: Per-tensor scale for two-level NVFP4 scaling """ - return torch.clamp(amax / F8E4M3_MAX, min=E4M3_EPS, max=F8E4M3_MAX).to( - torch.float32 - ) + return amax.to(torch.float32) / (F8E4M3_MAX * F4_E2M1_MAX) def nvfp4_quantize( @@ -568,15 +751,40 @@ def nvfp4_quantize( AssertionError: If input dtype is not supported, tensor size is not divisible by block_size, tensor is not contiguous, or block_size != 16 """ + return _nvfp4_quantize(data_hp, block_size, per_tensor_scale) + + +class _Float8Round(torch.autograd.Function): + """ + Cast a tensor to float8 and back to float32 with backward STE. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + return x.to(torch.float8_e4m3fn).to(torch.float32) + + @staticmethod + def backward(ctx, gy: torch.Tensor) -> torch.Tensor: + return gy + + +def _nvfp4_quantize( + data_hp: torch.Tensor, + block_size: int = 16, + per_tensor_scale: Optional[torch.Tensor] = None, + skip_dtype_cast_and_packing: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: assert data_hp.dtype in (torch.bfloat16, torch.float), ( f"{data_hp.dtype} not supported" ) - assert data_hp.numel() % block_size == 0, "unsupported" - assert data_hp.is_contiguous(), "unsupported" + assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size" + assert data_hp.is_contiguous(), "Only support contiguous data for now" assert block_size == 16, "NVFP4 requires block_size=16" + orig_dtype = data_hp.dtype orig_shape = data_hp.shape - data_hp = data_hp.reshape(orig_shape[0], -1, block_size) + # Convert to float32 early for consistent precision with Triton implementation + data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size) max_abs = torch.amax(torch.abs(data_hp), dim=-1) # These scales are currently in fp32, we are going to `quantize` them to e4m3 @@ -585,10 +793,8 @@ def nvfp4_quantize( out_scales = None if per_tensor_scale is None: # We are doing single level scaling - block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to( - torch.float8_e4m3fn - ) - block_scale_fp32 = block_scale_fp8.to(torch.float32) + block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX) + block_scale_fp32 = _Float8Round.apply(block_scale_fp8) data_scaled = data_hp / block_scale_fp32.unsqueeze(-1) out_scales = block_scale_fp8 else: @@ -600,8 +806,8 @@ def nvfp4_quantize( scaled_block_scales = block_scale_fp32 / per_tensor_scale scaled_block_scales_fp8 = torch.clamp( scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX - ).to(torch.float8_e4m3fn) - scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32) + ) + scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8) # We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale # To apply to data total_scale = per_tensor_scale * scaled_block_scales_fp32 @@ -610,8 +816,11 @@ def nvfp4_quantize( data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) data_scaled = data_scaled.view(orig_shape) - data_lp = f32_to_f4_unpacked(data_scaled.float()) - # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' - # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) - data_lp = pack_uint4(data_lp) - return out_scales, data_lp + if skip_dtype_cast_and_packing: + return out_scales.to(torch.float32), data_scaled.to(orig_dtype) + else: + data_lp = f32_to_f4_unpacked(data_scaled) + # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' + # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) + data_lp = pack_uint4(data_lp) + return out_scales.to(torch.float8_e4m3fn), data_lp diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index e4777d3899..2802888980 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -5,8 +5,18 @@ # LICENSE file in the root directory of this source tree. import torch - -from torchao.prototype.mx_formats.kernels import triton_mx_block_rearrange +from torch.distributed._tensor import DTensor + +from torchao.prototype.mx_formats.config import ( + MXFP8Dim1CastKernelChoice, + ScaleCalculationMode, +) +from torchao.prototype.mx_formats.kernels import ( + mxfp8_quantize_cuda, + triton_mx_block_rearrange, + triton_to_mxfp8_dim1, +) +from torchao.prototype.mx_formats.mx_tensor import MXTensor Tensor = torch.Tensor @@ -15,7 +25,7 @@ def ceil_div(a, b): return (a + b - 1) // b -def to_blocked(input_matrix, use_triton_kernel: bool = True) -> Tensor: +def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor: """ Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. @@ -58,6 +68,38 @@ def to_blocked(input_matrix, use_triton_kernel: bool = True) -> Tensor: return rearranged.flatten() +def from_blocked( + blocked_tensor: Tensor, original_rows: int, original_cols: int +) -> Tensor: + """ + Inverse of to_blocked: convert from blocked layout back to regular row-major layout. + + Args: + blocked_tensor: Flattened blocked tensor from to_blocked() + original_rows: Original number of rows before blocking + original_cols: Original number of columns before blocking + + Returns: + Tensor of shape (original_rows, original_cols) in regular layout + """ + n_row_blocks = ceil_div(original_rows, 128) + n_col_blocks = ceil_div(original_cols, 4) + + rearranged = blocked_tensor.view(n_row_blocks * n_col_blocks, 32, 16) + + temp = rearranged.reshape(n_row_blocks * n_col_blocks, 32, 4, 4) + + temp = temp.transpose(1, 2) + + blocks = temp.reshape(n_row_blocks, n_col_blocks, 128, 4) + + padded_view = blocks.permute(0, 2, 1, 3) + + padded = padded_view.reshape(n_row_blocks * 128, n_col_blocks * 4) + + return padded[:original_rows, :original_cols] + + def _to_blocked_single(scales: Tensor) -> Tensor: """Assume that we have a 128x4 block of scales in K Major order @@ -67,3 +109,65 @@ def _to_blocked_single(scales: Tensor) -> Tensor: assert scales.shape == (128, 4) scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles + + +def _to_mxfp8_dim1_kernel_wrapper( + a, + block_size, + elem_dtype, + hp_dtype, + gemm_kernel_choice, + cast_kernel_choice, + scale_calculation_mode: ScaleCalculationMode, +): + if cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON: + assert scale_calculation_mode == ScaleCalculationMode.FLOOR + a_data, a_scale = triton_to_mxfp8_dim1(a, block_size) + elif cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA: + assert scale_calculation_mode in ( + ScaleCalculationMode.FLOOR, + ScaleCalculationMode.RCEIL, + ) + _, a_data, _, a_scale = mxfp8_quantize_cuda( + a, + rowwise=False, + colwise=True, + scaling_mode=scale_calculation_mode.value, + ) + else: + raise ValueError(f"must be one of [CUDA, TRITON], got {cast_kernel_choice}") + + if isinstance(a_data, DTensor): + assert isinstance(a_scale, DTensor) + a_data_local = a_data.to_local() + a_scale_local = a_scale.to_local() + inner = MXTensor( + a_data_local.t(), + a_scale_local, + elem_dtype, + block_size, + hp_dtype, + gemm_kernel_choice, + False, + None, + ) + mx_tensor = DTensor.from_local( + inner, + a_data.device_mesh, + a_data.placements, + run_check=False, + shape=a_data.t().size(), + stride=a_data.t().stride(), + ) + else: + mx_tensor = MXTensor( + a_data.t(), + a_scale, + elem_dtype, + block_size, + hp_dtype, + gemm_kernel_choice, + False, + None, + ) + return mx_tensor diff --git a/torchao/prototype/parq/README.md b/torchao/prototype/parq/README.md index 045f4fa59d..d5f02ded84 100644 --- a/torchao/prototype/parq/README.md +++ b/torchao/prototype/parq/README.md @@ -48,17 +48,14 @@ optimizer = QuantOptimizer( ```python -from torchao.quantization import quantize_ -from torchao.quantization.qat import ( - FakeQuantizeConfig, - intx_quantization_aware_training, +from torchao.quantization import ( + quantize_, + Int8DynamicActivationInt4WeightConfig, ) +from torchao.quantization.qat import QATConfig -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) -quantize_( - model, - intx_quantization_aware_training(weight_config=weight_config), -) +base_config = Int4WeightOnlyConfig(group_size=32) +quantize_(model, QATConfig(base_config, step="prepare")) ``` @@ -68,13 +65,7 @@ quantize_( ```python -from torchao.quantization import IntxWeightOnlyConfig, quantize_ - -config = IntxWeightOnlyConfig( - weight_dtype=torch.int4, granularity=PerGroup(32) -) -optimizer.restore_latent_params() -quantize_(model, config, filter_fn=optimizer.get_filter_fn(model)) +optimizer.torchao_convert(model, weight_only=True) ``` @@ -82,9 +73,9 @@ quantize_(model, config, filter_fn=optimizer.get_filter_fn(model)) ```python from torchao.quantization import quantize_ -from torchao.quantization.qat import from_intx_quantization_aware_training +from torchao.quantization.qat import QATConfig -quantize_(model, from_intx_quantization_aware_training()) +quantize_(model, QATConfig(base_config, step="convert")) ``` @@ -93,6 +84,15 @@ quantize_(model, from_intx_quantization_aware_training()) Note that `UnifTorchaoQuantizer` calls the same quantization primitives as torchao to match the numerics (see [Affine Quantization Details](../../quantization#affine-quantization-details)). +To apply 8-bit dynamic activation quantization with PARQ, add the below to the prepare stage. +```python +from torchao.quantization.qat import QATConfig, IntxFakeQuantizeConfig + +activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +quantize_(self.model, QATConfig(activation_config, step="prepare")) +``` +For the convert stage, call `optimizer.torchao_convert(model)`. The resulting quantized model corresponds to `Int8DynamicActivationInt4WeightConfig` in torchao. + ## QAT arguments | | description | choices | diff --git a/torchao/prototype/parq/__init__.py b/torchao/prototype/parq/__init__.py index d254b1395a..7695c6b147 100644 --- a/torchao/prototype/parq/__init__.py +++ b/torchao/prototype/parq/__init__.py @@ -20,3 +20,8 @@ UnifQuantizer, UnifTorchaoQuantizer, ) +from .quant.config_torchao import StretchedIntxWeightConfig + +__all__ = [ + "StretchedIntxWeightConfig", +] diff --git a/torchao/prototype/parq/optim/quantopt.py b/torchao/prototype/parq/optim/quantopt.py index 2cdd34536d..bfa651dcc9 100644 --- a/torchao/prototype/parq/optim/quantopt.py +++ b/torchao/prototype/parq/optim/quantopt.py @@ -7,13 +7,21 @@ from collections import defaultdict from collections.abc import Callable from functools import partial -from typing import Any, Optional +from typing import Any, Generator, Optional import torch -from torch import Tensor +from torch import Tensor, nn from torch.optim import Optimizer -from ..quant import Quantizer +from torchao.quantization import quantize_ +from torchao.quantization.quant_api import _is_linear + +from ..quant import Quantizer, UnifTorchaoQuantizer +from ..quant.config_torchao import ( + _attach_hf_quantization_config, + _get_config_from_quantizer, + _is_hf_model, +) from ..utils import HAS_DTENSOR, is_dtensor from .proxmap import ProxMap @@ -91,6 +99,23 @@ def __repr__(self) -> str: def state(self) -> defaultdict[Tensor, Any]: # pyre-ignore[3] return self._state if hasattr(self, "_state") else self.base_optimizer.state + @property + def num_steps(self) -> int: + for group in self.regularized_param_groups(): + return group.setdefault("num_steps", 0) + + @num_steps.setter + def num_steps(self, value: int) -> None: + for group in self.regularized_param_groups(): + group["num_steps"] = value + return + + @num_steps.deleter + def num_steps(self) -> None: + for group in self.regularized_param_groups(): + group.pop("num_steps", None) + return + @staticmethod def quantize_( p: Tensor, @@ -106,52 +131,92 @@ def quantize_( quants.copy_(Q) return q - def regularized_param_groups(self): # pyre-ignore[3] + def regularized_param_groups(self) -> Generator[dict[str, Any], None, None]: """Yield parameter groups that need to be quantized.""" for group in self.param_groups: if group.get("quant_bits", 16) < 16: yield group - @property - def _param_set(self) -> set[int]: - return { - p.data_ptr() - for group in self.regularized_param_groups() - for p in group["params"] - } - - def get_filter_fn( - self, module: torch.nn.Module - ) -> Callable[[torch.nn.Module], bool]: - param_set = self._param_set - - def _filter_fn(module: torch.nn.Module, *args) -> bool: + def _param_sets(self) -> Generator[set[int], None, None]: + for group in self.regularized_param_groups(): + yield {p.data_ptr() for p in group["params"]} + + def get_filter_fns( + self, module: nn.Module + ) -> Generator[Callable[[nn.Module], bool], None, None]: + def _filter_fn(module: nn.Module, *args, param_set) -> bool: for p in module.parameters(recurse=False): if p.data_ptr() in param_set: return True return False - return _filter_fn + for param_set in self._param_sets(): + yield partial(_filter_fn, param_set=param_set) + + def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None: + """Converts model parameters to torchao quantized tensor subclasses.""" + model.eval() + self.restore_latent_params() + + # TODO(lvj): find more robust way to identify embedding layers + embed_data_ptrs = set() + linear_data_ptrs = set() + for module in model.modules(): + if isinstance(module, nn.Embedding): + embed_data_ptrs.add(module.weight.data_ptr()) + elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs: + linear_data_ptrs.add(module.weight.data_ptr()) + + filter_fns = [] + configs = [] + attach_hf_config = _is_hf_model(model) + all_linear_layers_idx = -1 + for i, (group, filter_fn) in enumerate( + zip(self.regularized_param_groups(), self.get_filter_fns(model)) + ): + filter_fns.append(filter_fn) + quantizer = group.get("quantizer", self.quantizer) + if not isinstance(quantizer, UnifTorchaoQuantizer) or not group["params"]: + configs.append(None) + continue + + if set((p.data_ptr() for p in group["params"])) == linear_data_ptrs: + all_linear_layers_idx = i + + device = group["params"][0].device + any_embed = any(p.data_ptr() in embed_data_ptrs for p in group["params"]) + config = _get_config_from_quantizer( + quantizer, + weight_only or any_embed, + device, + group["quant_bits"], + group.get("quant_block_size"), + ) + configs.append(config) + + filter_fns_orig = filter_fns[:] + configs_orig = configs[:] + + # If one group has all the linear layers, then set its config as default + if all_linear_layers_idx > -1: + module_to_config = {"_default": configs[all_linear_layers_idx]} + del filter_fns[all_linear_layers_idx] + del configs[all_linear_layers_idx] + else: + module_to_config = None + + if attach_hf_config: + _attach_hf_quantization_config(model, filter_fns, configs, module_to_config) + + for config, filter_fn in zip(configs_orig, filter_fns_orig): + quantize_(model, config, filter_fn=filter_fn) @torch._disable_dynamo def state_dict(self) -> dict[str, Any]: - state_dict = self.base_optimizer.state_dict() - state_dict["qat_state"] = {"num_steps": self.num_steps} - # quantizer and prox_map may also need to save states, can add here - return state_dict + return self.base_optimizer.state_dict() @torch._disable_dynamo - def load_state_dict( - self, state_dict: dict[str, Any], start_step: Optional[int] = None - ) -> None: - qat_state = state_dict.get("qat_state") - # resume from check points usually not corresponds to saved num_steps - # so allow explicit start_step computed from epochs * steps_per_epoc - if start_step is not None: - self.num_steps = start_step - elif qat_state is not None: - # hope discrepancy in num_steps does not cause major problem! - self.num_steps = qat_state["num_steps"] + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.base_optimizer.load_state_dict(state_dict) @torch.no_grad() @@ -191,6 +256,10 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] quant_update = False for group in self.regularized_param_groups(): + # Override quantizer if specified in the group + quantizer = group.get("quantizer", self.quantizer) + assert isinstance(quantizer, Quantizer), f"Invalid {quantizer=}" + # AProx in practice: ensure shrinkage coefficient >= 1 group["cumu_lr"] += group["lr"] gamma = max(1.0, group["cumu_lr"]) @@ -224,7 +293,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] # update quantization targets periodically per_channel = self.quant_per_channel and p.dim() > 1 if quant_update: - quant_size = self.quantizer.get_quant_size(b) + quant_size = quantizer.get_quant_size(b) if per_channel: quant_size = (p.size(0), quant_size) @@ -242,9 +311,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] q = None if quant_update: - qfunc = partial( - self.quantize_, quantizer=self.quantizer, b=b, dim=dim - ) + qfunc = partial(self.quantize_, quantizer=quantizer, b=b, dim=dim) if is_dtensor(p): qfunc = local_map( qfunc, diff --git a/torchao/prototype/parq/quant/__init__.py b/torchao/prototype/parq/quant/__init__.py index c8b8365725..4542554298 100644 --- a/torchao/prototype/parq/quant/__init__.py +++ b/torchao/prototype/parq/quant/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from .config_torchao import StretchedIntxWeightConfig # noqa: F401 from .lsbq import LSBQuantizer # noqa: F401 from .quantizer import Quantizer # noqa: F401 from .uniform import ( # noqa: F401 @@ -13,5 +14,6 @@ ) from .uniform_torchao import ( # noqa: F401 Int4UnifTorchaoQuantizer, + StretchedUnifTorchaoQuantizer, UnifTorchaoQuantizer, ) diff --git a/torchao/prototype/parq/quant/config_torchao.py b/torchao/prototype/parq/quant/config_torchao.py new file mode 100644 index 0000000000..f327042b6b --- /dev/null +++ b/torchao/prototype/parq/quant/config_torchao.py @@ -0,0 +1,209 @@ +import types +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +from torch import nn + +from torchao.core.config import AOBaseConfig +from torchao.dtypes import Int4CPULayout, Layout, QDQLayout +from torchao.quantization import MappingType, PerAxis, PerGroup +from torchao.quantization.quant_api import ( + Granularity, + Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + ModuleFqnToConfig, + _linear_extra_repr, +) +from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor +from torchao.quantization.transform_module import register_quantize_module_handler +from torchao.utils import check_cpu_version + +from .quant_api import choose_qparams_stretched_affine, quantize_stretched_affine +from .uniform_torchao import ( + _BIT_WIDTH_TO_DTYPE, + Int4UnifTorchaoQuantizer, + StretchedUnifTorchaoQuantizer, +) + +try: + from transformers import PretrainedConfig, TorchAoConfig + + TRANSFORMERS_AVAIL = True +except ImportError: + TRANSFORMERS_AVAIL = False + + +@dataclass +class StretchedIntxWeightConfig(AOBaseConfig): + granularity: Granularity = PerAxis(0) + scale_dtype: Optional[torch.dtype] = None + layout: Layout = QDQLayout() + version: int = 2 + b: Optional[int] = None + quant_min: Optional[int] = None + quant_max: Optional[int] = None + activation_quantization: Optional[str] = "int8_asym_per_token" + + +@register_quantize_module_handler(StretchedIntxWeightConfig) +def _int8_dynamic_activation_stretched_intx_transform( + module: nn.Module, config: StretchedIntxWeightConfig +) -> nn.Module: + weight = module.weight + granularity = config.granularity + mapping_type = MappingType.ASYMMETRIC + + if config.version != 2: + raise NotImplementedError(f"Unsupported {config.version=}") + + assert weight.dim() == 2, ( + f"StretchedIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}" + ) + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + elif isinstance(granularity, PerAxis): + assert granularity.axis == 0, ( + f"axis must be 0 with PerAxis, but got {granularity.axis}" + ) + group_size = weight.shape[-1] + else: + raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") + + block_size = (1, group_size) + target_dtype = torch.int8 + q_args = (weight, mapping_type, block_size, target_dtype, config.b) + scale, zero_point = choose_qparams_stretched_affine( + *q_args, + quant_min=config.quant_min, + quant_max=config.quant_max, + ) + qdata = quantize_stretched_affine( + weight, + block_size, + scale, + zero_point, + target_dtype, + quant_min=config.quant_min, + quant_max=config.quant_max, + ) + n_blocks = [qdata.shape[i] // block_size[i] for i in range(len(block_size))] + scale = scale.reshape(*n_blocks) + zero_point = zero_point.reshape(*n_blocks) + + weight = IntxUnpackedToInt8Tensor( + qdata=qdata, + scale=scale, + zero_point=zero_point, + target_dtype=getattr(torch, f"int{config.b}"), + block_size=block_size, + dtype=weight.dtype, + activation_quantization=config.activation_quantization, + ) + module.weight = nn.Parameter(weight, requires_grad=False) + + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + + return module + + +def _get_config_from_quantizer( + quantizer, + weight_only: bool, + device: torch.device, + b: int, + block_size: Optional[int], + version: int = 2, +) -> AOBaseConfig: + granularity = PerGroup(block_size) if block_size is not None else PerAxis(0) + weight_dtype = _BIT_WIDTH_TO_DTYPE[b] + if isinstance(quantizer, Int4UnifTorchaoQuantizer): + config = Int4WeightOnlyConfig( + group_size=block_size, + version=version, + ) + if check_cpu_version(device): + config.layout = Int4CPULayout() + config.version = 1 + elif isinstance(quantizer, StretchedUnifTorchaoQuantizer): + config = StretchedIntxWeightConfig( + b=b, + quant_min=quantizer.quant_min, + quant_max=quantizer.quant_max, + granularity=granularity, + version=version, + ) + if weight_only: + config.activation_quantization = None + elif weight_only: + config = IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=granularity, + mapping_type=quantizer.mapping_type, + version=version, + ) + else: + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=granularity, + weight_mapping_type=quantizer.mapping_type, + act_mapping_type=MappingType.ASYMMETRIC, + version=version, + ) + return config + + +def _is_hf_model(model: nn.Module) -> bool: + return TRANSFORMERS_AVAIL and isinstance( + getattr(model, "config", None), PretrainedConfig + ) + + +def _attach_hf_quantization_config( + model: nn.Module, + filter_fns: list[Callable[nn.Module, bool]], + configs: list[AOBaseConfig], + module_to_config: Optional[dict[str, AOBaseConfig]] = None, +) -> None: + """Attaches torchao quantization config(s) to Hugging Face model. + + Args: + model: nn.Module - Hugging Face model. + filter_fns: list[Callable[nn.Module, bool]] - Callables that correspond + to `configs`. Each `filter_fns[i]` returns whether the input module + should be quantized with `configs[i]`. A module can map to at most + one config. + configs: list[AOBaseConfig] - torchao quantization configs inferred by + `QuantOptimizer`. Each config corresponds to a param group returned + by `optimizer.regularized_param_groups()`. + """ + assert _is_hf_model(model), "model is not a Hugging Face model" + assert len(filter_fns) == len(configs), ( + "filter_fns and configs must have the same length" + ) + + if module_to_config is None: + module_to_config = {} + + tied_weights_keys = set(getattr(model, "_tied_weights_keys", [])) + modules_to_not_convert = [] + for name, module in model.named_modules(): + if not hasattr(module, "weight"): + continue + + # Do not quantize pointers to tied weights or normalization layers + if f"{name}.weight" in tied_weights_keys or "norm" in name: + modules_to_not_convert.append(name) + continue + + for i, filter_fn in enumerate(filter_fns): + if filter_fn(module): + module_to_config[name] = configs[i] + + model.config.quantization_config = TorchAoConfig( + quant_type=ModuleFqnToConfig(module_to_config), + include_input_output_embeddings=True, + modules_to_not_convert=modules_to_not_convert, + ) diff --git a/torchao/prototype/parq/quant/lsbq.py b/torchao/prototype/parq/quant/lsbq.py index 2d9f4e4c1e..0154f3c543 100644 --- a/torchao/prototype/parq/quant/lsbq.py +++ b/torchao/prototype/parq/quant/lsbq.py @@ -70,7 +70,7 @@ def compute_v_per_channel(p: Tensor, dim: Optional[int] = None, ternary: bool = r = r.sub(v * binary_sign(r)) # compute least squares error, then select the `v` minimizes it - costs = r.norm(dim=dim) + costs = torch.linalg.vector_norm(r, dim=dim) indices = costs.argmin(dim=dim, keepdim=True) v_best = v_cands.gather(1, indices) return v_best @@ -196,10 +196,10 @@ def quantize_optimal_2bits( V1V2.append((v1, v2)) assert len(V1V2) > 0, "LSBQ 2-bit optimal: No solution found." # find the best solution with least-square quantization error - min_error = p.norm() + min_error = torch.linalg.vector_norm(p) for v1v2 in V1V2: r = binary_quant_residue(p, v1v2) - error = r.norm() + error = torch.linalg.vector_norm(r) if error < min_error: min_error = error q = p - r @@ -244,14 +244,14 @@ def quantize_optimal_ternary( v_feasible.append(v) assert len(v_feasible) > 0, "LSBQ ternary optimal: No solution found." # find the best solution with least-square quantization error - min_error = p.norm() + min_error = torch.linalg.vector_norm(p) q_best = torch.zeros_like(p) v_best = torch.zeros_like(v) for v in v_feasible: Q = v * torch.tensor([-1.0, 0.0, 1.0], device=p.device) boundaries = v * torch.tensor([-0.5, 0.5], device=p.device) q = Q[torch.bucketize(p, boundaries)] - error = torch.linalg.norm(p - q) + error = torch.linalg.vector_norm(p - q) if error < min_error: min_error = error q_best = q diff --git a/torchao/prototype/parq/quant/quant_api.py b/torchao/prototype/parq/quant/quant_api.py new file mode 100644 index 0000000000..608fd9570e --- /dev/null +++ b/torchao/prototype/parq/quant/quant_api.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Union + +import torch + +from torchao.quantization import ( + MappingType, +) +from torchao.quantization.quant_primitives import ( + _SUB_BYTE_UINT_BOUNDS, + _get_reduction_params, +) + + +def choose_qparams_stretched_affine( + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + b: int, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale_dtype is None: + scale_dtype = input_float.dtype + if eps is None: + eps = torch.finfo(input_float.dtype).eps + if zero_point_dtype is None: + zero_point_dtype = input_float.dtype + + assert len(block_size) == input_float.dim(), f"Got {input.dim()=}, {block_size=}" + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input_float.size() + ) + input_float = input_float.view(shape_for_reduction) + + q_abs = input_float.abs() + max_val = torch.minimum( + b * q_abs.mean(dim=reduction_dims, keepdim=True), + torch.amax(q_abs, dim=reduction_dims, keepdim=True), + ).clamp_(min=eps) + + scale = max_val / quant_max + scale = scale.to(dtype=scale_dtype, device=input_float.device) + zero_point = torch.full_like(scale, -0.5, dtype=zero_point_dtype) + return scale, zero_point + + +def quantize_stretched_affine( + input_float: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: torch.Tensor, + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, +) -> torch.Tensor: + if target_dtype in _SUB_BYTE_UINT_BOUNDS: + target_dtype = torch.uint8 + assert input_float.dtype in (torch.float32, torch.float16, torch.bfloat16), ( + f"Unsupported input_float dtype: {input_float.dtype}" + ) + assert len(block_size) == input_float.dim(), ( + f"Got {input_float.dim()=}, {block_size=}" + ) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input_float.size() + ) + original_shape = input_float.shape + input_float = input_float.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + + if zero_point is not None and zero_point.numel() > 0: + zero_point = zero_point.view(shape_after_reduction) + else: + zero_point = None + + max_val = scale.mul(quant_max) + input_float = input_float.clamp(min=-max_val, max=max_val) + with torch.no_grad(): + # difference from quantize_affine: add zero_point before rounding + quant = torch.round(input_float / scale + zero_point) + quant = quant.to(dtype=target_dtype).view(original_shape) + return quant diff --git a/torchao/prototype/parq/quant/uniform_torchao.py b/torchao/prototype/parq/quant/uniform_torchao.py index ebe4e775e6..ad58bf6592 100644 --- a/torchao/prototype/parq/quant/uniform_torchao.py +++ b/torchao/prototype/parq/quant/uniform_torchao.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import math +from functools import partial from typing import Optional, Union import torch @@ -25,6 +27,7 @@ quantize_affine, ) +from .quant_api import choose_qparams_stretched_affine, quantize_stretched_affine from .quantizer import Quantizer _BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} @@ -56,17 +59,16 @@ def __init__( self._quantize = quantize_affine self._dequantize = dequantize_affine - if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: - self._choose_qparams = _choose_qparams_affine_tinygemm - self._quantize = _quantize_affine_tinygemm - self._dequantize = _dequantize_affine_tinygemm - elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: - self._choose_qparams = _choose_qparams_affine_dont_preserve_zero - self._quantize = quantize_affine - self._dequantize = dequantize_affine - elif zero_point_domain == ZeroPointDomain.NONE: + if zero_point_domain == ZeroPointDomain.NONE and not preserve_zero: self._quantize = _quantize_affine_no_zero_point self._dequantize = _dequantize_affine_no_zero_point + elif mapping_type == MappingType.ASYMMETRIC: + if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: + self._choose_qparams = _choose_qparams_affine_tinygemm + self._quantize = _quantize_affine_tinygemm + self._dequantize = _dequantize_affine_tinygemm + elif zero_point_domain == ZeroPointDomain.INT and not preserve_zero: + self._choose_qparams = _choose_qparams_affine_dont_preserve_zero def _init_quant_min_max(self, b: int) -> None: if self.quant_min is None or self.quant_max is None: @@ -113,9 +115,12 @@ def quantize( quant_max=self.quant_max, ) - Q = torch.arange( - self.quant_min, self.quant_max + 1, dtype=self.target_dtype, device=p.device - ) + Q = torch.arange(self.quant_min, self.quant_max + 1e-5, device=p.device) + + if isinstance(self.quant_min, float): + Q = Q.floor() + Q = Q.to(dtype=self.target_dtype) + if dim is not None: Q = Q.view(1, -1).expand(q.size(0), -1) block_size = (1, Q.size(-1)) @@ -133,6 +138,27 @@ def quantize( return q, Q +class StretchedUnifTorchaoQuantizer(UnifTorchaoQuantizer): + def __init__(self, b: int, int_shift: float = 0.5, **kwargs) -> None: + quant_absmax = 2 ** (b - 1) - int_shift + self.quant_min = -quant_absmax + self.quant_max = quant_absmax + self.int_shift = int_shift + + super().__init__( + mapping_type=MappingType.ASYMMETRIC, + quant_min=self.quant_min, + quant_max=self.quant_max, + **kwargs, + ) + + self._choose_qparams = partial(choose_qparams_stretched_affine, b=b) + self._quantize = quantize_stretched_affine + + def get_quant_size(self, b: int) -> int: + return math.floor(2**b - 2 * self.int_shift) + 1 + + class Int4UnifTorchaoQuantizer(UnifTorchaoQuantizer): """Based on torchao.quantization.quant_api._int4_weight_only_transform""" diff --git a/torchao/prototype/parq/utils.py b/torchao/prototype/parq/utils.py index ac5024fb5d..d4c0a603b6 100644 --- a/torchao/prototype/parq/utils.py +++ b/torchao/prototype/parq/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from importlib import import_module + import torch from torch import Tensor @@ -15,6 +17,10 @@ HAS_DTENSOR = False +def instantiate_module(module_path, module_suffix): + return getattr(import_module(module_path), module_suffix) + + def is_dtensor(x): return HAS_DTENSOR and isinstance(x, DTensor) diff --git a/torchao/prototype/qat/__init__.py b/torchao/prototype/qat/__init__.py new file mode 100644 index 0000000000..0727a1c673 --- /dev/null +++ b/torchao/prototype/qat/__init__.py @@ -0,0 +1,12 @@ +# Temporary location for prototype QAT features that will +# eventually live in torchao/quantization/qat + +from .nvfp4 import ( + NVFP4FakeQuantizeConfig, + NVFP4FakeQuantizer, +) + +__all__ = [ + "NVFP4FakeQuantizeConfig", + "NVFP4FakeQuantizer", +] diff --git a/torchao/prototype/qat/nvfp4.py b/torchao/prototype/qat/nvfp4.py new file mode 100644 index 0000000000..ed709dba1d --- /dev/null +++ b/torchao/prototype/qat/nvfp4.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass + +import torch + +from torchao.prototype.mx_formats.nvfp4_tensor import ( + _nvfp4_quantize, + per_tensor_amax_to_scale, +) +from torchao.quantization.qat import ( + FakeQuantizeConfigBase, + FakeQuantizerBase, +) + + +@dataclass +class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for fake quantizing weights or activations to NVIDIA's NVFP4 format + according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. + + Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py. + + Args: + use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling + after the initial fp8 (e4m3) block-wise scaling (default True) + """ + + use_per_tensor_scale: bool = True + + +class NVFP4FakeQuantizer(FakeQuantizerBase): + """ + (Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config. + """ + + def __init__(self, config: NVFP4FakeQuantizeConfig): + super().__init__() + torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer") + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + block_size = 16 + original_shape = x.shape + if x.dim() == 3: + x = x.view(-1, x.shape[-1]) + if self.config.use_per_tensor_scale: + tensor_amax = torch.max(torch.abs(x)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + else: + per_tensor_scale = None + + # quantize + scale, q = _nvfp4_quantize( + x, + block_size=block_size, + per_tensor_scale=per_tensor_scale, + skip_dtype_cast_and_packing=True, + ) + if self.config.use_per_tensor_scale: + scale = scale * per_tensor_scale + assert q.dtype == x.dtype + assert scale.dtype == torch.float32 + + # dequantize + M, K = q.shape[0], q.shape[1] + q = q.view(M, K // block_size, block_size) + scale = scale.view(M, K // block_size, 1) + dq = q * scale + return dq.view(original_shape).to(x.dtype) diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index 8966bd5226..1240bbacd0 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -45,10 +45,8 @@ Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) -from torchao.quantization.utils import quantize_activation_per_token_absmax +from torchao.quantization.utils import _quantize_activation_per_token_absmax from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, is_sm_at_least_89, is_sm_at_least_90, @@ -74,7 +72,7 @@ def _is_linear(mod, *args): # avoid circular dependencies from torchao.quantization.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, + _AffineFakeQuantizedTensor, ) # adding weight tensor subclass isinstance check to make sure the weight is only quantized once @@ -86,7 +84,7 @@ def _is_linear(mod, *args): and not isinstance(mod.weight, AutoQuantizableLinearWeightV1) and not isinstance(mod.weight, AffineQuantizedTensor) and not isinstance(mod.weight, LinearActivationQuantizedTensor) - and not isinstance(mod.weight, AffineFakeQuantizedTensor) + and not isinstance(mod.weight, _AffineFakeQuantizedTensor) and not isinstance(mod, torch.nn.modules.linear.NonDynamicallyQuantizableLinear) ) @@ -110,7 +108,7 @@ def _graph_equals(g1, g2): aten = torch.ops.aten -AUTOQUANT_CACHE = {} +_AUTOQUANT_CACHE = {} # This is a flag to control whether we do some rewrite for graph # to account for different batch sizes, it's a temporary solution for llama model @@ -119,15 +117,15 @@ def _graph_equals(g1, g2): def check_cache(gm, cls, shapes_and_dtype): - for gm_, cls_, shapes_and_dtype_ in AUTOQUANT_CACHE.keys(): + for gm_, cls_, shapes_and_dtype_ in _AUTOQUANT_CACHE.keys(): graph_equals = _graph_equals(gm_.graph, gm.graph) if graph_equals and cls_ is cls and shapes_and_dtype_ == shapes_and_dtype: - return AUTOQUANT_CACHE[(gm_, cls_, shapes_and_dtype_)] + return _AUTOQUANT_CACHE[(gm_, cls_, shapes_and_dtype_)] return None def update_cache(gm, cls, shapes_and_dtype, res): - AUTOQUANT_CACHE[(gm, cls, shapes_and_dtype)] = res + _AUTOQUANT_CACHE[(gm, cls, shapes_and_dtype)] = res # adjust each input's bsz to target_bsz @@ -469,6 +467,8 @@ def do_autoquant_bench(op, *args, **kwargs): """ runs benchmark op(*args, **kwargs) avoiding torch.compile overhead """ + from torch._inductor.runtime.benchmarking import benchmarker + rep = kwargs.pop("rep", 100) warmup = kwargs.pop("warmup", 25) with torch.no_grad(): @@ -483,24 +483,9 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - if TORCH_VERSION_AT_LEAST_2_5: - from torch._inductor.runtime.benchmarking import benchmarker - - res = benchmarker.benchmark_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - elif TORCH_VERSION_AT_LEAST_2_3: - from torch._inductor.runtime.runtime_utils import do_bench_gpu - - res = do_bench_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - else: - from torch._inductor.utils import do_bench - - res = do_bench( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) + res = benchmarker.benchmark_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) return res @@ -638,7 +623,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): # SAM best is between .8 and 1, SDXL also performs best in this range INTERPOLATION_CONSTANT = mode[1] w_qtensor = cls.from_float(weight) - x_vals_int8, x_scales = quantize_activation_per_token_absmax( + x_vals_int8, x_scales = _quantize_activation_per_token_absmax( act_mat.reshape(-1, act_mat.shape[-1]) ) quantized_matmul = ( diff --git a/torchao/prototype/quantization/codebook/codebook_ops.py b/torchao/prototype/quantization/codebook/codebook_ops.py index ca81ce0453..201dc30f27 100644 --- a/torchao/prototype/quantization/codebook/codebook_ops.py +++ b/torchao/prototype/quantization/codebook/codebook_ops.py @@ -198,8 +198,8 @@ def choose_qparams_codebook( dim=(-1), keepdim=True ).values # Shape: [*input_size[:-1], num_scale_blocks, 1] else: - scales = input.norm( - dim=(-1), keepdim=True + scales = torch.linalg.vector_norm( + input, dim=-1, keepdim=True ) # Shape: [*input_size[:-1], num_scale_blocks, 1] scales = torch.clamp(scales, min=1e-9) @@ -228,12 +228,14 @@ def _kmeans_greedy_init(data: torch.Tensor, k: int) -> torch.Tensor: running_min_distances = torch.full( (data.shape[0],), torch.inf, device=data.device, dtype=data.dtype ) - data_norm_squared = data.norm(p=2, dim=1).square() + data_norm_squared = torch.linalg.vector_norm(data, dim=1).square() for i in range(k): clusters[i] = data[running_min_distances.argmax()] distances_to_cluster_i = ( - data_norm_squared - 2 * data @ clusters[i] + clusters[i].norm().square() + data_norm_squared + - 2 * data @ clusters[i] + + torch.linalg.vector_norm(clusters[i]).square() ) running_min_distances = torch.minimum( running_min_distances, distances_to_cluster_i, out=running_min_distances diff --git a/torchao/prototype/quantization/codebook_coreml/__init__.py b/torchao/prototype/quantization/codebook_coreml/__init__.py new file mode 100644 index 0000000000..d0da8fcaf1 --- /dev/null +++ b/torchao/prototype/quantization/codebook_coreml/__init__.py @@ -0,0 +1,13 @@ +from .api import CodebookWeightOnlyConfig +from .codebook_ops import ( + choose_qparams_and_quantize_codebook_coreml, + dequantize_codebook, +) +from .codebook_quantized_tensor import CodebookQuantizedTensor + +__all__ = [ + "CodebookQuantizedTensor", + "CodebookWeightOnlyConfig", + "choose_qparams_and_quantize_codebook_coreml", + "dequantize_codebook", +] diff --git a/torchao/prototype/quantization/codebook_coreml/api.py b/torchao/prototype/quantization/codebook_coreml/api.py new file mode 100644 index 0000000000..36fa0d299f --- /dev/null +++ b/torchao/prototype/quantization/codebook_coreml/api.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List + +import torch + +from torchao.core.config import AOBaseConfig +from torchao.prototype.quantization.codebook_coreml.codebook_quantized_tensor import ( + CodebookQuantizedTensor, +) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) +from torchao.utils import is_package_at_least + + +@dataclass +class CodebookWeightOnlyConfig(AOBaseConfig): + dtype: torch.dtype + block_size: List[int] + + +@register_quantize_module_handler(CodebookWeightOnlyConfig) +def _codebook_weight_only_transform( + module: torch.nn.Module, + config: CodebookWeightOnlyConfig, +): + """ + Applies codebook weight-only quantization to linear layers. + + Args: + dtype: torch.uint1 to torch.uint8, torch.int32 supported. + Returns: + Callable for quantization transformation. + """ + if not is_package_at_least("coremltools", "8.3.0"): + raise ImportError("Requires coremltools >= 8.3.0") + + dtype = config.dtype + weight = module.weight + + quantized_weight = CodebookQuantizedTensor.from_float( + weight, + dtype, + config.block_size, + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + return module diff --git a/torchao/prototype/quantization/codebook_coreml/codebook_ops.py b/torchao/prototype/quantization/codebook_coreml/codebook_ops.py new file mode 100644 index 0000000000..c5f56c9d62 --- /dev/null +++ b/torchao/prototype/quantization/codebook_coreml/codebook_ops.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional, Tuple + +import torch + +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_BIT_WIDTH, + _SUB_BYTE_UINT_BOUNDS, +) +from torchao.utils import _register_custom_op + +quant_lib = torch.library.Library("quant", "FRAGMENT") +register_custom_op = _register_custom_op(quant_lib) + + +# wrapper around coreml util: https://github.com/apple/coremltools/blob/1c0e5cb1c1e3ab759af107b54f2be18b7c03f8aa/coremltools/models/neural_network/quantization_utils.py#L363 +@torch.no_grad +@register_custom_op +def choose_qparams_and_quantize_codebook_coreml( + input_tensor: torch.Tensor, + code_dtype: torch.dtype, + block_size: List[int], + force_kmeans1d: bool = False, + cluster_dim: int = 1, + vector_axis: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Initialize the codebook using k-means clustering on blocks of the input tensor. + + Args: + input_tensor (torch.Tensor): The input tensor to be quantized. + code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8] + block_size (List[int]): block sizes for how many elements in each dimension share + the same lookup table (len(block_size) == input_tensor.dim()) + Each dimension of input_tensor must be divisible by the corresponding element of block_size + Look up tables are indexed by {(di // bi) for i in input_tensor.dim()} + For example, if the input tensor has shape (N, K), and block_size is (N, group_size), this means + there is a lookup table for group_size columns, i.e., (K // group_size) total look up tables + force_kmeans1d (bool): Use kmeans1d regardless of number of weights + cluster_dim (int): this means the size of the vector for vector lookup table quantization + e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize + the tensor in a unit of 4 element vectors, a vector of original tensor will be mapped to + a vector in the codebook (lookup table) based on the indices. + vector_axis (Optional[int]): used in vector quantization, see more docs in https://github.com/apple/coremltools/blob/1c0e5cb1c1e3ab759af107b54f2be18b7c03f8aa/coremltools/optimize/_utils.py#L371 + + Returns: + Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8) + The LUT table has dimension (g0, .., g(N-1), 2**nbits, vec_dim), where: + * The first N dimensions index over the different tables (gi = input_tensor.shape[i] // block_size[i] in each dimension) + * The N + 1 dimension indexes over the nbit indices (2 ** nbits) + * The N + 2 dimension indexes over the look up values (shape = 1 for scalar) + """ + assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8] + nbits = _DTYPE_TO_BIT_WIDTH[code_dtype] + assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}" + assert input_tensor.dim() == 2, "Currently only rank 2 tensors are supported" + assert cluster_dim == 1, ( + f"only cluster_dim == 1 is supported right now, got {cluster_dim}" + ) + + original_shape = input_tensor.shape + N, K = original_shape + input_tensor = input_tensor.detach() + + # --- Process block_size --- + assert len(block_size) == 2 + processed_block_size = block_size.copy() + if processed_block_size[0] == -1: + processed_block_size[0] = N + if processed_block_size[1] == -1: + processed_block_size[1] = K + + row_block_size, col_block_size = processed_block_size + assert N % row_block_size == 0, ( + f"Tensor rows ({N}) not divisible by row block size ({row_block_size})" + ) + assert K % col_block_size == 0, ( + f"Tensor cols ({K}) not divisible by col block size ({col_block_size})" + ) + + # --- Determine and execute grouping strategy --- + assert row_block_size == N or col_block_size == K + is_col_grouping = row_block_size == N + + res_lut_list = [] + from coremltools.models.neural_network.quantization_utils import ( + _get_kmeans_lookup_table_and_weight, + ) + + if is_col_grouping: + # STRATEGY 1: Group by COLUMNS + num_luts = K // col_block_size + reshaped_tensor = input_tensor.reshape(N, num_luts, col_block_size) + res_codes = torch.zeros_like(reshaped_tensor, dtype=torch.uint8) + + for i in range(num_luts): + block_to_quantize = reshaped_tensor[:, i, :] + lut, w = _get_kmeans_lookup_table_and_weight( + nbits, block_to_quantize, force_kmeans1d, cluster_dim, vector_axis + ) + res_lut_list.append(torch.from_numpy(lut)) + res_codes[:, i, :] = torch.from_numpy(w.reshape(N, col_block_size)) + + # Shape to match CoreML spec: (1, num_luts, 2**nbits, 1) + final_luts = torch.stack(res_lut_list, dim=0).reshape(1, num_luts, 2**nbits, 1) + + else: # is_row_grouping + # STRATEGY 2: Group by ROWS + num_luts = N // row_block_size + reshaped_tensor = input_tensor.reshape(num_luts, row_block_size, K) + res_codes = torch.zeros_like(reshaped_tensor, dtype=torch.uint8) + + for i in range(num_luts): + block_to_quantize = reshaped_tensor[i, :, :] + lut, w = _get_kmeans_lookup_table_and_weight( + nbits, block_to_quantize, force_kmeans1d, cluster_dim, vector_axis + ) + res_lut_list.append(torch.from_numpy(lut)) + res_codes[i, :, :] = torch.from_numpy(w.reshape(row_block_size, K)) + + final_luts_stacked = torch.stack( + res_lut_list, dim=0 + ) # Shape: (num_luts, 2**nbits, 1) + + # Reshape to the consistent 4D format + # The shape is (num_row_groups, 1, 2**nbits, 1) + final_luts = final_luts_stacked.reshape(num_luts, 1, 2**nbits, 1) + + # Reshape codes back to the original tensor shape + final_codes = res_codes.reshape(*original_shape) + + return final_luts, final_codes + + +@register_custom_op +def dequantize_codebook( + codes: torch.Tensor, + codebook: torch.Tensor, + nbits: int, + block_size: List[int], + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Reconstructs the original tensor from codes and the codebook. + + Args: + codes (torch.Tensor): Indices of codebook entries for each element + General shape: (d0, d1, d2, ..., dN) + Simple example shape: (N, K) + codebook (torch.Tensor): Codebook tensor used for quantization + General shape: (d0 // block_size[0], ..., dN // block_size[N], 2**nbits, vec_dim), where vec_dim = 1 for scalar look up values + Simple example shape: (1, group_size, 2 ** nbits, 1) for scalar look up values, with 1 table per group_size columns + nbits: int: number of bits for the quantization + block_size (List[int]): a slice of elements with shape block_size will share the same lookup table. + If block_size[i] == -1, then the entire dimension is used. + output_dtype (torch.dtype): dtype for the output tensor. + + Returns: + dequant (torch.Tensor): Reconstructed tensor, shape (N, K) + + """ + assert output_dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], f"Unsupported output dtype: {output_dtype}" + + assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}" + + assert len(block_size) == codes.dim() + block_size = block_size.copy() + for i in range(len(block_size)): + if block_size[i] == -1: + block_size[i] = codes.shape[i] + assert block_size[i] >= 1 and codes.shape[i] % block_size[i] == 0, ( + "block_size[i] must divide codes.shape[i]" + ) + + assert codebook.dim() == codes.dim() + 2 + codebook_shape = codebook.shape + vec_dim = codebook_shape[-1] + quant_levels = 2**nbits + + # Check that last two dimensions of codebook are [quant_levels, vec_dim] + assert codebook_shape[-2] == quant_levels, "Codebook shape mismatch with nbits" + + # Compute shape of lookup group indices from codes shape and block size + code_shape = codes.shape + ndim = codes.ndim + assert len(block_size) == ndim, "block_size must match dimensionality of codes" + + # Compute which codebook slice to use for each element + group_indices = [] + for i in range(ndim): + assert block_size[i] >= 1 and code_shape[i] % block_size[i] == 0, ( + f"dimension {code_shape[i]} not divisible by block size {block_size[i]}" + ) + + # Index of block + idx = ( + torch.arange(code_shape[i], device=codes.device) // block_size[i] + ) # shape (di,) + + # Reshape idx to broadcast along all other dims + shape = [1] * ndim + shape[i] = code_shape[i] + idx = idx.view(*shape) # shape (1, ..., 1, di, 1, ..., 1) + idx = idx.expand(code_shape) # shape (d0, ..., dN) + group_indices.append(idx) + + # Stack the broadcasted group indices + # group_index_tensor at (i0, i1, ..., iN) is the gives the group indices (g0, ..., gN) + # for the element at (i0, i1, ..., iN) in the original code + # If code.shape = (d1, d2, d3), then group_index_tensor.shape = (d1, d2, d3, 3) + group_index_tensor = torch.stack( + group_indices, dim=-1 + ) # shape (d0, d1, ..., dN, ndim) + + # Flatten everything to index efficiently + flat_codes = codes.reshape(-1) # shape (numel,) + flat_groups = group_index_tensor.reshape(-1, ndim) # (numel, ndim) + + # Compute dequantized values via indexing + # index into codebook with (*group_index, code_index, :) + gathered = codebook[(*flat_groups.T, flat_codes)] # shape (numel, vec_dim) + dequant = gathered.reshape(*code_shape, vec_dim) + + if vec_dim == 1: + dequant = dequant.squeeze(-1) + + return dequant.to(dtype=output_dtype) diff --git a/torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py new file mode 100644 index 0000000000..7283a23918 --- /dev/null +++ b/torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.prototype.quantization.codebook_coreml.codebook_ops import ( + choose_qparams_and_quantize_codebook_coreml, + dequantize_codebook, +) +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_BIT_WIDTH, +) +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + + +class CodebookQuantizedTensor(TorchAOBaseTensor): + """ + Codebook quantized tensor subclass. + + Codebook (lookup table) quantization involves partitioning the input tensor into blocks, and replacing each block + with the index of the closest entry in a predefined codebook. + + Fields: + codes (torch.Tensor): Tensor of indices representing blocks in the original tensor. Each index + maps to a corresponding codebook entry, torch.uint8 dtype. + codebook (torch.Tensor): Tensor representing the quantization codebook, where each entry + corresponds to a block in the original tensor. Shape is `(codebook_size, out_block_size, in_block_size)`. + code_dtype (torch.dtype): The logical dtype for the codes, [torch.uint1, ..., torch.uint8] + Note that codes is stored in torch.uint8, this is just addtional information for dequantize op + block_size (Tuple[int, ...]): Granularity of quantization, specifying the dimensions of tensor + blocks that share the same quantization parameters. + shape (torch.Size): Shape of the original high-precision tensor. + dtype (torch.dtype): dtype of the original high-precision tensor. + """ + + tensor_data_attrs = ["codes", "codebook"] + tensor_attributes = ["code_dtype", "block_size", "shape", "dtype"] + + @staticmethod + def __new__( + cls, + codes: torch.Tensor, + codebook: torch.Tensor, + code_dtype: torch.dtype, + block_size: List[int], + shape: torch.Size, + dtype=None, + ): + kwargs = {} + kwargs["device"] = codes.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else codes.layout + ) + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + codes: torch.Tensor, + codebook: torch.Tensor, + code_dtype: torch.dtype, + block_size: List[int], + shape: torch.Size, + dtype=None, + ): + self.codes = codes + self.codebook = codebook + self.code_dtype = code_dtype + self.block_size = block_size + + def __repr__(self): + return ( + f"{self.__class__.__name__}(codes={self.codes}, codebook={self.codebook}, code_dtype={self.code_dtype}, block_size={self.block_size} " + f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + + def _quantization_type(self): + return f"shape={self.shape}, codebook_shape={self.codebook.shape}, code_dtype={self.code_dtype}, block_size={self.block_size}, device={self.device}" + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + codes = self.codes + if codes.dtype != torch.int32: + # TODO: Investigate and support not casting to torch.int32 for indexing to improve performance + codes = codes.to(torch.int32) + + # Note: code_dtype is just for lowering pass to understand the range of values in codes + return dequantize_codebook( + codes, + self.codebook, + _DTYPE_TO_BIT_WIDTH[self.code_dtype], + self.block_size, + output_dtype=output_dtype, + ) + + def __tensor_flatten__(self): + return self.tensor_data_attrs, [ + getattr(self, attr) for attr in self.tensor_attributes + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_data_attrs], + *tensor_attributes, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], + *[getattr(self, attr) for attr in self.tensor_attributes], + ) + + @classmethod + def from_float( + cls, + input_tensor: torch.Tensor, + code_dtype: torch.dtype, + block_size: List[int], + ): + """ + Creates a CodebookQuantizedTensor from a floating-point tensor by performing codebook quantization. + + Args: + input_tensor (torch.Tensor): The input floating-point tensor to quantize. + code_dtype (torch.dtype): The dtype of the codes, Note the codes Tensor is stored in uint8 + chunk_size (int): The chunk size to use during quantization (to control memory usage). + """ + codebook, codes = choose_qparams_and_quantize_codebook_coreml( + input_tensor, code_dtype, block_size + ) + + assert codes.dtype == torch.uint8, "Only support using uint8 for codes for now" + + return cls( + codes, + codebook, + code_dtype, + block_size, + input_tensor.shape, + dtype=input_tensor.dtype, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + *[getattr(self, attr).to(device) for attr in self.tensor_data_attrs], + *[getattr(self, attr) for attr in self.tensor_attributes], + **kwargs, + ) + + +implements = CodebookQuantizedTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + weight_tensor = weight_tensor.dequantize() + return func(input_tensor, weight_tensor, bias) + + +@implements([torch.nn.functional.embedding, aten.embedding.default]) +def _(func, types, args, kwargs): + assert len(args) == 2 + indices, weight_tensor = ( + args[0], + args[1], + ) + weight_tensor = weight_tensor.dequantize() + return func(indices, weight_tensor, **kwargs) + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) diff --git a/torchao/prototype/quantization/codebook_groupwise/__init__.py b/torchao/prototype/quantization/codebook_groupwise/__init__.py new file mode 100644 index 0000000000..8cf56240cd --- /dev/null +++ b/torchao/prototype/quantization/codebook_groupwise/__init__.py @@ -0,0 +1,9 @@ +from .api import GroupwiseLutWeightConfig +from .codebook_quantized_tensor import CodebookQuantizedPackedTensor + +__all__ = [ + "CodebookQuantizedPackedTensor", + "GroupwiseLutWeightConfig", + "QuantizedLutEmbedding", + "EmbeddingLutQuantizer", +] diff --git a/torchao/prototype/quantization/codebook_groupwise/api.py b/torchao/prototype/quantization/codebook_groupwise/api.py new file mode 100644 index 0000000000..ff8f17b4d7 --- /dev/null +++ b/torchao/prototype/quantization/codebook_groupwise/api.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import types +from dataclasses import dataclass, field +from typing import List, Optional + +import torch + +from torchao.core.config import AOBaseConfig +from torchao.prototype.quantization.codebook_coreml.codebook_quantized_tensor import ( + CodebookQuantizedTensor, +) +from torchao.prototype.quantization.codebook_groupwise.codebook_quantized_tensor import ( + CodebookQuantizedPackedTensor, +) +from torchao.quantization.transform_module import register_quantize_module_handler + + +def _get_linear_extra_repr_for_lut(self) -> str: + """ + Custom __repr__ for a linear module quantized with GroupwiseLutQuantizedTensor. + """ + out_features, in_features = self.weight.shape + + # Access metadata from the custom tensor + bit_width = self.weight.bit_width + lut_group_size = self.weight.lut_group_size + scale_group_size = self.weight.scale_group_size + + # The original bias is fused into the packed weight, so self.bias is None. + has_bias = self.bias is not None + + return ( + f"in_features={in_features}, out_features={out_features}, bias={has_bias}, " + f"quant=GroupwiseLut(bit_width={bit_width}, lut_gs={lut_group_size}, " + f"scale_gs={scale_group_size}')" + ) + + +@dataclass +class GroupwiseLutWeightConfig(AOBaseConfig): + """ + The primary configuration for groupwise Look-Up Table (LUT) quantization. + + This config uses a `block_shape` to define the quantization strategy, + allowing for flexible grouping by either rows or columns. + + Args: + code_dtype (torch.dtype): The target logical dtype for the LUT indices + (e.g., torch.uint4, torch.int4). This determines the codebook size. + weight_dtype (torch.dtype): The target dtype for the raw weight (e.g., torch.float32). + + lut_block_shape (List[int]): Defines the grouping for the look-up table. + This is the key parameter for controlling quantization granularity. + - To group by N rows: use `[N, -1]`. Example: `[2, -1]` means + every 2 rows share a single LUT. + - To group by K columns: use `[-1, K]`. Example: `[-1, 64]` means + every 64 columns share a single LUT. + + scale_block_shape (Optional[List[int]]): Defines grouping for scale factors, + used only by the 'scale' backend. If provided, the 'scale' backend + is automatically selected. The same `[N, -1]` or `[-1, K]` pattern applies. + has_scale (bool): Whether to use scale factors. Defaults to False. + target (str): The backend target for the C++ kernel (e.g., "auto", "aten"). + """ + + # --- Attributes --- + code_dtype: torch.dtype = torch.int4 + weight_dtype: torch.dtype = torch.float32 + backend: str = "auto" + + lut_block_shape: List[int] = field(default_factory=lambda: [2, -1]) + + scale_block_shape: Optional[List[int]] = None + + use_qdq_reference: bool = False + target: Optional[str] = None + cache_dir: Optional[str] = None + has_scale: bool = False + + def __post_init__(self): + """Validate the configuration after initialization.""" + # 1. Validate backend string + if self.backend not in ["auto", "scale", "coreml"]: + raise ValueError(f"Invalid backend: {self.backend}") + + # 2. Validate lut_block_shape + if not ( + isinstance(self.lut_block_shape, list) and len(self.lut_block_shape) == 2 + ): + raise ValueError( + "`lut_block_shape` must be a list of length 2 (e.g., [N, -1] or [-1, K])." + ) + if self.lut_block_shape.count(-1) != 1: + raise ValueError( + "`lut_block_shape` must contain exactly one '-1' to specify the grouping dimension." + ) + if self.has_scale == True: + raise ValueError("currently only support lut quantization without scale") + + # 3. Validate scale_block_shape if it exists + if self.has_scale and self.scale_block_shape is not None: + if not ( + isinstance(self.scale_block_shape, list) + and len(self.scale_block_shape) == 2 + ): + raise ValueError( + "`scale_block_shape` must be a list of length 2 if provided." + ) + + +@register_quantize_module_handler(GroupwiseLutWeightConfig) +def _groupwise_lut_weight_transform( + module: torch.nn.Module, config: GroupwiseLutWeightConfig +) -> torch.nn.Module: + """ + Transforms a linear module by applying groupwise LUT-based weight quantization. + Automatically caches results if config.cache_dir is set, using a hash of + the weight tensor for a unique key. + """ + assert isinstance(module, torch.nn.Linear), ( + "This transform only applies to torch.nn.Linear modules." + ) + weight = module.weight.data + + quantized_tensor = CodebookQuantizedTensor.from_float( + weight, code_dtype=config.code_dtype, block_size=config.lut_block_shape + ) + + if not config.use_qdq_reference: + packed_weight = CodebookQuantizedPackedTensor.from_codebook_quantized_tensor( + tensor=quantized_tensor, bias=module.bias + ) + module.weight = torch.nn.Parameter(packed_weight, requires_grad=False) + if module.bias is not None: + module.bias = None + module.extra_repr = types.MethodType(_get_linear_extra_repr_for_lut, module) + + else: # For reference, dequantize back to float + dequantized_weight = quantized_tensor.dequantize(config.weight_dtype) + module.weight.data.copy_(dequantized_weight) + + return module diff --git a/torchao/prototype/quantization/codebook_groupwise/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook_groupwise/codebook_quantized_tensor.py new file mode 100644 index 0000000000..8a66434685 --- /dev/null +++ b/torchao/prototype/quantization/codebook_groupwise/codebook_quantized_tensor.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.prototype.quantization.codebook_coreml.codebook_quantized_tensor import ( + CodebookQuantizedTensor, +) +from torchao.prototype.quantization.codebook_utils.codebook_utils import ( + block_shape_to_group_size, +) +from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH +from torchao.utils import TorchAOBaseTensor + +# --- C++ Op Accessor Functions --- + + +def get_pack_op(weight_nbit: int): + """Gets the C++ packing function from the 'torchao' namespace.""" + op_name = f"_pack_groupwise_{weight_nbit}bit_weight_with_lut" + if not hasattr(torch.ops.torchao, op_name): + raise NotImplementedError(f"Packing op for {weight_nbit}-bit not found.") + return getattr(torch.ops.torchao, op_name) + + +def get_linear_op(weight_nbit: int): + """Gets the C++ fused linear function from the 'torchao' namespace.""" + op_name = f"_linear_groupwise_{weight_nbit}bit_weight_with_lut" + if not hasattr(torch.ops.torchao, op_name): + raise NotImplementedError(f"Linear op for {weight_nbit}-bit not found.") + return getattr(torch.ops.torchao, op_name) + + +aten = torch.ops.aten + + +class CodebookQuantizedPackedTensor(TorchAOBaseTensor): + tensor_data_names = [ + "packed_weight", + ] + tensor_attribute_names = [ + "bit_width", + "lut_block_size", + "scale_block_size", + "shape", + "dtype", + ] + + def __new__( + cls, packed_weight, bit_width, lut_block_size, scale_block_size, shape, dtype + ): + kwargs = { + "device": packed_weight.device, + "dtype": dtype, + "requires_grad": False, + } + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__( + self, + packed_weight: torch.Tensor, + bit_width: int, + lut_block_size: List[int], + scale_block_size: Optional[List[int]], + shape: torch.Size, + dtype: torch.dtype, + ): + self.packed_weight = packed_weight + self.bit_width = bit_width + self.lut_block_size = lut_block_size + self.scale_block_size = scale_block_size + + @classmethod + def from_unpacked( + cls, + int_data: torch.Tensor, + luts: torch.Tensor, + scales: Optional[torch.Tensor], + bit_width: int, + lut_block_size: List[int], + scale_block_size: Optional[List[int]], + original_shape: torch.Size, + bias: Optional[torch.Tensor] = None, + ): + lut_group_size = block_shape_to_group_size(lut_block_size, int_data.shape) + + if scale_block_size is not None and scales is not None: + # Scales are present, calculate group size + scale_group_size = block_shape_to_group_size( + scale_block_size, int_data.shape + ) + scales_arg = scales + else: + # Scales are not present, provide safe defaults + scale_group_size = -1 + scales_arg = torch.empty(0, dtype=luts.dtype, device=luts.device) + + pack_op = get_pack_op(bit_width) + packed_weight = pack_op( + int_data, luts, scale_group_size, lut_group_size, scales_arg, bias + ) + return cls( + packed_weight, + bit_width, + lut_block_size, + scale_block_size, + original_shape, + int_data.dtype, + ) + + @classmethod + def from_codebook_quantized_tensor( + cls, + tensor: CodebookQuantizedTensor, + *, + bias: Optional[torch.Tensor] = None, + ): + """ + Factory method to create a packed tensor from a CodebookQuantizedTensor. + + This method takes the general components of a codebook-quantized tensor + (codes, codebook, etc.) and uses a specialized 'pack_op' to fuse them + into a single, efficient tensor format suitable for high-performance + inference kernels. + """ + lut_block_size = tensor.block_size + lut_group_size = block_shape_to_group_size(lut_block_size, tensor.shape) + + # CoreML quantization scheme does not use scales, so they are disabled. + scale_group_size = -1 + scales = None + + bit_width = _DTYPE_TO_BIT_WIDTH[tensor.code_dtype] + # Retrieve the appropriate packing C++/CUDA kernel for the given bit width. + pack_op = get_pack_op(bit_width) + + # Ensure the codebook (Look-Up Table) is in float32, as this is the + # data type expected by the underlying packing kernel. + codebook = tensor.codebook.to(torch.float32) + + # --- Explanation for .squeeze() --- + # The input `tensor.codebook` is often stored in a 4D format, such as + # [1, num_groups, 256, 1], for compatibility with generic operators like + # the dequantize function. However, the specialized `pack_op` expects a + # more compact 2D LUT of shape [num_groups, 256]. + # The .squeeze() operation removes the unnecessary singleton (size 1) + # dimensions to achieve this required 2D format. + codebook = codebook.squeeze() + + # Call the packing operator to create the final fused tensor. + packed_weight = pack_op( + tensor.codes, codebook, scale_group_size, lut_group_size, scales, bias, None + ) + + # Return a new instance of this class containing the final packed weight + # and its associated quantization metadata. + return cls( + packed_weight, bit_width, lut_block_size, None, tensor.shape, tensor.dtype + ) + + +implements = CodebookQuantizedPackedTensor.implements + + +@implements([F.linear]) +def _(func, types, args, kwargs): + """ + Override for `torch.nn.functional.linear` specifically for the + GroupwiseLutQuantizedTensor. This calls the fused C++ kernel. + """ + input_tensor, weight_tensor, _ = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + linear_op = get_linear_op(weight_tensor.bit_width) + lut_group_size = block_shape_to_group_size( + weight_tensor.lut_block_size, weight_tensor.shape + ) + original_shape = input_tensor.shape + k = weight_tensor.shape[1] + if input_tensor.dim() > 2: + input_tensor = input_tensor.reshape(-1, k) + + n = weight_tensor.shape[0] + output = linear_op( + input_tensor, weight_tensor.packed_weight, -1, lut_group_size, n, k + ) + + if len(original_shape) > 2: + output_shape = original_shape[:-1] + (n,) + return output.reshape(output_shape) + return output + + +@implements([aten.detach.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) diff --git a/torchao/prototype/quantization/codebook_utils/__init__.py b/torchao/prototype/quantization/codebook_utils/__init__.py new file mode 100644 index 0000000000..509ae88839 --- /dev/null +++ b/torchao/prototype/quantization/codebook_utils/__init__.py @@ -0,0 +1,17 @@ +from .codebook_utils import ( + block_shape_to_group_size, + dequantize_dispatch, + group_size_to_block_shapes, + load_quantized_data, + quantize_dispatch, + save_quantized_data, +) + +__all__ = [ + "quantize_dispatch", + "dequantize_dispatch", + "save_quantized_data", + "load_quantized_data", + "block_shape_to_group_size", + "group_size_to_block_shapes", +] diff --git a/torchao/prototype/quantization/codebook_utils/codebook_utils.py b/torchao/prototype/quantization/codebook_utils/codebook_utils.py new file mode 100644 index 0000000000..d80292f5c9 --- /dev/null +++ b/torchao/prototype/quantization/codebook_utils/codebook_utils.py @@ -0,0 +1,501 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# core ml support scale.. +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from torchao.prototype.quantization.codebook.codebook_ops import ( + choose_qparams_codebook, + dequantize_codebook, + quantize_codebook, +) +from torchao.prototype.quantization.codebook_coreml.codebook_ops import ( + choose_qparams_and_quantize_codebook_coreml, +) +from torchao.prototype.quantization.codebook_coreml.codebook_ops import ( + dequantize_codebook as dequantize_codebook_coreml, +) +from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH + + +def block_shape_to_group_size(block_shape, tensor_shape): + """Calculates the total number of elements in a group from a block_shape.""" + n_group, k_group = block_shape + n_dim, k_dim = tensor_shape + + if n_group == -1: + n_group = n_dim + if k_group == -1: + k_group = k_dim + + return n_group * k_group + + +def group_size_to_block_shapes( + lut_group_size: int, + tensor_shape: Tuple[int, int], +) -> Tuple[List[int], Optional[List[int]]]: + """ + Translates legacy integer-based group sizes into the new block_shape list format. + + This function encodes the implicit assumptions of the old system: + - LUTs were always grouped by rows. + - Scales were always grouped by columns. + + Args: + lut_group_size (int): The total number of elements that shared a single LUT. + tensor_shape (Tuple[int, int]): The shape of the weight tensor (N, K). + This is required to calculate the number of rows for the LUT group. + + Returns: + A tuple containing: + - lut_block_shape (List[int]): The new block shape for LUTs (e.g., [N, -1]). + - scale_block_shape (Optional[List[int]]): The new block shape for scales + (e.g., [-1, K]), or None. + """ + n_rows, k_cols = tensor_shape + + # --- 1. Translate LUT Group Size --- + if lut_group_size % k_cols != 0: + raise ValueError( + f"lut_group_size ({lut_group_size}) must be divisible by the number " + f"of columns ({k_cols}) for legacy row-grouping." + ) + rows_per_lut = lut_group_size // k_cols + lut_block_shape = [rows_per_lut, -1] + + return lut_block_shape + + +@torch.no_grad() +def _quantize_row_wise_group_with_scales( + input_tensor: torch.Tensor, + rows_per_group: int, + scale_block_shape: List[int], + code_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantizes a 2D tensor using row-wise grouping, with a unique LUT and + set of scales for each group. + + Returns a tuple of (codes, luts, scales) with structured shapes. + """ + assert input_tensor.ndim == 2, "This function expects a 2D tensor." + n_rows, k_cols = input_tensor.shape + assert n_rows % rows_per_group == 0, ( + f"Tensor rows ({n_rows}) must be divisible by rows_per_group ({rows_per_group})." + ) + + num_groups = n_rows // rows_per_group + list_of_luts, list_of_codes, list_of_scales = [], [], [] + + for i in range(num_groups): + start_row = i * rows_per_group + end_row = start_row + rows_per_group + tensor_slice = input_tensor[start_row:end_row, :] + + # This performs scalar quantization (block_size=(1, 1)) on the slice + codebook, scales = choose_qparams_codebook( + tensor_slice, + block_size=(1, 1), + scale_block_size=scale_block_shape[-1], + code_dtype=code_dtype, + ) + + codes = quantize_codebook( + tensor_slice, + codebook, + scales, + code_dtype=code_dtype, + ) + + # Append results without flattening + # Squeeze codebook from (codebook_size, 1, 1) to (codebook_size,) + list_of_luts.append(codebook.squeeze()) + list_of_scales.append(scales) + list_of_codes.append(codes) + + # Concatenate along the row dimension (dim=0) to preserve structure + final_codes = torch.cat(list_of_codes, dim=0) + final_scales = torch.cat(list_of_scales, dim=0) + + # Stack LUTs to create a (num_groups, codebook_size) tensor + final_luts = torch.stack(list_of_luts, dim=0) + final_scales = final_scales.flatten() + return final_codes, final_luts, final_scales + + +@torch.no_grad() +def _dequantize_row_wise_group_with_scales( + codes: torch.Tensor, + luts: torch.Tensor, + scales: torch.Tensor, + rows_per_group: int, + scale_group_size: int, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantizes a 2D tensor that was quantized with `quantize_per_row_group_with_scales`. + + Args: + codes (torch.Tensor): The quantized data codes. + Shape: (total_rows, total_cols) + luts (torch.Tensor): The lookup tables (codebooks) for each group. + Shape: (num_groups, codebook_size) + scales (torch.Tensor): The scale factors for each row. + Shape: (total_rows,) + rows_per_group (int): The number of rows in each quantization group. + output_dtype (torch.dtype): The desired data type for the output tensor. + + Returns: + torch.Tensor: The dequantized tensor. + Shape: (total_rows, total_cols) + """ + assert codes.ndim == 2, "This function expects a 2D codes tensor." + n_rows, k_cols = codes.shape + assert n_rows % rows_per_group == 0, ( + f"Tensor rows ({n_rows}) must be divisible by rows_per_group ({rows_per_group})." + ) + + # Calculate the number of row groups. + # e.g., if n_rows=128 and rows_per_group=4, num_groups=32 + num_groups = n_rows // rows_per_group + assert luts.shape[0] == num_groups, ( + "Mismatch between number of LUTs and row groups." + ) + + # calculate the number of scale blocks per row. + num_scale_blocks = k_cols // scale_group_size + # Reshape the flattened scales back to their original 3D structure. + # Shape: (n_rows, num_scale_blocks, 1) + reshaped_scales = scales.view(n_rows, num_scale_blocks, 1) + + # Pre-allocate the output tensor for efficiency to avoid creating new tensors in the loop. + # Shape: (total_rows, total_cols) + dequantized_tensor = torch.empty_like(codes, dtype=output_dtype) + + # Iterate over each group of rows to dequantize them chunk by chunk. + for i in range(num_groups): + # Calculate the start and end row indices for the current group slice. + start_row = i * rows_per_group + end_row = start_row + rows_per_group + + # Get the slice of codes for the current group. + # Shape: (rows_per_group, total_cols), e.g., (4, 64) + codes_slice = codes[start_row:end_row, :] + # Get the lookup table (codebook) for the current group. + # The LUT is 1D, shape: (codebook_size,), e.g., (2,) for 1-bit quantization. + # Reshape it to the (k, b1, b2) format required by dequantize_codebook. + # For scalar quantization, block sizes b1 and b2 are 1. + # Reshaped Shape: (codebook_size, 1, 1), e.g., (2, 1, 1) + current_lut = luts[i].view(-1, 1, 1) + + # Get the slice of scales corresponding to the rows in this group. + scales_slice = reshaped_scales[start_row:end_row, :, :] + + # Dequantize the slice using the dedicated function. + dequant_slice = dequantize_codebook( + codes=codes_slice, + codebook=current_lut, + scales=scales_slice, + output_dtype=output_dtype, + ) + # The returned `dequant_slice` has its original shape restored. + # Shape: (rows_per_group, total_cols), e.g., (4, 64) + + # Place the dequantized slice into the correct position in the final tensor. + dequantized_tensor[start_row:end_row, :] = dequant_slice + + return dequantized_tensor + + +@torch.no_grad +def quantize_flexible_grouping( + input_tensor: torch.Tensor, + lut_block_shape: List[int], + code_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor, None]: + """ + Quantizes a tensor using either row-wise or column-wise grouping. + + Args: + input_tensor (torch.Tensor): The 2D tensor to be quantized. + Shape: (n_rows, k_cols) + lut_block_shape (List[int]): Defines the grouping strategy. + - To group by columns: `[-1, k_group]`. + - To group by rows: `[n_group, -1]`. + code_dtype (torch.dtype): The dtype for the codes (e.g., torch.uint4). + + Returns: + A tuple containing the quantized codes, the lookup tables, and None. + - final_codes (torch.Tensor): Quantized data of shape (n_rows, k_cols). + - final_luts (torch.Tensor): The codebook of lookup tables. + Shape: (num_groups, 2**nbits), where num_groups depends on the strategy. + - None: Placeholder for scales, which are not computed. + """ + assert input_tensor.ndim == 2, "This function expects a 2D tensor." + assert len(lut_block_shape) == 2, ( + "lut_block_shape must have two elements for a 2D tensor." + ) + n_rows, k_cols = input_tensor.shape + n_group, k_group = lut_block_shape + + # STRATEGY 1: Group by ROWS (e.g., block_size = [2, -1]) + if n_group != -1 and k_group == -1: + assert n_rows % n_group == 0, ( + f"Tensor rows ({n_rows}) must be divisible by row group size ({n_group})." + ) + list_of_luts, list_of_codes = [], [] + for i in range(0, n_rows, n_group): + tensor_slice = input_tensor[i : i + n_group, :] + lut, codes = choose_qparams_and_quantize_codebook_coreml( + input_tensor=tensor_slice, + code_dtype=code_dtype, + block_size=[-1, -1], + ) + list_of_luts.append(lut) + list_of_codes.append(codes) + + # Concatenate and remove singleton dimensions + final_luts = torch.cat(list_of_luts, dim=0).squeeze() + final_codes = torch.cat(list_of_codes, dim=0) + return final_codes, final_luts, None + + # STRATEGY 2: Group by COLUMNS (e.g., block_size = [-1, 64]) + elif n_group == -1: + if k_group != -1: + assert k_cols % k_group == 0, ( + f"Tensor cols ({k_cols}) must be divisible by col group size ({k_group})." + ) + luts, codes = choose_qparams_and_quantize_codebook_coreml( + input_tensor=input_tensor, + code_dtype=code_dtype, + block_size=lut_block_shape, + ) + # Remove singleton dimensions + final_luts = luts.squeeze() + final_codes = codes + return final_codes, final_luts, None + + # Unsupported strategy + else: + raise NotImplementedError( + f"lut_block_shape pattern '{lut_block_shape}' is not supported." + ) + + +@torch.no_grad +def dequantize_with_flexible_grouping( + codes: torch.Tensor, + luts: torch.Tensor, + lut_block_shape: List[int], + code_dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + assert codes.ndim == 2, "This function expects a 2D codes tensor." + n_rows, k_cols = codes.shape + n_group, k_group = lut_block_shape + + # STRATEGY 1: Grouping was by COLUMNS (e.g., block_shape = [-1, 64]) + if n_group == -1: + return dequantize_codebook_coreml( + codes=codes, + codebook=luts, + code_dtype=code_dtype, + block_size=lut_block_shape, + output_dtype=output_dtype, + ) + + # STRATEGY 2: Grouping was by ROWS (e.g., block_shape = [2, -1]) + elif n_group != -1 and k_group == -1: + assert n_rows % n_group == 0, ( + f"Tensor rows ({n_rows}) must be divisible by row group size ({n_group})." + ) + num_groups = n_rows // n_group + dequantized_tensor = torch.empty_like(codes, dtype=output_dtype) + + for i in range(num_groups): + start_row, end_row = i * n_group, (i + 1) * n_group + + # Get the chunk of codes and the single LUT for that chunk + codes_slice = codes[start_row:end_row, :] + current_lut = luts[i] + + # To dequantize a chunk with a *single* LUT, we tell the primitive + # that the block_size should cover all columns (k_cols). + dequant_slice = dequantize_codebook_coreml( + codes=codes_slice, + # The primitive expects a 2D LUT of shape (num_luts, ...). + # Since we have one LUT, we must add a dimension. + codebook=current_lut.unsqueeze(0), + code_dtype=code_dtype, + block_size=[-1, k_cols], + output_dtype=output_dtype, + ) + dequantized_tensor[start_row:end_row, :] = dequant_slice + return dequantized_tensor + + else: + raise NotImplementedError( + f"lut_block_shape pattern '{lut_block_shape}' is not supported." + ) + + +def quantize_dispatch( + input_tensor: torch.Tensor, + lut_block_shape: List[int], + code_dtype: torch.dtype, + scale_block_shape: Optional[List[int]] = None, # Make this optional + backend: str = "auto", +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Single entry point for quantization that dispatches to the correct backend. + + This function uses lut_block_shape to determine the quantization strategy, + allowing for flexible grouping by either rows or columns. + + Args: + input_tensor (torch.Tensor): The 2D tensor to be quantized (N, K). + lut_block_shape (List[int]): Defines the grouping for the look-up table. + - To group by N rows: use `[N, -1]`. + - To group by K columns: use `[-1, K]`. + code_dtype (torch.dtype): The target dtype for the codes (e.g., torch.uint4). + scale_block_shape (Optional[List[int]]): Defines grouping for scale factors, + used only by the 'scale' backend. E.g., `[-1, 64]`. If provided, + the 'scale' backend is used in "auto" mode. Defaults to None. + backend (str): The quantization backend to use. Can be "auto", "coreml", + or "scale". "auto" chooses based on whether `scale_block_shape` is provided. + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: A tuple + containing the (codes, luts, scales). Scales will be None for the + 'coreml' backend. + """ + # Determine which backend to use based on if scale_block_shape is provided. + if backend == "auto": + backend = "scale" if scale_block_shape is not None else "coreml" + + # Dispatch to the appropriate backend implementation + if backend == "scale": + if scale_block_shape is None: + raise ValueError( + "'scale' backend requires a `scale_block_shape` to be set." + ) + + # The 'scale' backend only supports row-grouping for the LUT. + # We derive the rows_per_group from the lut_block_shape parameter. + n_group, k_group = lut_block_shape + if n_group == -1 or k_group != -1: + raise ValueError( + "The 'scale' backend currently only supports row-grouping for LUTs. " + "Please use a `lut_block_shape` of `[N, -1]`." + ) + rows_per_lut_group = n_group + + codes, luts, scales = _quantize_row_wise_group_with_scales( + input_tensor, + rows_per_lut_group, + scale_block_shape, + code_dtype, + ) + + elif backend == "coreml": + codes, luts, scales = quantize_flexible_grouping( + input_tensor, lut_block_shape, code_dtype + ) + + else: + raise ValueError(f"Unknown backend: {backend}") + + luts = luts.to(torch.float32) + return codes, luts, scales + + +def dequantize_dispatch( + codes: torch.Tensor, + luts: torch.Tensor, + scales: Optional[torch.Tensor], + lut_block_shape: List[int], + scale_block_shape: Optional[List[int]] = None, + backend: str = "auto", + code_dtype: torch.dtype = torch.int4, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Single entry point for dequantization that dispatches to the correct backend. + (Updated to use flexible block shapes). + """ + if backend == "auto": + # Use presence of scales to determine backend + backend = "scale" if scales is not None else "coreml" + + if backend == "scale": + # For backward compatibility, derive old integer args from new block shapes + if scale_block_shape is None: + raise ValueError("'scale' backend requires a `scale_block_shape`.") + + n_group, k_group = lut_block_shape + if k_group != -1: + raise ValueError( + "Scale dequant backend only supports row-grouped LUTs ([N, -1])." + ) + rows_per_lut_group = n_group + + scale_n_group, scale_k_group = scale_block_shape + if scale_n_group != 1: + raise ValueError( + "Scale dequant backend only supports col-grouped scales ([1, K])." + ) + scale_group_size = scale_k_group + + return _dequantize_row_wise_group_with_scales( + codes, + luts, + scales, + rows_per_lut_group, + scale_group_size, + output_dtype=output_dtype, + ) + + elif backend == "coreml": + # Perform grouping along rows, reshape the [Rows per group, 2**nbits] LUTs + # to [1, Rows per group, 2**nbits, 1] for the dequantize primitive. + num_luts = luts.shape[0] + lut_size = luts.shape[1] + luts_4d = luts.reshape(num_luts, 1, lut_size, 1) + return dequantize_codebook_coreml( + codes, + luts_4d, + _DTYPE_TO_BIT_WIDTH[code_dtype], + lut_block_shape, + output_dtype=output_dtype, + ) + + else: + raise ValueError(f"Unknown backend: {backend}") + + +def save_quantized_data(data: Dict[str, Any], filepath: str): + """ + Saves the dictionary of quantized tensors to a file. + """ + # Create the directory if it doesn't exist + os.makedirs(os.path.dirname(filepath), exist_ok=True) + torch.save(data, filepath) + print(f"Saved quantization results to '{filepath}'") + + +def load_quantized_data(filepath: str) -> Optional[Dict[str, Any]]: + """ + Loads the dictionary of quantized tensors from a file if it exists. + """ + if not os.path.exists(filepath): + return None + data = torch.load(filepath) + print(f"Loaded quantization results from cache: '{filepath}'") + return data diff --git a/torchao/prototype/quantization/embedding/__init__.py b/torchao/prototype/quantization/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/quantization/embedding/api.py b/torchao/prototype/quantization/embedding/api.py new file mode 100644 index 0000000000..a5712782c2 --- /dev/null +++ b/torchao/prototype/quantization/embedding/api.py @@ -0,0 +1,420 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import sys +from typing import Callable, List, Mapping, Optional, Tuple + +import torch +import torch.nn as nn + +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + _is_kernel_library_loaded, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +from torchao.quantization.granularity import Granularity, PerAxis, PerGroup +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + MappingType, + quantize_, +) +from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH + + +class QuantizedEmbedding(nn.Module): + def __init__( + self, + bit_width, + ): + super().__init__() + self.bit_width = bit_width + + def quantize_and_pack_weights(self, weights, group_size, mapping_type): + num_embeddings, embedding_dim = weights.shape + + embedding = torch.nn.Embedding(num_embeddings, embedding_dim) + embedding.weight = weights + quantize_( + embedding, + IntxWeightOnlyConfig( + weight_dtype=getattr(torch, f"int{self.bit_width}"), + granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0), + mapping_type=mapping_type, + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + weight_qvals = embedding.weight.qdata + weight_scales = embedding.weight.scale + weight_zeros = embedding.weight.zero_point + + assert weight_zeros is not None + weight_scales = weight_scales.reshape(num_embeddings, -1) + weight_zeros = weight_zeros.reshape(num_embeddings, -1).to(torch.int8) + self.register_buffer( + "packed_weight_qvals", + getattr(torch.ops.torchao, f"_pack_embedding_{self.bit_width}bit")( + weight_qvals.to(torch.int8) + ), + ) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.register_buffer("weight_scales", weight_scales) + self.register_buffer("weight_zeros", weight_zeros) + + def forward(self, x): + shape = x.shape + return getattr(torch.ops.torchao, f"_embedding_{self.bit_width}bit")( + self.packed_weight_qvals, + self.num_embeddings, + self.embedding_dim, + self.weight_scales, + # embedding op requires weight_zeros be passed, even if they are all 0 + self.weight_zeros, + x.reshape(-1), + ).reshape(*shape, -1) + + +class QuantizedEmbeddingFallback(nn.Module): + def __init__( + self, + bit_width, + ): + super().__init__() + self.bit_width = bit_width + + def quantize_and_pack_weights(self, weights, group_size, mapping_type): + self.embedding = torch.nn.Embedding(*weights.shape) + self.embedding.weight = weights + quantize_( + self.embedding, + IntxWeightOnlyConfig( + weight_dtype=getattr(torch, f"int{self.bit_width}"), + granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0), + mapping_type=mapping_type, + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + def forward(self, x): + return self.embedding(x) + + +class QuantizedTiedEmbedding(nn.Module): + def __init__(self, bit_width, unembedding_packed_weights, group_size, n, k): + super().__init__() + self.bit_width = bit_width + self.register_buffer("unembedding_packed_weights", unembedding_packed_weights) + self.n = n + self.k = k + if group_size == -1: + self.group_size = k + else: + self.group_size = group_size + self.shared_embedding_op = getattr( + torch.ops.torchao, f"_shared_embedding_{bit_width}bit" + ) + + def forward(self, x): + shape = x.shape + return self.shared_embedding_op( + self.unembedding_packed_weights, + self.group_size, + self.n, + self.k, + x.reshape(-1), + ).reshape(*shape, -1) + + +def _replace_embedding_with_quantized_embedding( + module: nn.Module, + kwargs={}, + fqn: str = "", +): + group_size = kwargs.get("group_size", None) + bit_width = kwargs.get("bit_width", None) + use_fallback = kwargs.get("use_fallback", None) + mapping_type = kwargs.get("mapping_type", None) + + assert not isinstance(module, nn.Embedding) + for name, child in module.named_children(): + child_fqn = f"{fqn}.{name}" if fqn != "" else name + + if not isinstance(child, nn.Embedding): + _replace_embedding_with_quantized_embedding(child, kwargs, child_fqn) + else: + assert child.weight.device == torch.device("cpu"), "Only CPU is supported" + assert child.weight.dtype == torch.float32, "Only float32 is supported" + + if use_fallback: + qembedding = QuantizedEmbeddingFallback(bit_width) + setattr(module, name, qembedding) + getattr(module, name).quantize_and_pack_weights( + child.weight, + group_size, + mapping_type, + ) + else: + assert _is_kernel_library_loaded(), ( + "torchao kernel library is not loaded" + ) + qembedding = QuantizedEmbedding(bit_width) + setattr(module, name, qembedding) + getattr(module, name).quantize_and_pack_weights( + child.weight, + group_size, + mapping_type, + ) + + +class EmbeddingQuantizer: + def __init__( + self, + weight_dtype: torch.dtype = torch.int4, + granularity: Granularity = PerAxis(0), + mapping_type: MappingType = MappingType.ASYMMETRIC, + use_fallback: bool = False, + ): + assert weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)] + bit_width = _DTYPE_TO_BIT_WIDTH[weight_dtype] + + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + elif isinstance(granularity, PerAxis): + assert granularity.axis == 0 + group_size = -1 + else: + raise ValueError(f"Unsupported granularity: {granularity}") + + self.bit_width = bit_width + self.group_size = group_size + self.use_fallback = use_fallback + self.mapping_type = mapping_type + + def quantize(self, model: nn.Module) -> nn.Module: + _replace_embedding_with_quantized_embedding( + model, + kwargs={ + "group_size": self.group_size, + "bit_width": self.bit_width, + "use_fallback": self.use_fallback, + "mapping_type": self.mapping_type, + }, + ) + return model + + +def _get_fqns_with_filter( + module: nn.Module, + filter_fn: Callable[Tuple[str, nn.Module], bool], + fqn: str, + fqns: List[str], +): + for name, child in module.named_children(): + child_fqn = f"{fqn}.{name}" if fqn != "" else name + if filter_fn(child, child_fqn): + fqns.append(child_fqn) + else: + _get_fqns_with_filter(child, filter_fn, child_fqn, fqns) + + +def get_fqns_with_filter( + module: nn.Module, filter_fn: Callable[Tuple[str, nn.Module], bool] +) -> List[str]: + fqns = [] + _get_fqns_with_filter(module, filter_fn, "", fqns) + return fqns + + +class QuantizedLinear(nn.Module): + def __init__(self, packed_weight, n, k, group_size, bit_width, bias): + super().__init__() + self.register_buffer("packed_weight", packed_weight) + self.n = n + self.k = k + self.group_size = group_size + self.bit_width = bit_width + self.bias = bias + + def _forward_2d(self, x): + assert x.dim() == 2 + m, k = x.shape + assert k == self.k + return getattr( + torch.ops.torchao, f"_linear_8bit_act_{self.bit_width}bit_weight" + )(x, self.packed_weight, self.group_size, self.n, self.k) + + def forward(self, x): + if x.dim() == 2: + res = self._forward_2d(x) + else: + assert x.dim() >= 3 + lead_shape = x.shape[0:-2] + m, k = x.shape[-2], x.shape[-1] + assert k == self.k + res = self._forward_2d(x.reshape(-1, k)) + res = res.reshape(*lead_shape, m, self.n) + + if self.bias is not None: + res = res + self.bias + return res + + +def get_parent_by_fqn(root: nn.Module, fqn: str): + parts = fqn.split(".") + if len(parts) == 1: + # e.g. "fqn" → parent is root, child is "fqn" + return root, parts[0] + + parent_fqn = ".".join(parts[:-1]) + child_name = parts[-1] + parent = dict(root.named_modules()).get(parent_fqn, None) + if parent is None: + raise KeyError(f"Parent module {parent_fqn} not found in model") + return parent, child_name + + +class TiedEmbeddingQuantizer: + def __init__( + self, + weight_dtype: torch.dtype = torch.int4, + granularity: Granularity = PerAxis(0), + mapping_type: MappingType = MappingType.ASYMMETRIC, + ): + self.weight_dtype = weight_dtype + self.granularity = granularity + self.mapping_type = mapping_type + + def quantize( + self, + model: nn.Module, + embedding_to_unembedding: Optional[Mapping[str, str]] = None, + ): + embedding_fqns = get_fqns_with_filter( + model, lambda m, fqn: isinstance(m, nn.Embedding) + ) + linear_fqns = get_fqns_with_filter( + model, lambda m, fqn: isinstance(m, nn.Linear) + ) + state_dict = model.state_dict() + + # If embedding_to_unembedding is not provided, automatically detect shared embeddings and unembeddings + if embedding_to_unembedding is None: + embedding_to_unembedding = {} + for embedding_fqn in embedding_fqns: + embedding_w = state_dict[embedding_fqn + ".weight"] + for linear_fqn in linear_fqns: + linear_w = state_dict[linear_fqn + ".weight"] + if embedding_w.shape == linear_w.shape and torch.allclose( + embedding_w, linear_w + ): + print( + f"Found shared embedding {embedding_fqn} and unembedding {linear_fqn}" + ) + if embedding_fqn not in embedding_to_unembedding: + embedding_to_unembedding[embedding_fqn] = linear_fqn + else: + raise ValueError( + f"Found multiple candidate unembeddings ({embedding_to_unembedding[embedding_fqn]}, {linear_fqn}) for embedding {embedding_fqn}. This is not supported yet. Please explicitly define the input embedding_to_unembedding." + ) + + # Construct reverse mapping + unembedding_to_embedding = {} + for v, k in embedding_to_unembedding.items(): + if k not in unembedding_to_embedding: + unembedding_to_embedding[k] = v + else: + raise ValueError( + f"Found multiple candidate embeddings ({unembedding_to_embedding[k]}, {v}) for unembedding {k}. This is not supported yet." + ) + + # Check that embeddings are shared, embeddings are embeddings, and unembeddings are linear ops + for embedding_fqn, unembedding_fqn in embedding_to_unembedding.items(): + assert embedding_fqn in embedding_fqns, ( + f"Embedding {embedding_fqn} is not found in model" + ) + assert unembedding_fqn in linear_fqns, ( + f"Unembedding {unembedding_fqn} is not found in model" + ) + assert torch.allclose( + state_dict[embedding_fqn + ".weight"], + state_dict[unembedding_fqn + ".weight"], + ), ( + f"Embedding {embedding_fqn} does not share weights with unembedding {unembedding_fqn}" + ) + + # Quantize unembeddings + quantize_( + model, + Int8DynamicActivationIntxWeightConfig( + weight_dtype=self.weight_dtype, + weight_granularity=self.granularity, + weight_mapping_type=self.mapping_type, + # Only universal layout is supported for shared embedding + intx_packing_format="opaque_torchao_lowbit", + ), + filter_fn=lambda m, fqn: isinstance(m, nn.Linear) + and fqn in list(embedding_to_unembedding.values()), + ) + + embedding_fqn_to_quantized_unembedding = {} + for fqn, t in model.state_dict().items(): + if ( + fqn.endswith(".weight") + and fqn[: -len(".weight")] in unembedding_to_embedding + ): + embedding_fqn = unembedding_to_embedding[fqn[: -len(".weight")]] + embedding_fqn_to_quantized_unembedding[embedding_fqn] = t + + for embedding_fqn, unembedding_fqn in embedding_to_unembedding.items(): + weight = embedding_fqn_to_quantized_unembedding[embedding_fqn] + n, k = weight.shape + group_size = weight.block_size[1] + packed_weight = weight.packed_weights + bit_width = weight.bit_width + + # Set embedding + parent, child_name = get_parent_by_fqn(model, embedding_fqn) + child = getattr(parent, child_name) + assert n == child.num_embeddings, ( + "num_embeddings must match n in shared_unembedding" + ) + assert k == child.embedding_dim, ( + "embedding_dim must match k in shared_unembedding" + ) + setattr( + parent, + child_name, + QuantizedTiedEmbedding( + bit_width, + packed_weight, + group_size, + n, + k, + ), + ) + + # Set unembedding + parent, child_name = get_parent_by_fqn(model, unembedding_fqn) + child = getattr(parent, child_name) + if weight.packed_weights_has_bias: + assert child.bias is None + setattr( + parent, + child_name, + QuantizedLinear(packed_weight, n, k, group_size, bit_width, child.bias), + ) diff --git a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py index c1272fceb6..f26083b90d 100644 --- a/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py +++ b/torchao/prototype/quantization/gguf/gguf_quantized_tensor.py @@ -14,10 +14,7 @@ _dequantize_gguf, _quantize_gguf, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor _QK_K = 256 aten = torch.ops.aten @@ -267,6 +264,5 @@ def _(func, types, args, kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([GGUFQuantizedTensor]) +# Allow a model with GGUFQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([GGUFQuantizedTensor]) diff --git a/torchao/prototype/quantization/int8_lut_tensor/__init__.py b/torchao/prototype/quantization/int8_lut_tensor/__init__.py new file mode 100644 index 0000000000..dd53868182 --- /dev/null +++ b/torchao/prototype/quantization/int8_lut_tensor/__init__.py @@ -0,0 +1,5 @@ +from .int8_lut_tensor import Int8LutTensor + +__all__ = [ + "Int8LutTensor", +] diff --git a/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py b/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py new file mode 100644 index 0000000000..a4feee13aa --- /dev/null +++ b/torchao/prototype/quantization/int8_lut_tensor/int8_lut_tensor.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch + +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_BIT_WIDTH, + _DTYPE_TO_QVALUE_BOUNDS, +) +from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + _is_kernel_library_loaded, +) +from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( + IntxUnpackedToInt8Tensor, + IntxUnpackedToInt8TensorActivationQuantization, +) +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + + +class Int8LutTensor(TorchAOBaseTensor): + """ + Tensor subclass that does int8 dynamic activation quantization with lookup table quantization + """ + + tensor_data_names = ["packed_weights"] + tensor_attribute_names = [ + "bit_width", + "block_size", + "shape", + "dtype", + "packed_weights_has_bias", + ] + + def __new__( + cls, + packed_weights, + bit_width, + block_size, + shape, + dtype, + packed_weights_has_bias, + ): + kwargs = {} + kwargs["device"] = packed_weights.device + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weights, + bit_width, + block_size, + shape, + dtype, + packed_weights_has_bias, + ): + super().__init__() + assert packed_weights.device == torch.device("cpu") + self.packed_weights = packed_weights + self.bit_width = bit_width + self.block_size = block_size + self.packed_weights_has_bias = packed_weights_has_bias + + def _quantization_type(self): + return f"bit_width={self.bit_width}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}" + + def to(self, *args, **kwargs): + raise NotImplementedError("to() is not implemented for IntxOpaqueTensor") + + @classmethod + def _get_lut_params(cls, tensor: IntxUnpackedToInt8Tensor): + assert isinstance(tensor, IntxUnpackedToInt8Tensor) + assert tensor.target_dtype in [torch.int1, torch.int2, torch.int3, torch.int4] + + qdata = tensor.qdata + scale = tensor.scale + zero_point = tensor.zero_point + + if tensor._has_float_zero_point(): + # Stretched tensors from PARQ should have -0.5 has zero_point + assert torch.all(zero_point == -0.5) + is_stretched_tensor = True + else: + assert torch.all(zero_point == 0) + is_stretched_tensor = False + + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[tensor.target_dtype] + lut_indices = qdata - quant_min + lut = torch.arange(quant_min, quant_max + 1) + + # Construct LUT as 2 * ([q_min, q_max] - 0.5) + if is_stretched_tensor: + lut = 2 * lut + 1 + scale = 0.5 * scale + + # LUT must be 2D and int8 + lut = lut.reshape(1, -1).to(torch.int8) + + # Scale must be 1D and float32 + scale = scale.reshape(-1).to(torch.float32) + + return lut, lut_indices, scale + + @classmethod + def from_intx_unpacked_to_int8_tensor( + cls, + tensor: IntxUnpackedToInt8Tensor, + *, + bias: Optional[torch.Tensor] = None, + ): + """ + Constructs a Int8LutTensor from an IntxUnpackedToInt8Tensor. + If bias is passed, bias is packed into the tensor. + """ + + assert _is_kernel_library_loaded(), "TorchAO kernel library is not loaded" + assert ( + tensor.activation_quantization + == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN + ), ( + "IntxUnpackedToInt8Tensor must have INT8_ASYM_PER_TOKEN activation quantization" + ) + + assert len(tensor.block_size) == 2 + assert tensor.block_size[0] == 1 + scale_group_size = tensor.block_size[1] + + packed_weights_has_bias = bias is not None + if packed_weights_has_bias: + n, k = tensor.shape + assert bias.shape == (n,) + bias = bias.to(torch.float32) + + lut, lut_indices, scale = cls._get_lut_params(tensor) + bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype] + packed_weights = getattr( + torch.ops.torchao, f"_pack_8bit_act_{bit_width}bit_weight_with_lut" + )( + lut_indices, + lut, + scale, + scale_group_size, + bias, + None, + ) + + block_size = [b for b in tensor.block_size] + shape = tensor.shape + bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype] + return cls( + packed_weights, + bit_width, + block_size, + shape, + tensor.dtype, + packed_weights_has_bias, + ) + + +implements = Int8LutTensor.implements + + +def _linear_impl_2d( + input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor +): + assert isinstance(weight_tensor, Int8LutTensor) + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + + packed_weights = weight_tensor.packed_weights + bit_width = weight_tensor.bit_width + + if weight_tensor.dtype != torch.float32: + input_tensor = input_tensor.to(torch.float32) + + res = getattr( + torch.ops.torchao, + f"_linear_8bit_act_{bit_width}bit_weight", + )( + input_tensor, + packed_weights, + group_size, + n, + k, + ) + if weight_tensor.dtype != torch.float32: + res = res.to(weight_tensor.dtype) + + return res + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + # TODO: why was this added https://github.com/pytorch/ao/pull/2043 + if input_tensor.numel() == 0: + return input_tensor + + if input_tensor.dim() == 1: + k = input_tensor.shape[0] + input_tensor = input_tensor.reshape(1, k) + res = _linear_impl_2d(input_tensor, weight_tensor, bias) + res = res.reshape(-1) + elif input_tensor.dim() == 2: + res = _linear_impl_2d(input_tensor, weight_tensor, bias) + else: + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + res = _linear_impl_2d(input_tensor.reshape(-1, k), weight_tensor, bias) + res = res.reshape(*lead_shape, m, n) + + if bias is not None: + assert not weight_tensor.packed_weights_has_bias + res = res + bias + + return res + + +# Allow a model with Int8LutTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int8LutTensor]) diff --git a/torchao/prototype/quantization/mixed_precision/scripts/fit.py b/torchao/prototype/quantization/mixed_precision/scripts/fit.py index d8e6be4550..bf663cb1c4 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/fit.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/fit.py @@ -84,7 +84,7 @@ def main(max_seqlen, checkpoint, nsamples, max_iter, num_layers): # have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M model = transformers.AutoModelForCausalLM.from_pretrained( - checkpoint, torch_dtype=torch.bfloat16 + checkpoint, dtype=torch.bfloat16 ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) model = model.to(device) diff --git a/torchao/prototype/quantization/mixed_precision/scripts/hessian_grad.py b/torchao/prototype/quantization/mixed_precision/scripts/hessian_grad.py index 1e7b403e3d..df811829a3 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/hessian_grad.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/hessian_grad.py @@ -130,7 +130,7 @@ def main(layer_id, checkpoint, max_seqlen, max_iter, nsamples): with sdpa_kernel(SDPBackend.MATH): # have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M model = transformers.AutoModelForCausalLM.from_pretrained( - checkpoint, torch_dtype=torch.bfloat16 + checkpoint, dtype=torch.bfloat16 ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) model = model.cuda() diff --git a/torchao/prototype/quantization/mixed_precision/scripts/hessian_vhp.py b/torchao/prototype/quantization/mixed_precision/scripts/hessian_vhp.py index faf46b01eb..2d0a2fb735 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/hessian_vhp.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/hessian_vhp.py @@ -100,7 +100,7 @@ def f(*new_params): with sdpa_kernel(SDPBackend.MATH): # have been tested models Llama-3-8B, Llama-2-7B, Mistral-7B, and stories110M model = transformers.AutoModelForCausalLM.from_pretrained( - checkpoint, torch_dtype=torch.bfloat16 + checkpoint, dtype=torch.bfloat16 ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) model = model.to(device) diff --git a/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py b/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py index 016b6c9eef..2174e7683a 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py @@ -101,11 +101,11 @@ def apply_intN_weight_only_quant_sym(weight): assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" if n == 8: raise AssertionError( - "Someone needs to refactor this code to handle int8_weight_only again" + "Someone needs to refactor this code to handle Int8WeightOnlyConfig again" ) elif n == 4: raise AssertionError( - "Someone needs to refactor this code to handle int4_weight_only again" + "Someone needs to refactor this code to handle Int4WeightOnlyConfig again" ) else: if symmetric: diff --git a/torchao/prototype/quantization/mixed_precision/scripts/utils.py b/torchao/prototype/quantization/mixed_precision/scripts/utils.py index 5a47664200..b1e0cbca8f 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/utils.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/utils.py @@ -105,9 +105,9 @@ def cal_model_size(model, fqn_to_config): def load_model(repo_id, device): tokenizer = AutoTokenizer.from_pretrained(repo_id) - model = AutoModelForCausalLM.from_pretrained( - repo_id, torch_dtype=torch.bfloat16 - ).to(device=device) + model = AutoModelForCausalLM.from_pretrained(repo_id, dtype=torch.bfloat16).to( + device=device + ) return model, tokenizer diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 6b438ca787..1eaaacd1db 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -29,7 +29,7 @@ def quantize_int8_rowwise( probability of rounding up is equal to x - ⌊x⌋, which indicates how close the value is to the next integer value. Thus, stochastic rounding also approximates the floating point value exactly. - Currently this function differs from AQT's `int8_weight_only()` in the following way: + Currently this function differs from AQT's `Int8WeightOnlyConfig()` in the following way: 1. Precision: AQT keeps original dtype when doing quantization, while this function upcasts input to FP32 before quantization. Output scale maintains the original input dtype. 2. Calculate scale: AQT uses `input.abs().amax() / 127.5`, while `input.abs().amax() / 127` is diff --git a/torchao/prototype/safetensors/__init__.py b/torchao/prototype/safetensors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/safetensors/safetensors_support.py b/torchao/prototype/safetensors/safetensors_support.py new file mode 100644 index 0000000000..19943e4b4a --- /dev/null +++ b/torchao/prototype/safetensors/safetensors_support.py @@ -0,0 +1,167 @@ +import json +import logging +from typing import Any, Dict + +import torch + +from torchao.prototype.safetensors.safetensors_utils import ( + Float8TensorAttributeJSONEncoder, + object_from_dict, +) +from torchao.quantization import Float8Tensor + +logger: logging.Logger = logging.getLogger(__name__) + + +def unflatten_tensor_state_dict( + tensors_data_dict: Dict[str, Any], + metadata: Dict[str, Any], +): + """ + Reconstructs tensor subclass state dict from provided torch.Tensor data and metadata dictionary + The naming of metadata is so that it is consistent with safetensors naming to avoid confusion + This function is used after loading in previously saved model state dict (using safetensors.save_file) to reconstruct tensor subclass structure + + For example, given a previously flattened tensors_data_dict and metadata: + tensors_data_dict = { + '0.weight:qdata': torch.Tensor(...), + '0.weight:scale': torch.Tensor(...), + '0.bias:_data': torch.Tensor(...), + } + metadata = { + '0.weight': { + '_type': 'Float8Tensor', + '_data': { + 'block_size': [1,32], + ... + } + } + '0.bias': { + '_type': 'torch.Tensor', + } + 'tensor_names': ['0.weight', '0.bias'] + } + + We recover the structure of the original state dict: + tensor_dict = { + '0.weight': Float8Tensor( + qdata=torch.Tensor(...), + scale=torch.Tensor(...), + block_size=[1,32], + ...), + '0.bias': torch.Tensor(...), + } + + Args: + tensors_data_dict: a dictionary from "tensor_name:tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance + metadata: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance + + Returns: + Dictionary of reconstructed tensor subclasses + """ + combined_data = {**tensors_data_dict, **metadata} + + if "tensor_names" not in metadata: + raise ValueError("No tensors found") + + tensor_names = json.loads(metadata["tensor_names"]) + result = {} + + for tensor_name in tensor_names: + tensor_tensors = {} + for key, value in combined_data.items(): + if key.startswith(f"{tensor_name}:"): + # Remove the prefix + tensor_tensors[key[len(tensor_name) + 1 :]] = value + + tensor_metadata = json.loads(metadata.get(tensor_name)) + tensor_type = tensor_metadata.get("_type") + + if tensor_type == Float8Tensor.__name__: + tensor_metadata["_data"].update(tensor_tensors) + result[tensor_name] = object_from_dict(tensor_metadata) + elif tensor_type == torch.Tensor.__name__: + result[tensor_name] = tensor_tensors["_data"] + else: + raise ValueError(f"Unsupported tensor type: {tensor_type}") + + return result + + +def flatten_tensor_state_dict( + tensors_dict: Dict[str, Dict[str, torch.Tensor]], +): + """ + Flattens a dictionary of tensor subclasses so that it is compatible with safetensors.save_file + We disconstruct tensor subclass structure into torch.Tensor data and metadata dictionary + The naming of metadata is so that it is consistent with safetensors naming to avoid confusion + + For example, given something like: + tensor_dict = { + '0.weight': Float8Tensor( + qdata=torch.Tensor(...), + scale=torch.Tensor(...), + block_size=[1,32], + ...), + '0.bias': torch.Tensor(...), + } + + We flatten this to: + tensors_data = { + '0.weight:qdata': torch.Tensor(...), + '0.weight:scale': torch.Tensor(...), + '0.bias:_data': torch.Tensor(...), + } + metadata = { + '0.weight': { + '_type': 'Float8Tensor', + '_data': { + 'block_size': [1,32], + ... + } + } + '0.bias': { + '_type': 'torch.Tensor', + } + 'tensor_names': ['0.weight', '0.bias'] + } + + Args: + tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names + + Returns: + A tuple of (tensors_data, metadata) where + tensors_data: Dict[str, torch.Tensor] contains the tensor data + metadata: Dict[str, str] contains accompanying metadata from tensor subclass + This structure is compatible with safetensors.save_file + """ + + metadata = {} + tensors_data_dict = {} + + for tensor_name, tensor in tensors_dict.items(): + if isinstance(tensor, Float8Tensor): + tensor_dict = {} + for tensor_data_name in tensor.tensor_data_names: + tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name) + + tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder) + elif type(tensor) is torch.Tensor: + tensor_dict = {"_data": tensor} + tensor_metadata = json.dumps({"_type": torch.Tensor.__name__}) + else: + raise ValueError(f"Unsupported tensor type: {type(tensor)}") + + # Clone tensors to avoid memory sharing issues + prefixed_tensors_dict = { + f"{tensor_name}:{key}": ( + value.detach().clone() if isinstance(value, torch.Tensor) else value + ) + for key, value in tensor_dict.items() + } + + metadata[tensor_name] = tensor_metadata + tensors_data_dict.update(prefixed_tensors_dict) + + metadata["tensor_names"] = json.dumps(list(tensors_dict.keys())) + return tensors_data_dict, metadata diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py new file mode 100644 index 0000000000..eb0258a505 --- /dev/null +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -0,0 +1,196 @@ +import dataclasses +import enum +import json +from typing import Any, Dict + +import torch + +import torchao +from torchao.quantization import Float8Tensor +from torchao.quantization.quantize_.common import KernelPreference +from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs + +ALLOWED_CLASSES = { + "Float8Tensor": Float8Tensor, + "Float8MMConfig": torchao.float8.inference.Float8MMConfig, + "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, + "PerRow": torchao.quantization.PerRow, + "PerTensor": torchao.quantization.PerTensor, + "KernelPreference": KernelPreference, +} + +ALLOWED_TENSORS = ["Float8Tensor", "Tensor"] + +__all__ = [ + "Float8TensorAttributeJSONEncoder", + "object_from_dict", + "is_metadata_torchao", +] + + +class Float8TensorAttributeJSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Float8Tensor): + tensor_attr_dict = {} + all_tensor_attributes = ( + o.optional_tensor_attribute_names + o.tensor_attribute_names + ) + + for tensor_attribute_name in all_tensor_attributes: + attribute = getattr(o, tensor_attribute_name) + encoded_attribute = self.encode_value(attribute) + tensor_attr_dict[tensor_attribute_name] = encoded_attribute + + return {"_type": o.__class__.__name__, "_data": tensor_attr_dict} + + if hasattr(o, "_fields") and hasattr( + o, "_asdict" + ): # Check for NamedTuple characteristics + asdict_data = o._asdict() + # Process each field to handle nested objects + processed_data = {k: self.encode_value(v) for k, v in asdict_data.items()} + + return { + "_type": o.__class__.__name__, + "_data": processed_data, + } + + if dataclasses.is_dataclass(o) and not isinstance(o, type): + data_dict = {} + # Process each field to handle nested objects + for f in dataclasses.fields(o): + data_dict[f.name] = self.encode_value(getattr(o, f.name)) + + return { + "_type": o.__class__.__name__, + "_data": data_dict, + } + + if isinstance(o, torch.dtype): + return {"_type": "torch.dtype", "_data": str(o).split(".")[-1]} + + if isinstance(o, enum.Enum): + # Store the full class name for enums to ensure uniqueness + return {"_type": f"{o.__class__.__name__}", "_data": o.name} + + if isinstance(o, list): + return [self.encode_value(item) for item in o] + + if isinstance(o, dict): + return {k: self.encode_value(v) for k, v in o.items()} + + # Default case + return super().default(o) + + def encode_value(self, value): + """Helper method to recursively encode a value""" + # Try to use default for custom type + try: + # This will handle all our special cases and raise TypeError + # if it can't handle the type + result = self.default(value) + return result + except TypeError: + pass + + # Default case - return as is + # (This will be processed by standard JSON encoder later) + return value + + +def object_from_dict(data: Dict[str, Any]): + if not isinstance(data, dict): + raise TypeError(f"Expected dictionary, got {type(data)}") + + if "_type" not in data or "_data" not in data: + raise ValueError("Input dictionary missing required '_type' or '_data' fields") + + type_path = data["_type"] + obj_data = data["_data"] + + if type_path == "torch.dtype": + return getattr(torch, obj_data) + + cls = ALLOWED_CLASSES.get(type_path) + + # If we couldn't find the class in any allowed module, raise an error + if cls is None: + allowed_modules_str = ", ".join(ALLOWED_CLASSES) + raise ValueError( + f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}" + ) + + # Handle the case where obj_data is not a dictionary + if not isinstance(obj_data, dict): + if issubclass(cls, enum.Enum): + # For enums, convert string to enum value + return getattr(cls, obj_data) + else: + # For other primitive types, create an instance with the value + try: + return cls(obj_data) + except: + return obj_data + + processed_data = {} + + for key, value in obj_data.items(): + if isinstance(value, dict) and "_type" in value and "_data" in value: + # Recursively handle nested configs + processed_data[key] = object_from_dict(value) + elif isinstance(value, list): + # Handle lists or tuples of possible configs + processed_data[key] = [ + object_from_dict(item) + if isinstance(item, dict) and "_type" in item and "_data" in item + else item + for item in value + ] + elif isinstance(value, tuple): + raise NotImplementedError( + "Tuples will be serialized as List in JSON, so we recommend to use " + f"Lists instead to avoid surprises. got: {value}" + ) + elif isinstance(value, dict): + # Handle dicts of possible configs + processed_data[key] = { + k: object_from_dict(v) + if isinstance(v, dict) and "_type" in v and "_data" in v + else v + for k, v in value.items() + } + else: + processed_data[key] = value + + # Create and return the instance + try: + return cls(**processed_data) + except Exception as e: + raise ValueError(f"Failed to create instance of {cls.__name__}: {e}") + + +def is_metadata_torchao(metadata: Dict[str, Any]): + if not metadata or "tensor_names" not in metadata: + return False + try: + all_tensor_names = json.loads(metadata["tensor_names"]) + except (TypeError, json.JSONDecodeError, UnicodeDecodeError): + return False + + if not all_tensor_names or not isinstance(all_tensor_names, list): + return False + + for tensor_name in all_tensor_names: + if tensor_name not in metadata or not isinstance(metadata[tensor_name], str): + return False + try: + tensor_dict = json.loads(metadata[tensor_name]) + except (TypeError, json.JSONDecodeError, UnicodeDecodeError): + return False + + # returns None if _type not in tensor_dict + tensor_type = tensor_dict.get("_type") + if tensor_type not in ALLOWED_TENSORS: + return False + + return True diff --git a/torchao/prototype/smoothquant/README.md b/torchao/prototype/smoothquant/README.md index c268a83504..00e819c438 100644 --- a/torchao/prototype/smoothquant/README.md +++ b/torchao/prototype/smoothquant/README.md @@ -1,98 +1,82 @@ -# SmothQuant quantization -This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438). +# SmoothQuant quantization -In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized. +This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438) with TorchAO Quantization APIs. + +$$ +Smoothing factor: s_{j} = \frac{max(|X_{j})^\alpha}{max(|W_{j}|) ^(1-\alpha)}, \ j=1, 2, \dots, C_{i} +$$ + +In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. ## Quick start + Run the example code with + ```bash -python example.py -m MODLE_ID --device= --quant-mode= +python example.py --model --device # An example -python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic -``` -To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance. -```bash -TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device= --quant-mode= --compile +python example.py --model meta-llama/Llama-2-7b-chat-hf ``` -To save a quantized model for reuse, specify `--model-save-path` -```bash -python example.py -m MODLE_ID --device= --quant-mode= --model-save-path ./quantized_model.pt -``` -And load it by `--model-load-path` + +To save a quantized model for reuse, specify `--model_save_path` + ```bash -python example.py -m MODLE_ID --device= --quant-mode= --model-load-path ./quantized_model.pt +python example.py --model --model_save_path ./model_smoothquant.pt ``` - ## Usage of API -The following APIs are provided: -- insert_smooth_quant_observer_ -- SmoothQuantConfig -- save_smooth_quant_recipe (advanced) -- load_smooth_quant_recipe (advanced) -`insert_smooth_quant_observer_` inserts observers into the model to be quantized. For example: -```python -insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic") -``` -After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe. +`SmoothQuantConfig` configures applying SmoothQuant to each linear layer of the model. Use it with `torchao.quantization.quantize_`. For example: -`SmoothQuantConfig` configures appliying SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example: ```python -from torchao.prototype.smoothquant import SmoothQuantObservedLinear -is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) -torchao.quantization.quantize_(model, SmoothQuantConfig(), is_observed_linear) -``` -`is_observed_linear` is a filter so that we only quantize observed linear layers. - -(Advanced) `save_smooth_quant_recipe` and `load_smooth_quant_recipe` saves or loads a recipe for a model. +from torchao.prototype.smoothquant import SmoothQuantConfig +from torchao.prototype.smoothquant.core import SmoothQuantStep +from torchao.quantization import quantize_ +from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig -A recipe contains smoothing factors and quantization parameters of weights and activation for all linear layers that are to be quantized. For advanced users, these parameters can be saved and modified somehow to produce better accuray, e.g., different alpha for different layers. Users can even leave some linear layers unquantized by deleting these layers in the recipe. Such modifications can be published as a recipe. By loading the recipe, it can be reused and calibration is no longer needed. +# Step 1: Prepare - insert observers +quant_config = SmoothQuantConfig( + base_config=Int8DynamicActivationInt8WeightConfig(), + step=SmoothQuantStep.PREPARE, + alpha=0.5, +) +quantize_(model, quant_config) -To save a recipe, users should insert observers and run calibration first. For example, -```python -insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic") -for data in dataset_for_calibration: +# Step 2: Calibration +for data in calibration_dataset: model(data) -save_smooth_quant_recipe(model, "./smooth_quant_recipe.json") -``` -To load a recipe, users should insert observers first. For example, -```python -insert_smooth_quant_observer_(model) -load_smooth_quant_recipe(model, "./smooth_quant_recipe.json") + +# Step 3: Convert +quant_config.step = SmoothQuantStep.CONVERT +quantize_(model, quant_config) ``` -## Benchmark -Running the example with `torch.compile` on a NVIDIA A10G GPU. -### meta-llama/Llama-2-7b-hf -Perplexity -| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* | -|-|-|-|-|-| -| Dynamic | 8.1872 | 7.4257 | 7.2518 | 7.5509 | -| Static | 43.8051 | 11.2984 | 7.5791 | 19.5050 | +## Benchmarks -Note*: Conventional quantization without SmoothQuant +All experiments use the `meta-llama/Llama-2-7b-chat-hf` model with max sequence length (SeqLen) 512 and calibration limit 128 on a 1xH100 80GB HBM2 instance. For comprehensive benchmarking, we compare three cases: 1. origin, 2. W8A8, 3. SmoothQuant (W8A8). -### meta-llama/Meta-Llama-3-8B -Perplexity -| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* | -|-|-|-|-|-| -| Dynamic | 21.2475 | 8.8288 | 9.6514 | 8.3574 | -| Static | 301.7118 | 18.0617 | 10.8343 | 278.9819 | +### Benchmark Results -Note*: Conventional quantization without SmoothQuant +Result shows SmoothQuant with W8A8 slightly increase perplexity, reducing latency 33.82%. Since tinygemm kernel only uses bfloat16 inputs, Tokens/sec decreases for float16 input. -### Test method -**Commands** -```bash -# dynamic quant -TORCHINDUCTOR_FREEZING=1 python example.py -m --device=cuda --quant-mode=dynamic --compile -# static quant -TORCHINDUCTOR_FREEZING=1 python example.py -m --device=cuda --quant-mode=static --compile -``` -Use `--alpha` to specify the alpha parameter. Add `--disable-smooth-quant` to run quantization without SmoothQuant. +| Precision dtype | Quantization | Perplexity | Tokens/sec | PPL Change | Speed Change | +|-----------|--------------|------------|------------|------------|--------------| +| bfloat16 | - | 6.93 | 667 | - | - | +| bfloat16* | - | 6.93 | 27 🐌 | - | - | +| bfloat16 | W8A8-dynamic | 7.35 | 1,967 | +6.07% | +33.89% | +| bfloat16 | W8A8-dynamic** | 7.03 | **1,972** | **+1.39%** | **+33.82%** | +| float16 | - | 6.93 | 625 | - | - | +| float16 | W8A8-dynamic | 7.29 | 523 | +5.21% | -19.42% | +| float16 | W8A8-dynamic** | 6.94 | 516 | **+0.21%** | -21.23% | +| bfloat16* | W8A8-dynamic** | 6.92 | 3 🐌 | -0.18% | -768.29% | + +> *Used with `torch.compile`, **Used with **SmoothQuant** + +### Key Findings + +- **Speed Improvement**: Most configurations show 35-40% speed improvement with both W8A8 and SmoothQuant-W8A8 +- **Quality Trade-off**: Slight perplexity increase (~1-1.4%) in most cases +- **Compilation Impact**: Using `--compile` flag significantly degrades performance (768% slower) +- **Best Configuration**: `bfloat16` without `--compile` provides optimal balance -**Environment** -- AWS g5.12xlarge instance -- torch==2.6.0.dev20241017+cu124 -- python==3.12.6 +> Note: Unlike AWQ, this benchmark isn't computed using the script in `vllm/benchmarks` or `lm_eval`. vLLM benchmark will be introduced in foreseeable future. See https://github.com/pytorch/ao/issues/2815 for more information. diff --git a/torchao/prototype/smoothquant/__init__.py b/torchao/prototype/smoothquant/__init__.py index 948a99c080..2ea8b5713a 100644 --- a/torchao/prototype/smoothquant/__init__.py +++ b/torchao/prototype/smoothquant/__init__.py @@ -1,15 +1,13 @@ -from .api import ( - SmoothQuantConfig, - insert_smooth_quant_observer_, - load_smooth_quant_recipe, - save_smooth_quant_recipe, +from .api import SmoothQuantConfig +from .core import ( + SmoothQuantObservedLinear, + SmoothQuantObserver, + SmoothQuantStep, ) -from .core import SmoothQuantObservedLinear __all__ = [ - "insert_smooth_quant_observer_", - "load_smooth_quant_recipe", - "save_smooth_quant_recipe", "SmoothQuantConfig", + "SmoothQuantStep", + "SmoothQuantObserver", "SmoothQuantObservedLinear", ] diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9397b340b3..9f78c49fb8 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -5,227 +5,122 @@ # LICENSE file in the root directory of this source tree. import types from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional import torch -import torchao from torchao.core.config import AOBaseConfig -from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static -from torchao.prototype.smoothquant.core import ( - SmoothQuantObservedLinear, - SmoothQuantObserver, -) -from torchao.quantization import quantize_ -from torchao.quantization.linear_activation_quantized_tensor import ( - to_linear_activation_quantized, -) from torchao.quantization.linear_activation_scale import ( to_weight_tensor_with_linear_activation_scale_metadata, ) from torchao.quantization.quant_api import ( + _QUANTIZE_CONFIG_HANDLER, _linear_extra_repr, - _replace_with_custom_fn_if_matches_filter, ) -from torchao.quantization.quant_primitives import MappingType from torchao.quantization.transform_module import ( register_quantize_module_handler, ) -from torchao.quantization.utils import _get_per_token_block_size -from torchao.quantization.weight_tensor_linear_activation_quantization import ( - to_weight_tensor_with_linear_activation_quantization_metadata, -) - - -def insert_smooth_quant_observer_( - model: torch.nn.Module, alpha: Optional[float] = 0.5, quant_mode: str = "dynamic" -): - """ - Inserts SmoothQuantObserver into Linear layers of a given model. - - Args: - model: The model to be modified (in place). Ensure model is on the desired device for calibration - alpha: The alpha value to determine smoothing factor. Factor = 1 if alpha is None, which means - falling back to conventional quantization. - quant_mode: dynamic or static quantization of activation - """ - _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) - - quant_min, quant_max = -127, 127 - eps = torch.finfo(torch.float32).eps - - def replace_with_observer(layer): - # creates observer and replaces linear layers with observed linear layers - observer = SmoothQuantObserver( - layer.weight, - alpha, - quant_mode, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) - return SmoothQuantObservedLinear.from_float(layer, observer) - - _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) - - -def save_smooth_quant_recipe( - model: torch.nn.Module, save_path: str -) -> Dict[str, torch.Tensor]: - """ - Save smoothing_factors, act_scales, and wei_scales for each SmoothQuantObservedLinear layer in the model. - """ - result = {} - - def recurse(module: torch.nn.Module, name: str = ""): - for child_name, child in module.named_children(): - full_name = f"{name}.{child_name}" if name else child_name - - # Apply the analysis function to this layer - if isinstance(child, SmoothQuantObservedLinear): - smoothing_factor, act_scales, wei_scales = child.obs.calculate_qparams() - result[full_name + ".smoothing_factor"] = smoothing_factor - result[full_name + ".act_scales"] = act_scales - result[full_name + ".wei_scales"] = wei_scales - - # Recurse into child modules - recurse(child, full_name) - - recurse(model) - - torch.save(result, save_path) - - -def load_smooth_quant_recipe( - model: torch.nn.Module, recipe_path: str, device=None -) -> torch.nn.Module: - recipe = torch.load(recipe_path, weights_only=True) - - def recurse(module: torch.nn.Module, name: str = ""): - if isinstance(module, SmoothQuantObservedLinear): - smoothing_factor = recipe.get(name + ".smoothing_factor", None) - act_scales = recipe.get(name + ".act_scales", None) - wei_scales = recipe.get(name + ".wei_scales", None) - if device is not None: - module.to(device=device) - # act_scales is None for dynamic quantization - if any(x is None for x in (smoothing_factor, wei_scales)): - return module - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - wrapper = torch.nn.Sequential(module) - quantize_( - wrapper, - SmoothQuantConfig(smoothing_factor, act_scales, wei_scales), - is_observed_linear, - ) - return wrapper[0] - - mod_new = module - - for child_name, child in module.named_children(): - full_name = f"{name}.{child_name}" if name else child_name - setattr(mod_new, child_name, recurse(child, full_name)) - return mod_new - - recurse(model) - - -class _ActQuantizer: - def __init__(self, target_dtype, quant_min=-127): - self.target_dtype = target_dtype - self.quant_min = quant_min - - def dynamic_quantize(self, input): - return to_affine_quantized_intx( - input, - MappingType.SYMMETRIC, - _get_per_token_block_size(input), - self.target_dtype, - self.quant_min, - ) +from torchao.utils import DummyModule - def static_quantize(self, input, scale, zero_point): - return to_affine_quantized_intx_static( - input, - scale, - zero_point, - list(input.shape), - self.target_dtype, - self.quant_min, - ) +from .core import ( + SmoothQuantObservedLinear, + SmoothQuantObserver, + SmoothQuantStep, +) @dataclass class SmoothQuantConfig(AOBaseConfig): """ - Configuration for quantizing linear layers when passed into quantize_() + Configuration for SmoothQuant quantization when passed into quantize_() Args: - smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None. - act_scales: The activation scales for the layer. Acquired from the layer's observer if None. - wei_scales: The weight scales for the layer. Acquired from the layer's observer if None. - set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. + base_config: Base quantization configuration that SmoothQuant is applied on top of + step (SmoothQuantStep): The step for SmoothQuant process + PREPARE: insert SmoothQuant Observers to linear layers + CONVERT: convert the observed linear modules to quantized modules + PREPARE_FOR_LOADING: convert the floating point model to a dummy smoothquant quantized model, so we can + load the quantized weights through copy_ later + alpha: The alpha value to determine smoothing factor. Factor = 1 if alpha is None, which means + Fall back to conventional quantization if None """ - smoothing_factor: Optional[torch.Tensor] = None - act_scales: Optional[torch.Tensor] = None - wei_scales: Optional[torch.Tensor] = None - set_inductor_config: bool = True + base_config: AOBaseConfig + step: SmoothQuantStep + alpha: Optional[float] = 0.5 + + def __post_init__(self): + self.step = self.step.lower() if isinstance(self.step, str) else self.step.value + all_step_values = [s.value for s in SmoothQuantStep] + if self.step not in all_step_values: + raise ValueError(f"{self.step} is not one of {all_step_values}") @register_quantize_module_handler(SmoothQuantConfig) def _smooth_quant_transform( module: torch.nn.Module, config: SmoothQuantConfig, -): - smoothing_factor = config.smoothing_factor - act_scales = config.act_scales - wei_scales = config.wei_scales - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - observed_linear = module - - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - observed_linear.bias is not None, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.bias = observed_linear.bias +) -> torch.nn.Module: + step = config.step + base_config = config.base_config - target_dtype = torch.int8 - # act_scales is None for dynamic quantization thus not checked - if any(x is None for x in (smoothing_factor, wei_scales)): - factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() - weight = observed_linear.obs.weight * factor - else: - factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales - weight = observed_linear.weight * factor - weight = weight.to(observed_linear.weight.dtype) - block_size = (1, weight.size(1)) - wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) - qw = to_affine_quantized_intx_static( - weight, - w_scales, - wei_zero_points, - block_size, - target_dtype, - ) + if step == SmoothQuantStep.PREPARE: + observer = SmoothQuantObserver( + weight=module.weight, + alpha=config.alpha, + ) + return SmoothQuantObservedLinear.from_float(module, observer) - if x_scale is None: - # dynamic quant - qw = to_linear_activation_quantized( - qw, _ActQuantizer(target_dtype).dynamic_quantize + if step == SmoothQuantStep.PREPARE_FOR_LOADING: + # loading from pre-quantized checkpoint + observer = SmoothQuantObserver( + weight=module.weight, + alpha=config.alpha, ) + observed_linear = SmoothQuantObservedLinear.from_float(module, observer) + example_input = torch.randn( + (1, module.weight.shape[1]), + device=module.weight.device, + dtype=module.weight.dtype, + ) + observed_linear(example_input) + + elif step == SmoothQuantStep.CONVERT: + if not isinstance(module, SmoothQuantObservedLinear): + print( + f"convert: module is not SmoothQuantObservedLinear, skipping: {type(module)}" + ) + return module + observed_linear = module else: - # static quant - x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64) - qw = to_weight_tensor_with_linear_activation_quantization_metadata( - qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point + raise ValueError(f"Unexpected step: {step}") + + # Compute smoothed weight parameters + smoothing_factor = observed_linear.obs.calculate_qparams() + weight = observed_linear.weight * smoothing_factor + + # Create new linear layer + with torch.device("meta"): + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias is not None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, ) + linear.bias = observed_linear.bias - qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype)) + # Quantize weights + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)] + dummy_mod = DummyModule(weight) + quant_mod = base_config_handler(dummy_mod, base_config) + qw = quant_mod.weight + + # Add smoothing factor metadata + qw = to_weight_tensor_with_linear_activation_scale_metadata( + qw, smoothing_factor.to(qw.dtype) + ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) - linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.extra_repr = types.MethodType(_linear_extra_repr, linear) + return linear diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 3e6c6ea5d5..83f1e78275 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -3,15 +3,17 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from enum import Enum from typing import Optional import torch import torch.nn.functional as F -from torchao.quantization.observer import AffineQuantizedMinMaxObserver, PerAxis -from torchao.quantization.quant_primitives import ( - MappingType, -) + +class SmoothQuantStep(str, Enum): + PREPARE = "prepare" + CONVERT = "convert" + PREPARE_FOR_LOADING = "prepare_for_loading" class SmoothQuantObserver(torch.nn.Module): @@ -19,113 +21,48 @@ def __init__( self, weight: torch.Tensor, alpha: Optional[float] = 0.5, - quant_mode: str = "static", # or dynamic - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, ): """ - A custom observer for SmoothQuant + A custom observer for smoothing factor, main concept of SmoothQuant. Args: weight: The weight tensor to be observed. alpha: The alpha value to determine smoothing factor, normally between 0 and 1. - Fall back to conventional quantization if alpha is None. - quant_mode: The mode of activation quantization, either static or dynamic - quant_min: The minimum quantized value - quant_max: The maximum quantized value - eps: The minimum scale to avoid dividing by zero. """ super().__init__() assert weight.ndim == 2 self.weight = weight - self.inputs = [] - self.device = self.weight.device self.alpha = alpha - assert quant_mode in ["static", "dynamic"] - self.quant_mode = quant_mode - self.quant_min = quant_min - self.quant_max = quant_max - self.eps = eps - # act.shape = [mb, ic] (reshape if needed), wei.shape = [oc, ic] - # *_ic_obs are used to determine smoothing_factor - # wei_oc_obs is used to find qparams for quantization - self.act_ic_obs = AffineQuantizedMinMaxObserver( - MappingType.SYMMETRIC, - torch.int8, - PerAxis(-1), - eps=eps, - ) - self.wei_ic_obs = AffineQuantizedMinMaxObserver( - MappingType.SYMMETRIC, - torch.int8, - PerAxis(-1), - eps=eps, - ) - self.wei_oc_obs = AffineQuantizedMinMaxObserver( - MappingType.SYMMETRIC, - torch.int8, - PerAxis(0), - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) - self.wei_ic_obs(self.weight) + self.inputs = [] + self.device = weight.device @torch.no_grad() def forward(self, input: torch.Tensor): - self.act_ic_obs(input.to("cpu")) + self.inputs.append(input.to("cpu")) return input def calculate_qparams(self): - # 1 Get min/max per IC from observers - wei_min_per_ic = self.wei_ic_obs.min_val - wei_max_per_ic = self.wei_ic_obs.max_val - act_min_per_ic = self.act_ic_obs.min_val - act_max_per_ic = self.act_ic_obs.max_val - x_abs_max_per_ic = ( - torch.max(torch.abs(act_min_per_ic), torch.abs(act_max_per_ic)) + self.eps - ) - w_abs_max_per_ic = ( - torch.max(torch.abs(wei_min_per_ic), torch.abs(wei_max_per_ic)) + self.eps + assert self.inputs and len(self.inputs) > 0, ( + "calibrate observer first by running model on exemplar data" ) - # 2 calculate the smoothing factor + inputs = [inp.to(self.device) for inp in self.inputs] + acc = torch.cat(inputs, dim=0) + # Reshape if needed: [batch, seq, features] -> [batch*seq, features] + if acc.ndim > 2: + acc = acc.view(-1, acc.shape[-1]) + + # Calculate per-channel max values + x_abs_max = torch.max(torch.abs(acc), dim=0)[0] + w_abs_max = torch.max(torch.abs(self.weight), dim=0)[0] + + # Calculate smoothing factor if self.alpha is None: - # fall back to conventional quantization if alpha is None - smoothing_factor = torch.ones_like( - x_abs_max_per_ic, - dtype=x_abs_max_per_ic.dtype, - device=x_abs_max_per_ic.device, - ) - else: - smoothing_factor = torch.pow(x_abs_max_per_ic, self.alpha) / torch.pow( - w_abs_max_per_ic.to(x_abs_max_per_ic.device), 1 - self.alpha - ) - # 3 apply smoothing factor to activations and find scales for static quantization - act_scales = None - if self.quant_mode == "static": - act_min_per_ic_new = act_min_per_ic / smoothing_factor.reshape( - act_min_per_ic.shape - ) - act_max_per_ic_new = act_max_per_ic / smoothing_factor.reshape( - act_max_per_ic.shape - ) - min_val_per_tensor = torch.min(act_min_per_ic_new) - max_val_per_tensor = torch.max(act_max_per_ic_new) - min_val_neg = torch.min( - min_val_per_tensor, torch.zeros_like(min_val_per_tensor) - ) - max_val_pos = torch.max( - max_val_per_tensor, torch.zeros_like(max_val_per_tensor) - ) - max_val_pos = torch.max(-min_val_neg, max_val_pos) - act_scale = max_val_pos / (float(self.quant_max - self.quant_min) / 2) - act_scales = act_scale.to(self.device) - # 4 update weight and find scales - self.wei_oc_obs(self.weight * smoothing_factor.to(self.device)) - wei_scales, _ = self.wei_oc_obs.calculate_qparams() - # 5 return results - return smoothing_factor.to(self.device), act_scales, wei_scales.to(self.device) + return torch.ones_like(x_abs_max) + + eps = torch.finfo(torch.float32).eps + return torch.pow(x_abs_max + eps, self.alpha) / torch.pow( + w_abs_max + eps, 1 - self.alpha + ) class SmoothQuantObservedLinear(torch.nn.Linear): @@ -133,30 +70,31 @@ def __init__( self, in_features: int, out_features: int, - bias: bool, obs: SmoothQuantObserver, + is_bias: bool = False, device=None, dtype=None, ): - super().__init__(in_features, out_features, bias, device, dtype) - assert isinstance(obs, SmoothQuantObserver) + super().__init__( + in_features, out_features, bias=is_bias, device=device, dtype=dtype + ) self.obs = obs def forward(self, input: torch.Tensor): input = self.obs(input) - output = F.linear(input, self.weight, self.bias) - return output + return F.linear(input, self.weight) @classmethod def from_float(cls, float_linear: torch.nn.Linear, obs: SmoothQuantObserver): - observed_linear = cls( - float_linear.in_features, - float_linear.out_features, - float_linear.bias is not None, - obs, - device=float_linear.weight.device, - dtype=float_linear.weight.dtype, - ) + with torch.device("meta"): + observed_linear = cls( + float_linear.in_features, + float_linear.out_features, + obs, + is_bias=float_linear.bias is not None, + device=float_linear.weight.device, + dtype=float_linear.weight.dtype, + ) observed_linear.weight = float_linear.weight observed_linear.bias = float_linear.bias return observed_linear diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index de1e4ed93e..8602b57e20 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -4,185 +4,263 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import argparse -import os import time -from typing import Optional import torch from datasets import load_dataset -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig +from torchao.prototype.awq.example import get_calib_dataset from torchao.prototype.smoothquant import ( SmoothQuantConfig, - SmoothQuantObservedLinear, - insert_smooth_quant_observer_, ) +from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ +from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig -def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation") - samples = [] - n_tokens = n_samples * block_size - n_run = n_tokens - for data in dataset: - line = data["text"] - line = line.strip() - line_encoded = tokenizer.encode(line) - if len(line_encoded) > 512: - continue - sample = torch.tensor([line_encoded]) - if sample.numel() == 0: - continue - samples.append(sample) - n_run -= len(line_encoded) - if n_run <= n_samples: - break - - cat_samples = torch.cat(samples, dim=1) - return [ - cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples) - ] - - -def wiki2_eval( - model, tokenizer, sequence_length, stride=512, verbose=True, device="cuda" -): - model.eval() - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "right" - tokenizer.add_eos_token = False - - print("Loading dataset") - t0 = time.time() +# TODO: Build benchmark within vLLM ecosystem with more quantization APIs +# See https://github.com/pytorch/ao/issues/2815 for more details +def benchmark(model, tokenizer, max_seq_length=512, tasks=["PPL"], device="cuda"): + """Benchmark model with perplexity calculation on WikiText-2""" + # Load WikiText-2 test set dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt") - print(f"Time to load dataset: {time.time() - t0:.02f} seconds") - - encodings["input_ids"] = encodings["input_ids"].to(device) - - print("Running evaluation") - lls, t = [], [] - for i in tqdm( - range(0, encodings["input_ids"].size(1), stride), disable=not verbose - ): - begin_loc = max(i + stride - sequence_length, 0) - end_loc = min(i + stride, encodings["input_ids"].size(1)) - trg_len = end_loc - i - input_ids = encodings["input_ids"][:, begin_loc:end_loc] - target_ids = input_ids.clone() - target_ids[:, :-trg_len] = -100 # ignore context - - t1 = time.time() - with torch.no_grad(): - log_likelihood = model(input_ids, labels=target_ids).loss * trg_len - if device == "cuda": - torch.cuda.synchronize() - t2 = time.time() - t.append((t2 - t1)) - lls.append(log_likelihood) - - del input_ids, target_ids - - ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) - pred_time = sum(t) / len(t) - if verbose: - print("perplexity", ppl) - print("time", str(pred_time) + " sec/it") - - return {"perplexity": ppl, "prediction_time": pred_time} - - -def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): + + # Prepare text data and truncate if necessary + text = "\n\n".join(dataset["text"]) + # Get model's maximum sequence length + model_max_length = getattr(tokenizer, "model_max_length", max_seq_length) + if model_max_length > 1000000: # Default large value, use our max_seq_length + model_max_length = max_seq_length + + encodings = tokenizer( + text, return_tensors="pt", truncation=True, max_length=model_max_length + ) + + # Calculate perplexity model.eval() - model.config.use_cache = False - if tasks is None: - tasks = ["PPL"] - results = {} - if "PPL" in tasks: - results["perplexity"] = wiki2_eval( - model, tokenizer, 512, verbose=True, device=device - ) - return results - - -def wikitext2_ppl( + nlls = [] + + with torch.no_grad(): + seq_len = encodings.input_ids.size(1) + prev_end_loc = 0 + + for begin_loc in range(0, seq_len, max_seq_length): + end_loc = min(begin_loc + max_seq_length, seq_len) + trg_len = end_loc - prev_end_loc + + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + # Measure inference time + start_time = time.time() + outputs = model(input_ids, labels=target_ids) + inference_time = time.time() - start_time + + neg_log_likelihood = outputs.loss * trg_len + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + + ppl = torch.exp(torch.stack(nlls).sum() / end_loc) + + return { + "perplexity": ppl.item(), + "tokens_per_sec": input_ids.size(1) / inference_time, + } + + +def quantize_and_eval( model_id: str, - alpha: Optional[float], - quant_mode: str, - calibration_size: int, + alpha: float, + tasks: list[str], + max_seq_length: int, + calibration_limit: int, device: str, - precision: torch.dtype, - sequence_length: int, - compile: bool, - model_load_path: str, model_save_path: str, + model_save_hf_hub_path: str, ): print(f"Loading model on {device}...") torch.manual_seed(34) t0 = time.time() tokenizer = AutoTokenizer.from_pretrained(model_id) - if model_load_path is not None and os.path.exists(model_load_path): - print(f"Loading quantized model from {model_load_path}") - t0 = time.time() - model = torch.load(model_load_path, weights_only=False).to(device) - print(f"Time to load quantized model: {time.time() - t0:.02f} seconds") - else: - model = ( - AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=precision) - .eval() - .to(device) - ) - print(f"Time to load model: {time.time() - t0:.02f} seconds") - print("running calibration") - t0 = time.time() - # insert observers to find average magnitude and calculate scales - insert_smooth_quant_observer_(model, alpha, quant_mode) - calibration_data = get_calib_dataset( - tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length - ) - for batch in calibration_data: - model(batch.to(device)) - batch.to("cpu") - print(f"time for calibration: {time.time() - t0:.02f} seconds") - - is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - print(f"running SmoothQuant with {quant_mode} quantization") - t0 = time.time() - quantize_(model, SmoothQuantConfig(), is_observed_linear) - print(f"time for quantization: {time.time() - t0:.02f} seconds") - if model_save_path is not None: - print(f"Saving quantized model to {model_save_path}") - t0 = time.time() - torch.save(model, model_save_path) - print(f"Time to save quantized model: {time.time() - t0:.02f} seconds") - if compile: - model = torch.compile(model, dynamic=True) - - return benchmark(model, tokenizer, sequence_length, tasks=["PPL"], device=device) + model = ( + AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16) + .eval() + .to(device) + ) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + # Step 1: Prepare - insert observers + print("running SmoothQuant prepare and calibrate") + t0 = time.time() + quant_config = SmoothQuantConfig( + base_config=Int8DynamicActivationInt8WeightConfig(), + step=SmoothQuantStep.PREPARE, + alpha=alpha, + ) + quantize_(model, quant_config) -if __name__ == "__main__": + # Step 2: Calibration + calibration_data = get_calib_dataset( + tokenizer=tokenizer, n_samples=calibration_limit, block_size=max_seq_length + ) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") + + print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds") + + # Step 3: Convert to quantized model + print("running SmoothQuant convert") + t0 = time.time() + quant_config.step = SmoothQuantStep.CONVERT + quantize_(model, quant_config) + print(f"time for convert: {time.time() - t0:.02f} seconds") + + # Set up config for loading + quant_config.step = SmoothQuantStep.PREPARE_FOR_LOADING + model.config.quantization_config = TorchAoConfig(quant_config) + + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + + if model_save_hf_hub_path is not None: + print("pushing model to hub:", model_save_hf_hub_path) + model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) + tokenizer.push_to_hub(model_save_hf_hub_path) + + print("Benchmarking SmoothQuant model...") + return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) + + +def compare_models( + model_id: str, + alpha: float, + tasks: list[str], + max_seq_length: int, + calibration_limit: int, + device: str, + model_save_path: str, + model_save_hf_hub_path: str, +): + """Compare perplexity and speed for behchmarking SmoothQuant""" + + # Case 1: Base model without quantization + print("Benchmarking base model...") + torch.manual_seed(34) + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = ( + AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16) + .eval() + .to(device) + ) + base_results = benchmark( + model, tokenizer, max_seq_length, tasks=tasks, device=device + ) + + # Case 2: W8A8-dynamic without SmoothQuant + print("Benchmarking W8A8-dynamic without SmoothQuant...") + torch.manual_seed(34) + w8a8_model = ( + AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16) + .eval() + .to(device) + ) + quantize_(w8a8_model, Int8DynamicActivationInt8WeightConfig()) + w8a8_results = benchmark( + w8a8_model, tokenizer, max_seq_length, tasks=tasks, device=device + ) + + # Case 3: SmoothQuant + W8A8-dynamic + print("Benchmarking SmoothQuant with W8A8-dynamic...") + smoothquant_results = quantize_and_eval( + model_id, + alpha, + tasks, + max_seq_length, + calibration_limit, + device, + model_save_path, + model_save_hf_hub_path, + ) + + # Calculate changes and display results + w8a8_ppl_change = ( + (w8a8_results["perplexity"] - base_results["perplexity"]) + / base_results["perplexity"] + * 100 + ) + w8a8_speed_change = ( + (w8a8_results["tokens_per_sec"] - base_results["tokens_per_sec"]) + / base_results["tokens_per_sec"] + * 100 + ) + + smoothquant_ppl_change = ( + (smoothquant_results["perplexity"] - base_results["perplexity"]) + / base_results["perplexity"] + * 100 + ) + smoothquant_speed_change = ( + (smoothquant_results["tokens_per_sec"] - base_results["tokens_per_sec"]) + / base_results["tokens_per_sec"] + * 100 + ) + + # Print results + print( + f"\nBase: PPL={base_results['perplexity']:.2f}, Speed={base_results['tokens_per_sec']:.2f} tokens/sec" + ) + print( + f"w8a8-Dynamic: PPL={w8a8_results['perplexity']:.2f}, Speed={w8a8_results['tokens_per_sec']:.2f} tokens/sec" + ) + print( + f"SmoothQuant+w8a8: PPL={smoothquant_results['perplexity']:.2f}, Speed={smoothquant_results['tokens_per_sec']:.2f} tokens/sec" + ) + print(f"w8a8 Changes: PPL {w8a8_ppl_change:+.2f}%, Speed {w8a8_speed_change:+.2f}%") + print( + f"SmoothQuant Changes: PPL {smoothquant_ppl_change:+.2f}%, Speed {smoothquant_speed_change:+.2f}%" + ) + + return { + "base_model": base_results, + "w8a8_model": w8a8_results, + "smoothquant_model": smoothquant_results, + "w8a8_ppl_change_percent": w8a8_ppl_change, + "w8a8_speed_improvement_percent": w8a8_speed_change, + "smoothquant_ppl_change_percent": smoothquant_ppl_change, + "smoothquant_speed_improvement_percent": smoothquant_speed_change, + } + + +def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( - description="Evaluate a model with the specified parameters." + description="Evaluate a model with SmoothQuant quantization." ) - # Optional arguments with default values parser.add_argument( - "--model-id", "-m", type=str, help="Repository ID of the model." + "--model", type=str, required=True, help="Model ID from Huggingface hub." ) parser.add_argument( "--alpha", type=float, default=0.5, - help="The alpha hyperparameter for SmoothQuant.", + help="The alpha hyperparameter for SmoothQuant. Default is 0.5.", ) parser.add_argument( - "--quant-mode", type=str, help="Quantization mode, either static or dynamic." + "--tasks", + nargs="+", + type=str, + help="Task to benchmark model on.", + default=["PPL"], ) parser.add_argument( - "--calibration-samples", + "--calibration_limit", type=int, default=10, help="Number of samples to use for calibration. Default is 10.", @@ -194,54 +272,38 @@ def wikitext2_ppl( help="Device to run the evaluation on. Default is 'cuda'.", ) parser.add_argument( - "--precision", - type=str, - default="bfloat16", - help="Precision type. Default is 'bfloat16'.", - ) - parser.add_argument( - "--seq_len", + "--max_seq_length", type=int, default=512, - help="Length of examples to calibrate and evaluate model on. Default is 512", + help="Maximum sequence length. Default is 512", ) parser.add_argument( - "--compile", - action="store_true", - help="Flag to indicate if compilation is required.", - ) - parser.add_argument( - "--model-load-path", + "--model_save_path", type=str, default=None, - help="Path to load quantized model. If this is provided, " - "the model will be loaded from this path instead of quantizing the model.", + help="Path to store the quantized model.", ) parser.add_argument( - "--model-save-path", + "--model_save_hf_hub_path", type=str, default=None, - help="Path to store quantized model.", - ) - parser.add_argument( - "--disable-smooth-quant", - action="store_true", - help="Run conventional dynamic or static quantization for testing or debugging.", + help="Huggingface hub path to store the quantized model and tokenizer.", ) + return parser + + +if __name__ == "__main__": + parser = create_parser() args = parser.parse_args() - # Convert precision argument to torch dtype - precision_dtype = getattr(torch, args.precision, torch.bfloat16) - ppl = wikitext2_ppl( - args.model_id, - None if args.disable_smooth_quant else args.alpha, - args.quant_mode, - args.calibration_samples, + result = compare_models( + args.model, + args.alpha, + args.tasks, + args.max_seq_length, + args.calibration_limit, args.device, - args.precision, - args.seq_len, - args.compile, - args.model_load_path, args.model_save_path, + args.model_save_hf_hub_path, ) diff --git a/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py b/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py index c61a00b8e1..df9ed7cf5e 100644 --- a/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py +++ b/torchao/prototype/sparsity/pruner/lstm_saliency_pruner.py @@ -43,7 +43,7 @@ def update_mask(self, module, tensor_name, **kwargs): ) # take norm over all but first dim dims = tuple(range(1, weights.dim())) - saliency = weights.norm(dim=dims, p=1) + saliency = torch.linalg.vector_norm(weights, dim=dims, ord=1) # handle weights in 4 groups split_size = len(mask) // 4 diff --git a/torchao/prototype/sparsity/pruner/saliency_pruner.py b/torchao/prototype/sparsity/pruner/saliency_pruner.py index 0c0af152fa..4619773313 100644 --- a/torchao/prototype/sparsity/pruner/saliency_pruner.py +++ b/torchao/prototype/sparsity/pruner/saliency_pruner.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import torch + from .base_structured_sparsifier import BaseStructuredSparsifier @@ -11,7 +13,7 @@ class SaliencyPruner(BaseStructuredSparsifier): Prune rows based on the saliency (L1 norm) of each row. This pruner works on N-Dimensional weight tensors. - For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. + For each row, we will calculate the saliency, which is the sum the L1 norm of all weights in that row. We expect that the resulting saliency vector has the same shape as our mask. We then pick elements to remove until we reach the target sparsity_level. """ @@ -26,7 +28,9 @@ def update_mask(self, module, tensor_name, **kwargs): raise Exception( "Structured pruning can only be applied to a 2+dim weight tensor!" ) - saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + saliency = -torch.linalg.vector_norm( + weights, dim=tuple(range(1, weights.dim())), ord=1 + ) assert saliency.shape == mask.shape num_to_pick = int(len(mask) * kwargs["sparsity_level"]) diff --git a/torchao/prototype/spinquant/hadamard_utils.py b/torchao/prototype/spinquant/hadamard_utils.py index e1c779c563..1a88664c79 100644 --- a/torchao/prototype/spinquant/hadamard_utils.py +++ b/torchao/prototype/spinquant/hadamard_utils.py @@ -11,7 +11,6 @@ import torch -from torchao.ops import lib from torchao.prototype.spinquant._hadamard_matrices import ( get_had12, get_had20, @@ -26,7 +25,6 @@ get_had156, get_had172, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 try: from fast_hadamard_transform import hadamard_transform as _fast_hadamard_transform @@ -50,21 +48,14 @@ def matmul_hadU(X, hadK, K): def register_custom_op_impl(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.custom_op(f"{name}", mutates_args=())(func) - else: - lib.define("hadamard_transform(Tensor x, float scale = 0.0) -> Tensor") - return torch.library.impl(f"{name}", "cuda")(func) + return torch.library.custom_op(f"{name}", mutates_args=())(func) return decorator def register_custom_op_abstract(name): def decorator(func): - if TORCH_VERSION_AT_LEAST_2_4: - return torch.library.register_fake(f"{name}")(func) - else: - return torch.library.impl_abstract(f"{name}")(func) + return torch.library.register_fake(f"{name}")(func) return decorator @@ -175,7 +166,7 @@ def get_hadK(n, transpose=False): hadK = get_had12().T if transpose else get_had12() else: assert is_pow2(n) - + hadK = torch.FloatTensor([[1]]) K = 1 return hadK, K @@ -222,7 +213,7 @@ def matmul_hadU_fast(X, hadK, K): def random_hadamard_matrix(size, device, seed=0): # See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation" - gen = torch.Generator() + gen = torch.Generator(device=device) gen.manual_seed(seed) Q = torch.randint(low=0, high=2, size=(size,), generator=gen).to(torch.float64) Q = Q * 2 - 1 @@ -246,6 +237,10 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!" W = module.weight.data + if output and module.bias is not None: + B = module.bias.data + bias_dtype_orig = B.dtype + B = B.float() dtype_orig = W.dtype W = W.float() @@ -253,9 +248,13 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): if output: had_K, K = get_hadK(out_features) W = matmul_hadU(W.t(), had_K.to(W.device), K).t() + if output and module.bias is not None: + B = matmul_hadU(B, had_K.to(B.device), K) else: had_K, K = get_hadK(in_features) W = matmul_hadU(W, had_K.to(W.device), K) + if output and module.bias is not None: + B = matmul_hadU(B, had_K.to(B.device), K) else: if R2 is not None: hadK = R2.to(torch.float64) @@ -269,8 +268,15 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): temp = W.reshape(-1, shape[-1] // had_dim, had_dim) temp = temp.to(torch.float64) @ hadK W = temp.reshape(shape) + if output and module.bias is not None: + shape = B.shape + temp = B.reshape(-1, had_dim) + temp = temp.to(torch.float64) @ hadK + B = temp.reshape(shape) if output: W = W.t() module.weight.data = W.to(dtype=dtype_orig) + if output and module.bias is not None: + module.bias.data = B.to(dtype=bias_dtype_orig) diff --git a/torchao/prototype/spinquant/spinquant.py b/torchao/prototype/spinquant/spinquant.py index 3c5733615a..b64534c602 100644 --- a/torchao/prototype/spinquant/spinquant.py +++ b/torchao/prototype/spinquant/spinquant.py @@ -48,6 +48,7 @@ def apply_spinquant( use_r2=False, use_r4=True, pretrained_rotation_path=None, + qkv_split=False, ): """ Apply SpinQuant to a Transformer model: https://arxiv.org/abs/2405.16406 @@ -57,9 +58,9 @@ def apply_spinquant( which appears to show best results in many cases (see https://github.com/pytorch/ao/pull/983). Note that the R3 rotation matrix and Cayley optimization for R1/R2 are currently not implemented. - """ - assert isinstance(model, Transformer), "Only Transformer models are supported" + qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv + """ original_device = next(model.parameters()).device device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device=device) @@ -75,18 +76,21 @@ def apply_spinquant( assert Path(pretrained_rotation_path).suffix == ".bin", "Expected a .bin file." if use_r1: - fuse_layernorm_into_linear(model) - apply_spinquant_r1(model, device, pretrained_rotation_path) + fuse_layernorm_into_linear(model, qkv_split) + apply_spinquant_r1(model, device, pretrained_rotation_path, qkv_split) if use_r2: - apply_spinquant_r2(model, device, pretrained_rotation_path) + apply_spinquant_r2(model, device, pretrained_rotation_path, qkv_split) if use_r4: apply_spinquant_r4(model, device) model.to(device=original_device) -def apply_spinquant_r1(model, device, pretrained_rotation_path=None): - """Apply the SpinQuant R1 rotation matrix to the model.""" +def apply_spinquant_r1(model, device, pretrained_rotation_path=None, qkv_split=False): + """ + Apply the SpinQuant R1 rotation matrix to the model. + qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv + """ if pretrained_rotation_path is not None: R1 = torch.load(pretrained_rotation_path)["R1"].to(device).to(torch.float64) @@ -97,11 +101,14 @@ def apply_spinquant_r1(model, device, pretrained_rotation_path=None): else: R1 = random_hadamard_matrix(model.config.dim, device) - _rotate_model_r1(model, R1) + _rotate_model_r1(model, R1, qkv_split=qkv_split) -def apply_spinquant_r2(model, device, pretrained_rotation_path=None): - """Apply the SpinQuant R2 rotation matrices to the model.""" +def apply_spinquant_r2(model, device, pretrained_rotation_path=None, qkv_split=False): + """ + Apply the SpinQuant R2 rotation matrices to the model. + qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv + """ R2s = [] # note that unlike R1, there are multiple R2 matrices (one per layer) head_dim = model.config.head_dim @@ -118,7 +125,7 @@ def apply_spinquant_r2(model, device, pretrained_rotation_path=None): R2 = random_hadamard_matrix(head_dim, device) R2s.append(R2) - _rotate_model_r2(model, R2s) + _rotate_model_r2(model, R2s, qkv_split=qkv_split) def apply_spinquant_r4(model, device): @@ -154,19 +161,19 @@ def _fuse_layernorm_into_linear( @torch.no_grad() -def _rotate_model_r1(model, R1): +def _rotate_model_r1(model, R1, qkv_split=False): _rotate_embeddings(model, R1) _rotate_head(model, R1) for layer in model.layers: - _rotate_attention_inputs(layer, R1) + _rotate_attention_inputs(layer, R1, qkv_split=qkv_split) _rotate_attention_output(layer, R1) _rotate_mlp_input(layer, R1) _rotate_mlp_output(layer, R1) @torch.no_grad() -def _rotate_model_r2(model, R2s): +def _rotate_model_r2(model, R2s, qkv_split=False): """Rotate the W_v and W_o weights of the multi-head self-attention modules.""" head_dim = model.config.head_dim @@ -180,25 +187,28 @@ def _rotate_model_r2(model, R2s): # Rotate W_o apply_exact_had_to_linear(attn.wo, had_dim=head_dim, output=False, R2=R2) - # Extract W_v - kv_size = model.config.n_local_heads * head_dim - wq, wk, wv = attn.wqkv.weight.data.split( - [model.config.dim, kv_size, kv_size], dim=0 - ) - out_features, in_features = wv.shape - wv_mod = nn.Linear( - in_features, - out_features, - bias=attn.wqkv.bias is not None, - device=wv.device, - dtype=wv.dtype, - ) - wv_mod.weight.data = wv + if qkv_split: + apply_exact_had_to_linear(attn.wv, had_dim=head_dim, output=True, R2=R2) + else: + # Extract W_v + kv_size = model.config.n_local_heads * head_dim + wq, wk, wv = attn.wqkv.weight.data.split( + [model.config.dim, kv_size, kv_size], dim=0 + ) + out_features, in_features = wv.shape + wv_mod = nn.Linear( + in_features, + out_features, + bias=attn.wqkv.bias is not None, + device=wv.device, + dtype=wv.dtype, + ) + wv_mod.weight.data = wv - # Rotate W_v - apply_exact_had_to_linear(wv_mod, had_dim=head_dim, output=True, R2=R2) + # Rotate W_v + apply_exact_had_to_linear(wv_mod, had_dim=head_dim, output=True, R2=R2) - attn.wqkv.weight.data = torch.cat([wq, wk, wv_mod.weight.data], dim=0) + attn.wqkv.weight.data = torch.cat([wq, wk, wv_mod.weight.data], dim=0) @torch.no_grad() @@ -226,12 +236,14 @@ def _add_activation_wrappers_r4(model): @torch.no_grad() -def fuse_layernorm_into_linear(model): +def fuse_layernorm_into_linear(model, qkv_split=False): """ Fuse RMSNorm weights into the subsequent linear layers. This is done in the paper specifically to make pre-norm LLMs like LLaMa rotation-invariant when quantization is not present. + + qkv_split should be set to True if attention modules have separate tensors wq, wk, wv instead of wqkv """ # Embedding fusion (from SpinQuant repo: utils/fuse_norm_utils.py:43) # I currently don't understand why this is necessary, so I contacted the @@ -244,7 +256,13 @@ def fuse_layernorm_into_linear(model): _fuse_layernorm_into_linear( layer.ffn_norm, [layer.feed_forward.w1, layer.feed_forward.w3] ) - _fuse_layernorm_into_linear(layer.attention_norm, [layer.attention.wqkv]) + if qkv_split: + _fuse_layernorm_into_linear( + layer.attention_norm, + [layer.attention.wq, layer.attention.wk, layer.attention.wv], + ) + else: + _fuse_layernorm_into_linear(layer.attention_norm, [layer.attention.wqkv]) _fuse_layernorm_into_linear(model.norm, [model.output]) @@ -270,8 +288,13 @@ def _rotate_attention_output(layer, R1): mod.bias.data = torch.matmul(R1.T, b).to(dtype=mod.weight.dtype) -def _rotate_attention_inputs(layer, R1): - _rotate_mod_weight_right(layer.attention.wqkv, R1) +def _rotate_attention_inputs(layer, R1, qkv_split=False): + if qkv_split: + _rotate_mod_weight_right(layer.attention.wq, R1) + _rotate_mod_weight_right(layer.attention.wk, R1) + _rotate_mod_weight_right(layer.attention.wv, R1) + else: + _rotate_mod_weight_right(layer.attention.wqkv, R1) def _rotate_head(model, R1): diff --git a/torchao/prototype/tensor_conversion/__init__.py b/torchao/prototype/tensor_conversion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/tensor_conversion/api.py b/torchao/prototype/tensor_conversion/api.py new file mode 100644 index 0000000000..6533e5de2d --- /dev/null +++ b/torchao/prototype/tensor_conversion/api.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +# TODO: move the function to torchao.utils +from torchao.dtypes.utils import is_device +from torchao.quantization import ( + Int4PreshuffledTensor, + Int4Tensor, + IntxUnpackedToInt8Tensor, +) +from torchao.utils import TorchAOBaseTensor, _is_fbgemm_genai_gpu_available + + +def _convert_linear_weight_to_int8_lut_tensor(module): + from torchao.prototype.quantization.int8_lut_tensor import Int8LutTensor + + assert isinstance(module, nn.Linear) + weight = module.weight + new_weight = Int8LutTensor.from_intx_unpacked_to_int8_tensor( + weight, bias=module.bias + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.bias = None + + +def _convert_module_weight_to_intx_opaque_tensor(module, intx_packing_format): + from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + ) + + assert isinstance(module, nn.Linear) or isinstance(module, nn.Embedding) + weight = module.weight + new_weight = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor( + weight, + bias=module.bias if hasattr(module, "bias") else None, + intx_packing_format=intx_packing_format, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + if hasattr(module, "bias"): + module.bias = None + + +def _find_tied_module_names_for_embedding(embedding_weight, model): + assert isinstance(embedding_weight, IntxUnpackedToInt8Tensor) + tied_names = [] + for name, module in model.named_modules(): + is_linear = isinstance(module, nn.Linear) + is_embedding = isinstance(module, nn.Embedding) + if not (is_linear or is_embedding): + continue + + weight = module.weight + if not isinstance(weight, IntxUnpackedToInt8Tensor): + continue + + # We only have tied kernels for dynamically quantized linears + if is_linear and weight.activation_quantization != "int8_asym_per_token": + continue + + # We only have tied kernels for linear layers with no bias + if is_linear and module.bias is not None: + continue + + are_tied = ( + (embedding_weight.shape == weight.shape) + and (embedding_weight.block_size == weight.block_size) + and (embedding_weight.dtype == weight.dtype) + and (embedding_weight.qdata == weight.qdata).all() + and (embedding_weight.scale == weight.scale).all() + and (embedding_weight.zero_point == weight.zero_point).all() + ) + + if are_tied: + tied_names.append(name) + + return tied_names + + +def _find_tied_params(model): + from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import ( + IntxOpaqueTensor, + ) + + module_name_to_tied_param = {} + for name, module in model.named_modules(): + if not isinstance(module, nn.Embedding): + continue + + weight = module.weight + if not isinstance(weight, IntxUnpackedToInt8Tensor): + continue + + tied_module_names = _find_tied_module_names_for_embedding(weight, model) + if not tied_module_names: + continue + + if name in module_name_to_tied_param: + tied_param = module_name_to_tied_param[name] + else: + # Construct a new tied param + # IntxOpaqueTensor requires activation_quantization = int8_asym_per_token + prev = weight.activation_quantization + weight.activation_quantization = "int8_asym_per_token" + tied_param = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor( + weight, + bias=None, + intx_packing_format="opaque_torchao_lowbit", + ) + weight.activation_quantization = prev + tied_param = nn.Parameter(tied_param, requires_grad=False) + module_name_to_tied_param[name] = tied_param + + for t in tied_module_names: + if t not in module_name_to_tied_param: + module_name_to_tied_param[t] = tied_param + + return module_name_to_tied_param + + +def _convert_model_for_aarch64( + model, *, tensor_type="auto", intx_packing_format="opaque_torchao_auto" +): + module_name_to_tied_param = _find_tied_params(model) + + # Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor + for name, module in model.named_modules(): + if name in module_name_to_tied_param: + module.weight = module_name_to_tied_param[name] + continue + + if isinstance(module, nn.Embedding): + print("Skipping converting nn.Embedding {name} because it is not tied") + continue + + if not isinstance(module, nn.Linear): + continue + + weight = module.weight + if not isinstance(weight, IntxUnpackedToInt8Tensor): + print( + f"Skipping converting {name} to IntxOpaqueTensor because its weight is not an IntxUnpackedToInt8Tensor" + ) + continue + + if tensor_type == "int8_lut_tensor": + _convert_linear_weight_to_int8_lut_tensor(module) + elif tensor_type == "intx_opaque_tensor": + _convert_module_weight_to_intx_opaque_tensor(module, intx_packing_format) + elif tensor_type == "auto": + if weight._has_float_zero_point() and isinstance(module, nn.Linear): + _convert_linear_weight_to_int8_lut_tensor(module) + else: + _convert_module_weight_to_intx_opaque_tensor( + module, intx_packing_format + ) + else: + raise ValueError(f"Unexpected tensor_type={tensor_type}") + + return model + + +def convert_to_packed_tensor_based_on_current_hardware(tensor: TorchAOBaseTensor): + """Convert a plain / unpacked torchao tensor to a packed one based on hardware + + Goal is to have an optimized performance on current hardware, while also allow + us to + (1). distribute a single unpacked / plain format that can be used in multiple hardwares + (2). support the vLLM use case, where we need to slice the weights for distributed + inference. Since slice is not always supported in packed weight, we would like to first + load plain / unpacked weight, slice it and then convert to packed weight to get the best + inference speed + """ + if ( + isinstance(tensor, Int4Tensor) + and is_device("cuda", tensor.device) + and _is_fbgemm_genai_gpu_available() + ): + return Int4PreshuffledTensor.from_int4_tensor(tensor) + return tensor diff --git a/torchao/prototype/tests/test_spinquant.py b/torchao/prototype/tests/test_spinquant.py new file mode 100644 index 0000000000..f9dce4d9d6 --- /dev/null +++ b/torchao/prototype/tests/test_spinquant.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn + +from torchao.prototype.spinquant.hadamard_utils import apply_exact_had_to_linear + + +class TestSpinQuant(unittest.TestCase): + def test_rotate_in_and_out(self): + """Perform rotation to output of linear layer and inverse rotation to input of next layer; test that the output is the same.""" + with torch.no_grad(): + layer1 = nn.Linear(256, 256, bias=True) + layer2 = nn.Linear(256, 256, bias=True) + model = nn.Sequential(layer1, layer2) + input = torch.rand(256) + output = model(input) + apply_exact_had_to_linear(layer1, output=True) + apply_exact_had_to_linear(layer2, output=False) + new_output = model(input) + torch.testing.assert_allclose(output, new_output) diff --git a/torchao/quantization/GPTQ/GPTQ.py b/torchao/quantization/GPTQ/GPTQ.py index fe55ed19db..f20a5a4965 100644 --- a/torchao/quantization/GPTQ/GPTQ.py +++ b/torchao/quantization/GPTQ/GPTQ.py @@ -295,7 +295,7 @@ def __torch_function__( SQNR(DQ, DQ_from_qtensor), ) - qparams2 = cls.get_qparams_func(W) + qparams2 = cls.get_qparams_func(W, W.dtype) Q2 = cls.quantize_func(W, qparams2) DQ2 = cls.dequantize_func(Q2, qparams2).to(W.dtype) old_q_out = ( @@ -444,7 +444,9 @@ def faster_quant(cls, H, W, device): group_end = min(group_start + group_size, columns) if group_start % group_size == 0: # needed for when group_size == columns so only calculate qparams once - cur_qparams = cls.get_qparams_func(W[:, group_start:group_end]) + cur_qparams = cls.get_qparams_func( + W[:, group_start:group_end], orig_dtype + ) all_qparams.append(cur_qparams) for index in range(group_start, group_end): # within each group @@ -679,10 +681,11 @@ def __init__( else: self.zero_point_domain = ZeroPointDomain.FLOAT - self.get_qparams_func = lambda w: get_groupwise_affine_qparams( + self.get_qparams_func = lambda w, precision: get_groupwise_affine_qparams( w, n_bit, group_size, + dtype=precision, zero_point_domain=self.zero_point_domain, ) self.quantize_func = ( diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 83caffdc09..f53a6085c1 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -101,21 +101,21 @@ When used as in the example above, when the `autoquant` api is called alongside When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `Int4WeightOnlyConfig()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow. -Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. +Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization._AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. ```python import pickle import torchao.quantization # After the first forward pass (when quantization was done) -from torchao.quantization.autoquant import AUTOQUANT_CACHE +from torchao.quantization.autoquant import _AUTOQUANT_CACHE with open("quantization-cache.pkl", "wb") as f: - pickle.dump(AUTOQUANT_CACHE, f) + pickle.dump(_AUTOQUANT_CACHE, f) # On load -from torchao.quantization.autoquant import AUTOQUANT_CACHE +from torchao.quantization.autoquant import _AUTOQUANT_CACHE with open("quantization-cache.pkl", "rb") as f: - AUTOQUANT_CACHE.update(pickle.load(f)) + _AUTOQUANT_CACHE.update(pickle.load(f)) ``` ## Quantization Techniques @@ -125,18 +125,13 @@ be applied individually. While there are a large variety of quantization apis, t #### A16W4 WeightOnly Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int4WeightOnlyConfig group_size = 32 # you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through -# use_hqq flag for `Int4WeightOnlyConfig` quantization +# by setting int4_choose_qparams_algorithm to "hqq" for `Int4WeightOnlyConfig` quantization use_hqq = False -quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors -change_linear_weights_to_int4_woqtensors(model) +quantize_(model, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")) ``` Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. @@ -144,28 +139,18 @@ Note: The quantization error incurred by applying int4 quantization to your mode #### A16W8 Int8 WeightOnly Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int8WeightOnlyConfig quantize_(model, Int8WeightOnlyConfig()) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors -change_linear_weights_to_int8_woqtensors(model) ``` #### A8W8 Int8 Dynamic Quantization ```python -# for torch 2.4+ from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig quantize_(model, Int8DynamicActivationInt8WeightConfig()) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors -change_linear_weights_to_int8_dqtensors(model) ``` -### A16W8 Float8 WeightOnly Quantization +#### A16W8 Float8 WeightOnly Quantization ```python # for torch 2.5+ @@ -205,6 +190,34 @@ quantize_(model, FPXWeightOnlyConfig(3, 2)) You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype. +``` + +KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: + +```python +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + quantize_, +) +from torchao.quantization.granularity import PerGroup, PerAxis +from torchao.quantization.quant_primitives import MappingType +from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler + +my_model = Model() + +quantize_( + my_model, + Int8DynamicActivationIntxWeightConfig( + weight_scale_dtype=torch.float32, + weight_granularity=PerGroup(32), # PerAxis is also supported + weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR, # MappingType.SYMMETRIC can also be used but increases error + layout=layout, + weight_dtype=torch.int4, + intx_packing_format="opaque_aten_kleidiai", + ), +) +``` + ## Affine Quantization Details Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. @@ -266,16 +279,10 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -quantize_(m, Int4WeightOnlyConfig(group_size=group_size)) -## If different zero_point_domain needed -# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT)) +quantize_(m, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="tile_packed_to_4d")) +# can also specify different packing format +# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="plain")) -# temporary workaround for tensor subclass + torch.compile -# NOTE: this is only need for torch version < 2.5+ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -from torchao.utils import unwrap_tensor_subclass -if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(m) # compile the model to improve performance m = torch.compile(m, mode='max-autotune') @@ -363,7 +370,7 @@ Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. F | | w4a8-g128 | 187.62 | 640.32 | 4.82 | 3.41 | ### Gemlite Triton -Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `gemlite_uintx_weight_only`. An example can be found in `torchao/_models/llama/generate.py`. +Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `GemliteUIntXWeightOnlyConfig`. An example can be found in `torchao/_models/llama/generate.py`. Note: we test on gemlite 0.4.1, but should be able to use any version after that, we'd recommend to use the latest release to get the most recent performance improvements. diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index d9aba0bcc5..d57b8790c7 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -19,6 +19,7 @@ MultiTensorInputRecorder, ) from .granularity import ( + Granularity, PerAxis, PerGroup, PerRow, @@ -43,9 +44,9 @@ ) from .quant_api import ( CutlassInt4PackedLayout, - FbgemmConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Float8MMConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, @@ -87,6 +88,17 @@ dequantize_affine, quantize_affine, ) +from .quantize_.workflows import ( + Float8Tensor, + Int4MarlinSparseTensor, + Int4OpaqueTensor, + Int4PlainInt32Tensor, + Int4PreshuffledTensor, + Int4Tensor, + Int4TilePackedTo4dTensor, + IntxOpaqueTensor, + IntxUnpackedToInt8Tensor, +) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, SmoothFakeDynQuantMixin, @@ -137,6 +149,7 @@ "Int8DynamicActivationInt8WeightConfig", "Int8DynamicActivationIntxWeightConfig", "Int4WeightOnlyConfig", + "Float8DynamicActivationInt4WeightConfig", "Int8WeightOnlyConfig", "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", @@ -148,7 +161,16 @@ "GemliteUIntXWeightOnlyConfig", "AOPerModuleConfig", "ModuleFqnToConfig", - "FbgemmConfig", + # tensor subclasses + "Int4Tensor", + "Int4PlainInt32Tensor", + "Int4PreshuffledTensor", + "Int4MarlinSparseTensor", + "IntxOpaqueTensor", + "IntxUnpackedToInt8Tensor", + "Int4TilePackedTo4dTensor", + "Float8Tensor", + "Int4OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", @@ -176,6 +198,7 @@ "MappingType", "ZeroPointDomain", "TorchAODType", + "Granularity", "PerTensor", "PerAxis", "PerGroup", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 6f0aac947a..eb19a00923 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -27,15 +27,14 @@ ZeroPointDomain, ) from torchao.quantization.utils import ( + _quantize_activation_per_token_absmax, compute_error, - quantize_activation_per_token_absmax, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, is_sm_at_least_89, is_sm_at_least_90, + torch_version_at_least, ) from .granularity import ( @@ -63,15 +62,15 @@ aten = torch.ops.aten -AUTOQUANT_CACHE = {} +_AUTOQUANT_CACHE = {} -def check_cache(cls, shapes_and_dtype): - return AUTOQUANT_CACHE.get((cls,) + shapes_and_dtype, None) +def _check_cache(cls, shapes_and_dtype): + return _AUTOQUANT_CACHE.get((cls,) + shapes_and_dtype, None) -def update_cache(cls, shapes_and_dtype, res): - AUTOQUANT_CACHE[(cls,) + shapes_and_dtype] = res +def _update_cache(cls, shapes_and_dtype, res): + _AUTOQUANT_CACHE[(cls,) + shapes_and_dtype] = res # TODO: Document the methods @@ -145,12 +144,12 @@ def log_shape(act_mat, w_autoquant, bias): shapes_and_dtype, 0 ) for q_cls in w_autoquant.qtensor_class_list: - if check_cache(q_cls, shapes_and_dtype) is None: - update_cache(q_cls, shapes_and_dtype, None) + if _check_cache(q_cls, shapes_and_dtype) is None: + _update_cache(q_cls, shapes_and_dtype, None) def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype - if check_cache(q_cls, shapes_and_dtype) is None: + if _check_cache(q_cls, shapes_and_dtype) is None: with torch.no_grad(): act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device) bias = ( @@ -183,7 +182,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}" ) res = torch.inf - update_cache(q_cls, shapes_and_dtype, res) + _update_cache(q_cls, shapes_and_dtype, res) @torch.no_grad() def to_quantized(self, error_on_unseen, **kwargs): @@ -223,13 +222,13 @@ def count_shapes(self, do_print=True): total_seen = 0 shape_count = count_shapes(self, do_print=False) for shapes_and_dtype, times_seen in self.logged_data.items(): - if check_cache(q_cls, shapes_and_dtype) is None: + if _check_cache(q_cls, shapes_and_dtype) is None: # only print shapes once if print_shape_once: print_shape_once = False count_shapes(self, do_print=True) - time_for_best_shape = check_cache(best_cls, shapes_and_dtype) + time_for_best_shape = _check_cache(best_cls, shapes_and_dtype) time_for_best_shape = ( torch.inf if time_for_best_shape is None @@ -238,7 +237,7 @@ def count_shapes(self, do_print=True): self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape) ran_new_benchmarks = True torch._dynamo.reset() - cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen + cur_time += _check_cache(q_cls, shapes_and_dtype) * times_seen total_seen += times_seen cur_time = cur_time / total_seen # print aggregated time if there were multiple shapes to aggregate and some new benchmarking was done @@ -329,6 +328,8 @@ def do_autoquant_bench(op, *args, **kwargs): """ runs benchmark op(*args, **kwargs) avoiding torch.compile overhead """ + from torch._inductor.runtime.benchmarking import benchmarker + rep = kwargs.pop("rep", 100) warmup = kwargs.pop("warmup", 25) with torch.no_grad(): @@ -343,22 +344,15 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - if TORCH_VERSION_AT_LEAST_2_5: - from torch._inductor.runtime.benchmarking import benchmarker + if torch_version_at_least("2.9.0.dev"): + from statistics import median res = benchmarker.benchmark_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - elif TORCH_VERSION_AT_LEAST_2_3: - from torch._inductor.runtime.runtime_utils import do_bench_gpu - - res = do_bench_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="all" ) + res = median(res) else: - from torch._inductor.utils import do_bench - - res = do_bench( + res = benchmarker.benchmark_gpu( lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" ) return res @@ -498,7 +492,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): # SAM best is between .8 and 1, SDXL also performs best in this range INTERPOLATION_CONSTANT = mode[1] w_qtensor = cls.from_float(weight) - x_vals_int8, x_scales = quantize_activation_per_token_absmax( + x_vals_int8, x_scales = _quantize_activation_per_token_absmax( act_mat.reshape(-1, act_mat.shape[-1]) ) quantized_matmul = ( @@ -1269,6 +1263,8 @@ def autoquant( model(*example_input2) model.finalize_autoquant() """ + torch._C._log_api_usage_once("torchao.quantization.autoquant") + if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() @@ -1346,12 +1342,11 @@ def finalize_autoquant(): return model -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) - torch.serialization.add_safe_globals( - [ - _to_float16, - _to_bfloat16, - _identity, - ] - ) +torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) +torch.serialization.add_safe_globals( + [ + _to_float16, + _to_bfloat16, + _identity, + ] +) diff --git a/torchao/quantization/dynamic_quant.py b/torchao/quantization/dynamic_quant.py index 61c6b0dc07..5c6ee9c8f9 100644 --- a/torchao/quantization/dynamic_quant.py +++ b/torchao/quantization/dynamic_quant.py @@ -8,8 +8,8 @@ import torch.nn as nn from .utils import ( + _quant_int8_dynamic_per_token_linear, dynamically_quantize_per_channel, - quant_int8_dynamic_per_token_linear, ) __all__ = ["DynamicallyPerAxisQuantizedLinear"] @@ -44,7 +44,7 @@ def forward(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ - Y = quant_int8_dynamic_per_token_linear( + Y = _quant_int8_dynamic_per_token_linear( X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype ) return Y diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index aa946c064f..abc6c794e9 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -8,10 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "LinearActivationQuantizedTensor", @@ -136,11 +133,13 @@ def _same_metadata( @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) + input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None) + weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None) + bias = kwargs.get("bias", args[2] if len(args) > 2 else None) + + assert input_tensor is not None, "input tensor must not be None" + assert weight_tensor is not None, "weight tensor must not be None" + if isinstance(weight_tensor, LinearActivationQuantizedTensor): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) @@ -219,6 +218,11 @@ def _(func, types, args, kwargs): for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) return + elif type(self) is torch.Tensor and type(src) is LinearActivationQuantizedTensor: + new_src = src.to(dtype=self.dtype, device=self.device) + self.copy_(new_src) + return + raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) @@ -288,8 +292,7 @@ def _(func, types, args, kwargs): ) -to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float +to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float # Converts a float tensor to LinearActivationQuantizedTensor for dynamic activation quantization -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([LinearActivationQuantizedTensor]) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([LinearActivationQuantizedTensor]) diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index 6c433844a6..500228cf3c 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -6,10 +6,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "WeightTensorWithLinearActivationScaleMetadata", @@ -33,8 +30,8 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor): scale (torch.Tensor): The scale tensor to be applied to activation. """ - original_weight_tensor: torch.Tensor - scale: torch.Tensor + tensor_data_names = ["original_weight_tensor", "scale"] + tensor_attribute_names = [] def __new__( cls, @@ -57,21 +54,8 @@ def __init__( self.original_weight_tensor = original_weight_tensor self.scale = scale - def __repr__(self): - return f"WeightTensorWithLinearActivationScaleMetadata({self.original_weight_tensor}, scale={self.scale}" - - def __tensor_flatten__(self): - tensor_data = ["original_weight_tensor", "scale"] - return tensor_data, [] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - return cls( - tensor_data_dict["original_weight_tensor"], - tensor_data_dict["scale"], - ) + def _quantization_type(self): + return f"{self.__class__}" @staticmethod def _quantized_linear_op( @@ -93,20 +77,6 @@ def from_float( ): return cls(input_float, scale) - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.original_weight_tensor), - fn(self.scale), - ) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.original_weight_tensor.to(device), - self.scale.to(device), - ) - implements = WeightTensorWithLinearActivationScaleMetadata.implements @@ -126,28 +96,13 @@ def _(func, types, args, kwargs): ) -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) +@implements(aten.slice.Tensor) def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -@implements(aten._to_copy.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + self = args[0] + new = self.__class__( + func(self.original_weight_tensor, *args[1:], **kwargs), self.scale ) + return return_and_correct_aliasing(func, args, kwargs, new) @implements(aten.t.default) @@ -161,8 +116,5 @@ def _(func, types, args, kwargs): WeightTensorWithLinearActivationScaleMetadata.from_float ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals( - [WeightTensorWithLinearActivationScaleMetadata] - ) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([WeightTensorWithLinearActivationScaleMetadata]) diff --git a/torchao/quantization/linear_activation_weight_observed_tensor.py b/torchao/quantization/linear_activation_weight_observed_tensor.py index 029b89e54b..d17bc382db 100644 --- a/torchao/quantization/linear_activation_weight_observed_tensor.py +++ b/torchao/quantization/linear_activation_weight_observed_tensor.py @@ -9,10 +9,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.observer import AffineQuantizedObserverBase -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "LinearActivationWeightObservedTensor", @@ -153,6 +150,5 @@ def _(func, types, args, kwargs): ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) diff --git a/torchao/quantization/linear_quant_modules.py b/torchao/quantization/linear_quant_modules.py index 73e95036f1..de6755a55d 100644 --- a/torchao/quantization/linear_quant_modules.py +++ b/torchao/quantization/linear_quant_modules.py @@ -16,10 +16,7 @@ import torch.nn.functional as F from torchao.dtypes.utils import is_device -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - find_multiple, -) +from torchao.utils import find_multiple from .quant_primitives import ( MappingType, @@ -60,7 +57,7 @@ def linear_forward_int4( ): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + if is_device(x.device.type, "cpu"): c = torch.ops.aten._weight_int4pack_mm_for_cpu( x.to(precision), weight_int4pack, @@ -299,10 +296,7 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - if ( - is_device(w_int4x8.device.type, "cpu") - and TORCH_VERSION_AT_LEAST_2_6 - ): + if is_device(w_int4x8.device.type, "cpu"): weight_int4pack = ( torch.ops.aten._convert_weight_to_int4pack_for_cpu( w_int4x8.to(self.device), self.inner_k_tiles diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index e103f0a59e..d12ffaf520 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -10,11 +10,10 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.quantization.quant_primitives import _fake_quantize_affine from .granularity import ( Granularity, - PerAxis, PerRow, PerTensor, ) @@ -24,6 +23,7 @@ _get_reduction_params, choose_qparams_affine_with_min_max, ) +from .utils import get_block_size logger = logging.getLogger(__name__) @@ -63,26 +63,6 @@ def _with_args(cls_or_self, *args, **kwargs): return r -def get_block_size( - input_shape: Tuple[int, ...], granularity: Granularity -) -> Tuple[int, ...]: - """Get the block size based on the input shape and granularity type. - - Args: - input_shape: The input tensor shape possibly more than 2 dimensions - granularity: The granularity type of the quantization - """ - if isinstance(granularity, PerTensor): - return input_shape - elif isinstance(granularity, PerAxis): - block_size = list(input_shape) - block_size[granularity.axis] = 1 - return tuple(block_size) - elif isinstance(granularity, PerRow): - return (1,) * (len(input_shape) - 1) + (input_shape[-1],) - raise ValueError(f"Unsupported Granularity: {granularity}") - - ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: @@ -193,6 +173,184 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([PerRow, PerTensor]) +class AffineQuantizedFixedQParamObserver(AffineQuantizedObserverBase): + """ + Observer that allows manual setting of fixed quantization parameters. + """ + + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + scale: Optional[torch.Tensor] = None, + zero_point: Optional[torch.Tensor] = None, + ): + super().__init__( + mapping_type, + target_dtype, + granularity, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + if not scale: + scale = torch.Tensor([1]) + if not zero_point: + zero_point = torch.zeros_like(scale) + self.register_buffer("scale", scale.to(dtype=scale_dtype)) + self.register_buffer("zero_point", zero_point.to(dtype=zero_point_dtype)) + + def set_qparams(self, scale, zero_point=None): + if not zero_point: + zero_point = torch.zeros_like(scale) + self.scale = scale.to(dtype=self.scale_dtype) + self.zero_point = zero_point.to(dtype=self.zero_point_dtype) + + def forward(self, input): + return input + + def calculate_qparams(self): + return self.scale, self.zero_point + + +class AffineQuantizedMSEObserver(AffineQuantizedObserverBase): + """ + Minimize quantization loss caused by outlier via linear search. More details can be found at https://arxiv.org/pdf/2209.13325 + """ + + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + steps: int = 100, + run_once: bool = False, + ): + super().__init__( + mapping_type, + target_dtype, + granularity, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + self.steps = steps + self.calibrated = False + self.run_once = run_once + + def mse(self, pred, expect, block_size): + loss = (pred - expect).abs().pow(2) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, loss.size() + ) + loss = loss.view(shape_for_reduction) + return torch.mean(loss, dim=reduction_dims, keepdim=False) + + def loss_fn(self, x, new_min, new_max): + block_size = get_block_size(x.shape, self.granularity) + scale, zero_point = choose_qparams_affine_with_min_max( + new_min, + new_max, + self.mapping_type, + [], + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + x_q = _fake_quantize_affine( + x, + block_size, + scale, + zero_point, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain, + ) + return self.mse(x_q, x, block_size) + + def line_search(self, input): + if input.numel() == 0: + return input + + input_detached = input.detach() + assert self.granularity is not None, "granularity is None" + block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + + range_val = torch.max(min_val.abs(), max_val) + optimal_loss = torch.zeros_like(min_val) + 1e9 + + # check which clip range could produce smallest loss + for i in range(1, self.steps + 1): + thres = range_val / self.steps * i + current_loss = self.loss_fn(input, -thres, thres) + min_val = torch.where(current_loss < optimal_loss, -thres, min_val) + max_val = torch.where(current_loss < optimal_loss, thres, max_val) + optimal_loss = torch.min(current_loss, optimal_loss) + + return min_val, max_val + + def forward(self, input): + if not (self.run_once and self.calibrated): + self.min_val, self.max_val = self.line_search(input) + self.calibrated = True + + return input + + def calculate_qparams(self): + assert hasattr(self, "min_val") and hasattr(self, "max_val"), ( + "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + ) + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([PerRow, PerTensor]) diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py index 20d51912f0..72a7da453e 100644 --- a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py @@ -1,9 +1,9 @@ from torchao.quantization.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - to_affine_fake_quantized, + _AffineFakeQuantizedTensor, + _to_affine_fake_quantized, ) __all__ = [ - "AffineFakeQuantizedTensor", - "to_affine_fake_quantized", + "_AffineFakeQuantizedTensor", + "_to_affine_fake_quantized", ] diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/prototype/qat/fake_quantizer.py index 3bbe1fb704..560a609ce2 100644 --- a/torchao/quantization/prototype/qat/fake_quantizer.py +++ b/torchao/quantization/prototype/qat/fake_quantizer.py @@ -1,5 +1,5 @@ from torchao.quantization.qat.fake_quantizer import ( - FakeQuantizer, + IntxFakeQuantizer as FakeQuantizer, ) __all__ = [ diff --git a/torchao/quantization/pt2e/__init__.py b/torchao/quantization/pt2e/__init__.py index 8b6a99337b..0b8f8c12ed 100644 --- a/torchao/quantization/pt2e/__init__.py +++ b/torchao/quantization/pt2e/__init__.py @@ -48,7 +48,6 @@ from .observer import ( AffineQuantizedObserverBase, FixedQParamsObserver, - Granularity, HistogramObserver, MappingType, MinMaxObserver, @@ -57,20 +56,13 @@ NoopObserver, ObserverBase, PartialWrapper, - PerAxis, - PerBlock, PerChannelMinMaxObserver, - PerGroup, - PerRow, - PerTensor, - PerToken, PlaceholderObserver, RecordingObserver, ReuseInputObserver, TorchAODType, UniformQuantizationObserverBase, ZeroPointDomain, - get_block_size, ) for _f in [ @@ -139,17 +131,9 @@ "compare_results", # should be merged with torchao/quantization/observer.py in the future "AffineQuantizedObserverBase", - "Granularity", "MappingType", - "PerAxis", - "PerBlock", - "PerGroup", - "PerRow", - "PerTensor", - "PerToken", "TorchAODType", "ZeroPointDomain", - "get_block_size", "default_fake_quant", "default_dynamic_fake_quant", ] diff --git a/torchao/quantization/pt2e/_affine_quantization.py b/torchao/quantization/pt2e/_affine_quantization.py index e02bee03ce..a863c8f00e 100644 --- a/torchao/quantization/pt2e/_affine_quantization.py +++ b/torchao/quantization/pt2e/_affine_quantization.py @@ -19,8 +19,8 @@ MappingType, TorchAODType, ZeroPointDomain, - get_block_size, ) +from torchao.quantization.utils import get_block_size ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: diff --git a/torchao/quantization/pt2e/_numeric_debugger.py b/torchao/quantization/pt2e/_numeric_debugger.py index de1e1eee84..df01d02f99 100644 --- a/torchao/quantization/pt2e/_numeric_debugger.py +++ b/torchao/quantization/pt2e/_numeric_debugger.py @@ -14,13 +14,9 @@ from torch.ao.ns.fx.utils import compute_sqnr from torch.export import ExportedProgram from torch.fx import GraphModule, Node +from torch.fx.traceback import NodeSource from torch.nn import functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.fx.traceback import NodeSource - from .graph_utils import bfs_trace_with_node_process NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" @@ -30,6 +26,21 @@ log = logging.getLogger(__name__) +@dataclass(frozen=True) +class NodeSourceDebugInfo: + """ + Contains node source information for locating the node in the original graph. + This replaces the numeric debug handle approach with direct node source info. + """ + + # The name of the node in the graph, e.g. "conv2d" + name: str + + # The unique id of the graph that the node belongs to. + graph_id: int + + +# This function is no longer used for torchao debug flow, but is kept here for backward compatibility. def generate_numeric_debug_handle(ep: ExportedProgram) -> None: """ Attach numeric_debug_handle_id for all nodes in the graph module of the given @@ -40,7 +51,7 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None: Here's an example of using debug handle quantize flow:: - ep = export_for_training(eager_model, example_inputs) + ep = torch.export.export(eager_model, example_inputs) generate_numeric_debug_handle(ep) m = ep.module() @@ -84,53 +95,48 @@ def _assign_debug_handle(node: torch.fx.Node) -> None: bfs_trace_with_node_process(ep, _assign_debug_handle) -def _get_greatest_ancestor_node_source(node: Node) -> Optional["NodeSource"]: - if (node_source := node.meta.get(FROM_NODE_KEY)) is None: - return None +def _extract_node_source_debug_info(node: Node) -> Optional[NodeSourceDebugInfo]: + """ + Extract node source debug info from a node, or return None if the node + does not need to be traced. - node_source = node_source[-1] + Returns NodeSourceDebugInfo containing the name and graph_id from the + node's greatest ancestor node source, or None if the node is not in + the original graph. + """ - while len(node_source.from_node) > 0: - node_source = node_source.from_node[-1] + def _get_greatest_ancestor_node_source(node: Node) -> "NodeSource": + node_source = node.meta.get(FROM_NODE_KEY)[-1] - return node_source + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + return node_source -def _generate_debug_handle_from_node(node: Node) -> Optional[int]: - """ - Generate a debug handle based on node's oldest ancestor node's name - and graph id, or return None if the node does not need to be traced. + def _is_node_in_original_graph(node: Node) -> bool: + if ( + FROM_NODE_KEY not in node.meta + or node.meta[FROM_NODE_KEY] is None + or node.meta[FROM_NODE_KEY][-1].pass_name + == "ExportedProgram.module().unlift()" + ): + # This node is not part of the ExportedProgram.module().graph, so it doesn't have a debug handle + return False - This is a temporary function for migrating node tracing infra from - using debug handle to node.meta["from_node"]. The infrastructure will - depend on node.meta["from_node"] directly in the future, without the need - of debug handle as intermediate variable. - """ + return True if node.op == "placeholder" or node.op == "output": - # placeholder and output nodes don't have debug handle + # placeholder and output nodes don't have debug info return None - if ( - FROM_NODE_KEY not in node.meta - or node.meta[FROM_NODE_KEY] is None - or node.meta[FROM_NODE_KEY][-1].pass_name == "ExportedProgram.module().unlift()" - ): - # This node is not part of the ExportedProgram.module().graph, so it doesn't have a debug handle + if not _is_node_in_original_graph(node): return None greatest_ancestor_node_source = _get_greatest_ancestor_node_source(node) - if greatest_ancestor_node_source is None: - # This node is not part of the ExportedProgram.module().graph, so it doesn't have a debug handle - return None - - if greatest_ancestor_node_source.pass_name == "ExportedProgram.module().unlift()": - # uplifted nodes don't have debug handle - return None - - return hash( - greatest_ancestor_node_source.name + str(greatest_ancestor_node_source.graph_id) + return NodeSourceDebugInfo( + name=greatest_ancestor_node_source.name, + graph_id=greatest_ancestor_node_source.graph_id, ) @@ -192,14 +198,14 @@ class OutputLogger(torch.nn.Module): def __init__( self, - debug_handle: int, + debug_info: NodeSourceDebugInfo, node_name: Optional[str] = None, nn_module_stack: Optional[object] = None, ) -> None: super().__init__() self.node_name = node_name self.nn_module_stack = nn_module_stack - self.debug_handle = debug_handle + self.debug_info = debug_info self.stats: list[object] = [] def forward(self, x: object) -> object: @@ -208,15 +214,17 @@ def forward(self, x: object) -> object: def __extra_repr__(self) -> str: return ( - f"debug_handle={self.debug_handle}, node_name={self.node_name}, " + f"debug_info={self.debug_info}, node_name={self.node_name}, " "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" ) -def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: +def _insert_logger( + model: GraphModule, node: Node, debug_info: NodeSourceDebugInfo +) -> Node: """For a given node, adds an OutputLogger that observes the output of that node, and all its users use the OutputLogger output instead. - The OutputLogger will contain the debug_handle which can be used to compare + The OutputLogger will contain the debug_info which can be used to compare graphs after transforms""" # to avoid circular dep @@ -229,7 +237,7 @@ def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: setattr( model, logger_name, - OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), + OutputLogger(debug_info, node.name, node.meta.get("nn_module_stack")), ) logger_node = model.graph.call_module(logger_name, (node,), {}) @@ -250,17 +258,11 @@ def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: Returns: a model with output loggers for all unlifted nodes """ - if not TORCH_VERSION_AT_LEAST_2_6: - log.warning( - "prepare_for_propagation_comparison is only supported for PyTorch 2.6+" - ) - return model - # don't change the original model model = copy.deepcopy(model) for n in model.graph.nodes: - if (numeric_debug_handle := _generate_debug_handle_from_node(n)) is not None: - _insert_logger(model, n, numeric_debug_handle) + if (debug_info := _extract_node_source_debug_info(n)) is not None: + _insert_logger(model, n, debug_info) model.recompile() return model @@ -310,7 +312,7 @@ def __post_init__(self) -> None: @dataclass(frozen=True) class NodeAccuracySummary: - handle: int + debug_info: NodeSourceDebugInfo actual_node_name: str actual_module_stack: str ref_node_name: str @@ -334,21 +336,21 @@ def _module_stack_to_str(module_stack: object) -> str: def extract_results_from_loggers( model: GraphModule, -) -> dict[int, tuple[Optional[str], object, list[object]]]: - """For a given model, extract the tensors stats and related information for each debug handle. +) -> dict[NodeSourceDebugInfo, tuple[Optional[str], object, list[object]]]: + """For a given model, extract the tensors stats and related information for each debug info. The reason we have a list of object, instead of Tensor is because the output of node may not be a Tensor, it could be (nested) list, tuple or dict as well. Returns: - A dict is keyed by the debug_handle id and the values are a list of object recorded + A dict is keyed by the NodeSourceDebugInfo and the values are a list of object recorded in loggers """ - # Results maps debug handle to a tensor list for each model being compared. - handles: dict[int, tuple[Optional[str], object, list[object]]] = {} - for _name, module in model.named_children(): + # Results maps debug info to a tensor list for each model being compared. + handles: dict[NodeSourceDebugInfo, tuple[Optional[str], object, list[object]]] = {} + for _, module in model.named_children(): if isinstance(module, OutputLogger) and len(module.stats) > 0: - handles[module.debug_handle] = ( + handles[module.debug_info] = ( module.node_name, module.nn_module_stack, module.stats, @@ -358,29 +360,33 @@ def extract_results_from_loggers( def compare_results( - ref_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]], - actual_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]], -) -> dict[int, NodeAccuracySummary]: - """Given two dict mapping from `debug_handle_id` (int) to list of tensors - return a map from `debug_handle_id` to `NodeAccuracySummary` that contains + ref_results: dict[ + NodeSourceDebugInfo, tuple[Optional[str], object, list[torch.Tensor]] + ], + actual_results: dict[ + NodeSourceDebugInfo, tuple[Optional[str], object, list[torch.Tensor]] + ], +) -> dict[NodeSourceDebugInfo, NodeAccuracySummary]: + """Given two dict mapping from `NodeSourceDebugInfo` to list of tensors + return a map from `NodeSourceDebugInfo` to `NodeAccuracySummary` that contains comparison information like SQNR, MSE etc. Args: - ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id - actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id + ref_results (Dict[NodeSourceDebugInfo, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug info + actual_results (Dict[NodeSourceDebugInfo, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug info Returns: - Dict[int, NodeAccuracySummary] + Dict[NodeSourceDebugInfo, NodeAccuracySummary] """ comparisons = {} - for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): - if debug_handle not in actual_results: + for debug_info, (ref_name, ref_stack, ref_stats) in ref_results.items(): + if debug_info not in actual_results: log.debug( - "Cannot compare for handle %s because it wasn't found in the transformed model", - debug_handle, + "Cannot compare for debug info %s because it wasn't found in the transformed model", + debug_info, ) continue - actual_name, actual_stack, actual_stats = actual_results[debug_handle] + actual_name, actual_stack, actual_stats = actual_results[debug_info] try: results = [ QuantizationComparisonResult(actual=a, ref=b) @@ -388,13 +394,13 @@ def compare_results( ] except Exception as e: # Add extra information for an exception from QuantizationComparisonResult - # if the shapes didn't match, to include the handle and the node names. + # if the shapes didn't match, to include the debug info and the node names. raise ValueError( - f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" + f"For debug_info={debug_info} from ref node {ref_name} and actual node {actual_name}" ) from e - comparisons[debug_handle] = NodeAccuracySummary( - handle=debug_handle, + comparisons[debug_info] = NodeAccuracySummary( + debug_info=debug_info, actual_node_name=actual_name or "", actual_module_stack=_module_stack_to_str(actual_stack), ref_node_name=ref_name or "", diff --git a/torchao/quantization/pt2e/constant_fold.py b/torchao/quantization/pt2e/constant_fold.py index 27f82e6757..365eb0a77a 100644 --- a/torchao/quantization/pt2e/constant_fold.py +++ b/torchao/quantization/pt2e/constant_fold.py @@ -12,8 +12,6 @@ from torch._inductor.freezing_utils import maybe_set_is_frozen_param from torch.utils._ordered_set import OrderedSet -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - aten = torch.ops.aten # We would like to split modules into two subgraphs for runtime weight updates to work correctly. @@ -162,13 +160,9 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.convert_element_type.no_fuse, + torch.ops.torchao.dequantize_affine, ] - if TORCH_VERSION_AT_LEAST_2_5: - DEQUANT_OPS += [ - torch.ops.torchao.dequantize_affine, - ] - if node.target in DEQUANT_OPS: # For the pattern fp32_weight -> q -> dq # We only folding fp32_weight -> q diff --git a/torchao/quantization/pt2e/convert.py b/torchao/quantization/pt2e/convert.py index 99516ac4c3..7123b0488c 100644 --- a/torchao/quantization/pt2e/convert.py +++ b/torchao/quantization/pt2e/convert.py @@ -49,9 +49,7 @@ ) from torch.ao.quantization.fx.utils import ( _get_module, - assert_and_get_unique_device, collect_producer_nodes, - create_getattr_from_value, graph_module_from_producer_nodes, node_arg_is_weight, ) @@ -69,14 +67,13 @@ from torch.fx import GraphModule from torch.fx.graph import Argument, Graph, Node from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY +from torch.fx.traceback import NodeSource, NodeSourceAction from torch.nn.utils.parametrize import type_before_parametrizations from torchao.quantization.pt2e import FROM_NODE_KEY from torchao.quantization.pt2e.observer import _is_activation_post_process -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 - -if TORCH_VERSION_AT_LEAST_2_6: - from torch.fx.traceback import NodeSource, NodeSourceAction +from torchao.quantization.pt2e.utils import create_getattr_from_value +from torchao.utils import _assert_and_get_unique_device __all__ = [ "convert", @@ -132,6 +129,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node working with decomposed Tensor @@ -188,8 +186,6 @@ def add_dequantize_op_kwargs(dequantize_op, input_node): def add_quantize_dequantize_node_info(qdq_node, original_node): # propagate from_node info from observer/fake_quant node to quantize/dequantize node - if not TORCH_VERSION_AT_LEAST_2_6: - return qdq_node.meta[FROM_NODE_KEY] = [ NodeSource( original_node, @@ -260,7 +256,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node): # sure that the default overload can be used. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -407,6 +407,7 @@ def _replace_observer_with_quantize_dequantize_node( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node @@ -487,7 +488,11 @@ def _replace_observer_with_quantize_dequantize_node( # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -785,6 +790,7 @@ def convert_weighted_module( backend_config: BackendConfig, is_decomposed: bool = False, is_reference: bool = False, + model_device: Optional[torch.device] = None, ) -> None: """Convert a weighted module to reference quantized module in the model If the QConfig of a QAT module is not set, the module will still be converted to @@ -873,7 +879,10 @@ def convert_weighted_module( is_ptq = weight_post_process is None if is_ptq: weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] - device = assert_and_get_unique_device(float_module) + if model_device is not None: + device = model_device + else: + device = _assert_and_get_unique_device(float_module) if device: weight_post_process.to(device) @@ -1076,6 +1085,7 @@ def convert( root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) qat_module_classes = get_qat_module_classes(backend_config) fused_module_classes = get_fused_module_classes(backend_config) + model_device = _assert_and_get_unique_device(model) for node in list(model.graph.nodes): if node.op == "placeholder": @@ -1123,6 +1133,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) else: _replace_observer_with_quantize_dequantize_node( @@ -1131,6 +1142,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) elif isinstance(mod, DeQuantStub): _replace_observer_or_dequant_stub_with_dequantize_node( @@ -1160,6 +1172,7 @@ def convert( backend_config, is_decomposed, is_reference, + model_device, ) # remove deadcode after converting observers to quant/dequant ops @@ -1271,9 +1284,6 @@ def _convert_to_reference_decomposed_fx( reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model) """ - torch._C._log_api_usage_once( - "quantization_api.quantize_fx._convert_to_reference_decomposed_fx" - ) return _convert_fx( graph_module, is_reference=True, diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index dd9f0e6c21..a0aef11541 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -27,12 +27,21 @@ _PER_TENSOR_QUANTIZE_OPS = [ quantized_decomposed.quantize_per_tensor.default, quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.torchao.quantize_affine_float8_non_decomposed.default, ] -_VIEW_OPS = [ +_VIEW_FUNCTION_OPS = [ aten.transpose.int, aten.permute.default, aten.view.default, + aten.reshape.default, +] + +_VIEW_METHOD_OPS = [ + "transpose", + "permute", + "view", + "reshape", ] """ @@ -62,7 +71,13 @@ def _get_pattern_output_dtype(match: Match): output_node = pattern_output_nodes[0] assert isinstance(output_node, torch.fx.Node) output_dtype = output_node.meta["val"].dtype - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float32, + torch.bfloat16, + torch.float8_e4m3fn, + ] return output_dtype @@ -121,7 +136,7 @@ def get_dequantize_per_tensor_activation_pattern( ): if is_fp8: dequantize_per_tensor_activation_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, KeywordArg("x"), KeywordArg("x_scale"), output_dtype=KeywordArg("x_dq_dtype"), @@ -327,20 +342,31 @@ def generate_pattern_with_unary(computation_call, unary_post_op): return computation_call -def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): - quantized_op_output_pattern_pt2e = CallFunction( - quantized_decomposed.quantize_per_tensor.default, - _may_generate_pattern_with_dtype_convert( - computation_call, - Arg(), - with_dtype_convert, - ), - KeywordArg("o_inv_scale"), - KeywordArg("o_zp"), - KeywordArg("o_qmin"), - KeywordArg("o_qmax"), - KeywordArg("o_dtype"), +def generate_pattern_with_output_quant( + computation_call, with_dtype_convert=False, is_fp8=False +): + may_generate_pattern_with_dtype_convert = _may_generate_pattern_with_dtype_convert( + computation_call, + Arg(), + with_dtype_convert, ) + if is_fp8: + quantized_op_output_pattern_pt2e = CallFunction( + torch.ops.torchao.quantize_affine_float8_non_decomposed.default, + may_generate_pattern_with_dtype_convert, + KeywordArg("o_inv_scale"), + float8_dtype=KeywordArg("o_dtype"), + ) + else: + quantized_op_output_pattern_pt2e = CallFunction( + quantized_decomposed.quantize_per_tensor.default, + may_generate_pattern_with_dtype_convert, + KeywordArg("o_inv_scale"), + KeywordArg("o_zp"), + KeywordArg("o_qmin"), + KeywordArg("o_qmax"), + KeywordArg("o_dtype"), + ) return quantized_op_output_pattern_pt2e @@ -447,7 +473,10 @@ def fn(match): (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( extra_input_of_binary_node.target - != quantized_decomposed.dequantize_per_tensor.default + not in [ + quantized_decomposed.dequantize_per_tensor.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] ) ): return False @@ -479,7 +508,7 @@ def fn(match): if "other" in match.kwargs else ( match.kwargs["accum"] - if (output_dtype in [torch.uint8, torch.int8]) + if (output_dtype in [torch.uint8, torch.int8, torch.float8_e4m3fn]) or (not extra_input_from_dequant) else match.kwargs["accum_after_dequant"] ) @@ -501,7 +530,7 @@ def _inner(match): if dequant_pattern_end_node.target not in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, - torch.ops.torchao.dequantize_affine_float8.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, prims.convert_element_type.default, aten.reshape.default, ]: @@ -531,7 +560,7 @@ def _inner(match): in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, - torch.ops.torchao.dequantize_affine_float8.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] and len(list(dequant_pattern_end_node.users)) > 1 ): @@ -598,7 +627,7 @@ def clone_to_new_node(graph, source_node, user_node): assert dequant_pattern_end_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, - torch.ops.torchao.dequantize_affine_float8.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, prims.convert_element_type.default, aten.reshape.default, ] @@ -611,7 +640,7 @@ def _find_first_node_in_dequant_pattern(_node): if _node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, - torch.ops.torchao.dequantize_affine_float8.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ]: # For a dequant pattern, we expect the start node is a dequantize_per_tensor node return _node @@ -628,7 +657,7 @@ def _find_first_node_in_dequant_pattern(_node): assert dequant_pattern_start_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, - torch.ops.torchao.dequantize_affine_float8.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] # Clone the dequant pattern for each user node @@ -731,10 +760,10 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): ) dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] - assert ( - dequant_per_channel.target # type: ignore[union-attr] - is quantized_decomposed.dequantize_per_channel.default - ) + assert dequant_per_channel.target in [ # type: ignore[union-attr] + quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] # Activation QParams qx, x_zp, x_scale = ( @@ -951,6 +980,7 @@ def _inner(match): assert dequant_node.target in [ quantized_decomposed.dequantize_per_tensor.default, quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] if len(list(dequant_node.users)) != 1: @@ -1007,6 +1037,7 @@ def _register_qlinear_weight_prepack_pass( dtype=torch.float32, input_dim_exceeds_two=False, input_contiguous=True, + is_fp8=False, ): @register_freezing_graph_pattern( pattern, @@ -1022,7 +1053,7 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): | dequant_per_tensor | - mm/addmm <- t <- dequant_per_channel <- int8_weight + mm/addmm <- t <- dequant <- int8_weight Insert weight prepack node and change the pattern to: int8 activation @@ -1054,28 +1085,28 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): t_node = linear_node.args[weight_index] if dtype == torch.float32: - dequant_per_channel = t_node.args[0] + dequant = t_node.args[0] else: weight_to_bf16_node = t_node.args[0] - dequant_per_channel = weight_to_bf16_node.args[0] - assert ( - dequant_per_channel.target - is quantized_decomposed.dequantize_per_channel.default - ) + dequant = weight_to_bf16_node.args[0] + assert dequant.target in [ + quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] # Activation QParams - qx, x_zp, x_scale = ( + qx, x_scale = ( kwargs["x"], - kwargs["x_zp"], kwargs["x_scale"], ) # Weight QParams - qw, w_scale, w_zp = ( + qw, w_scale = ( kwargs["q_weight"], kwargs["w_scale"], - kwargs["w_zp"], ) + x_zp = kwargs["x_zp"] if "x_zp" in kwargs else None + w_zp = kwargs["w_zp"] if "w_zp" in kwargs else None # Params bias = kwargs["b"] if "b" in kwargs else None @@ -1112,7 +1143,8 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): "", # post op algorithm ) Node = torch.fx.node.Node - if isinstance(x_scale, Node) and isinstance(x_zp, Node): + # fp8 not need zp + if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8): new_linear_node = graph.call_function( torch.ops.onednn.qlinear_pointwise.tensor, args=new_args ) @@ -1158,7 +1190,7 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): graph.erase_node(t_node) if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] - graph.erase_node(dequant_per_channel) + graph.erase_node(dequant) counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( @@ -1171,6 +1203,7 @@ def _generate_dequant_linear_node_pattern( dtype=torch.float32, input_dim_exceeds_two=False, is_tensor_overload=False, + is_fp8=False, ): assert dtype in [torch.float32, torch.bfloat16] t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) @@ -1180,7 +1213,9 @@ def _generate_dequant_linear_node_pattern( KeywordArg("b"), _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload, is_fp8 + ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1197,7 +1232,9 @@ def _generate_dequant_linear_node_pattern( aten.mm.default, _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload, is_fp8 + ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1217,6 +1254,7 @@ def _generate_dequant_bmm_node_pattern( dtype=torch.float32, with_bias=False, is_tensor_overload=False, + is_fp8=False, ): # When activation of linear dim exceed 2 and not contiguous t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) @@ -1227,7 +1265,9 @@ def _generate_dequant_bmm_node_pattern( CallFunction( aten.expand.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload, is_fp8 + ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1259,20 +1299,32 @@ def _generate_qlinear_weight_prepack_patterns( input_contiguous=True, with_bias=False, is_tensor_overload=False, + is_fp8=False, ): + if is_fp8: + dequant_wgt_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), + ) + else: + dequant_wgt_pattern = dequantize_per_channel_weight_pattern if input_dim_exceeds_two and not input_contiguous: return _generate_dequant_bmm_node_pattern( - dequantize_per_channel_weight_pattern, + dequant_wgt_pattern, dtype, with_bias, is_tensor_overload, + is_fp8=is_fp8, ) else: return _generate_dequant_linear_node_pattern( - dequantize_per_channel_weight_pattern, + dequant_wgt_pattern, dtype, input_dim_exceeds_two, is_tensor_overload, + is_fp8=is_fp8, ) @@ -1442,15 +1494,23 @@ def _register_qlinear_weight_prepack(): # | OPT(add) | linear_weight_prepack_cases = itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) # Step 1: register patterns from mm and addmm - for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: + for ( + dtype, + input_dim_exceeds_two, + is_tensor_overload, + is_fp8, + ) in linear_weight_prepack_cases: + if is_fp8 and not is_tensor_overload: + continue weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( dtype, input_dim_exceeds_two, is_tensor_overload=is_tensor_overload, + is_fp8=is_fp8, ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. @@ -1459,6 +1519,7 @@ def _register_qlinear_weight_prepack(): pass_number=1, dtype=dtype, input_dim_exceeds_two=input_dim_exceeds_two, + is_fp8=is_fp8, ) # Step 2: register patterns from bmm @@ -1467,8 +1528,8 @@ def _register_qlinear_weight_prepack(): # https://github.com/pytorch/pytorch/blob/ # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 # in this case, we can convert it back to qlinear - for dtype, with_bias, is_tensor_overload in itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + for dtype, with_bias, is_tensor_overload, is_fp8 in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ): bmm_pattern = _generate_qlinear_weight_prepack_patterns( dtype=dtype, @@ -1476,6 +1537,7 @@ def _register_qlinear_weight_prepack(): input_contiguous=False, with_bias=with_bias, is_tensor_overload=is_tensor_overload, + is_fp8=is_fp8, ) _register_qlinear_weight_prepack_pass( bmm_pattern, @@ -1485,6 +1547,7 @@ def _register_qlinear_weight_prepack(): dtype=dtype, input_dim_exceeds_two=True, input_contiguous=False, + is_fp8=is_fp8, ) @@ -1753,7 +1816,7 @@ def _with_outer_reshape(pattern): KeywordArg("out_shape_with_bias"), ) - # The following patterns are for torchao int8_dynamic_activation_int8_weight linear, + # The following patterns are for torchao Int8DynamicActivationInt8WeightConfig linear, # when both activation and weights are symmetrically quantized. # In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used. # Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias. @@ -2363,11 +2426,17 @@ def qlinear_post_op_fusion(match: Match, *args, **kwargs): b = kwargs["b"] if "b" in kwargs else None # Output QParams - o_inv_scale = ( - kwargs["o_inv_scale"] - if (output_dtype in [torch.uint8, torch.int8]) - else 1.0 - ) + if output_dtype == torch.float8_e4m3fn: + # For float8, torchao.quantize_affine_float8 requires tensor as scale + # Support scale node is full firstly + assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default + o_inv_scale = kwargs["o_inv_scale"].args[1] + else: + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype in [torch.uint8, torch.int8]) + else 1.0 + ) o_zero_point = ( kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0 ) @@ -2446,105 +2515,114 @@ def _register_qlinear_unary_fusion(): _gelu_fusion_2 as _gelu_fusion_tanh, ) - for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + combinations = itertools.product( + [torch.float32, torch.bfloat16], [False, True], [True, False] + ) + for original_pattern_output_dtype, x_scale_zp_are_tensors, is_fp8 in combinations: is_bf16 = original_pattern_output_dtype == torch.bfloat16 - for x_scale_zp_are_tensors in (False, True): - qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) - computation_op = ( - torch.ops.onednn.qlinear_pointwise.tensor - if x_scale_zp_are_tensors - else torch.ops.onednn.qlinear_pointwise.default - ) - # Priority 1 to match: QLinear Unary pattern with int8 output - linear_unary_replace_patterns = { - PostOpAttr( - "none", None, "none", [], "" - ): generate_pattern_with_output_quant( - qlinear_pattern, - ), - PostOpAttr( - "none", None, "relu", [], "" - ): generate_pattern_with_output_quant( - generate_pattern_with_unary(qlinear_pattern, aten.relu.default), - ), - PostOpAttr( - "none", None, "gelu", [], "none" - ): generate_pattern_with_output_quant( - _unary_fusion_pattern( - _gelu_fusion_erf, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 2 - ), - 2, - is_bf16, + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + qlinear_pattern, + is_fp8=is_fp8, + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + is_fp8=is_fp8, + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 ), - with_dtype_convert=is_bf16, + 2, + is_bf16, ), - PostOpAttr( - "none", None, "gelu", [], "tanh" - ): generate_pattern_with_output_quant( - _unary_fusion_pattern( - _gelu_fusion_tanh, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 4 - ), - 4, - is_bf16, + with_dtype_convert=is_bf16, + is_fp8=is_fp8, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 ), - with_dtype_convert=is_bf16, + 4, + is_bf16, ), - } + with_dtype_convert=is_bf16, + is_fp8=is_fp8, + ), + } - for unary_attr, patterns in linear_unary_replace_patterns.items(): - _register_qlinear_post_op_fusion_pass( - patterns, - 3, # pass_number - computation_op, - unary_attr, # unary_attr - ) + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + computation_op, + unary_attr, # unary_attr + ) - # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output - linear_unary_replace_float_out_patterns = { - PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - qlinear_pattern, aten.relu.default - ), - PostOpAttr( - "none", None, "gelu", [], "none" - ): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _gelu_fusion_erf, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 2 - ), - 2, - is_bf16, + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 ), - Arg(), + 2, is_bf16, ), - PostOpAttr( - "none", None, "gelu", [], "tanh" - ): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _gelu_fusion_tanh, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 4 - ), - 4, - is_bf16, + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 ), - Arg(), + 4, is_bf16, ), - } + Arg(), + is_bf16, + ), + } - for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): - _register_qlinear_post_op_fusion_pass( - patterns, - 4, # pass_number - computation_op, - unary_attr, # unary_attr - ) + for ( + unary_attr, + patterns, + ) in linear_unary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + computation_op, + unary_attr, # unary_attr + ) def _register_qlinear_binary_fusion(): @@ -2610,14 +2688,22 @@ def _register_qlinear_binary_fusion(): # totally 3 patterns (2 are identical) swap_binary_inputs_list = [False, True] int8_mixed_bf16_list = [False, True] + is_fp8_list = [False, True] combinations = itertools.product( unary_postop_list, int8_mixed_bf16_list, swap_binary_inputs_list, convert_dtype_after_binary_list, + is_fp8_list, ) qlinear_binary_replace_patterns = {} - for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + for ( + unary_op, + int8_mixed_bf16, + swap_inputs, + cvt_dtype_binary, + is_fp8, + ) in combinations: if not int8_mixed_bf16 and cvt_dtype_binary: # No convert node after binary node if dtypes are all fp32 continue @@ -2638,6 +2724,7 @@ def _register_qlinear_binary_fusion(): ), unary_postop_dict[unary_op], ), + is_fp8=is_fp8, ) } ) @@ -2764,241 +2851,6 @@ def _register_qlinear_binary_fusion(): ) -def _generate_dequant_fp8_linear_node_pattern(dtype, input_dim_exceeds_two): - # + - - - - | - - - - - - | - - - - + - # | dq_per_tensor dq_per_tensor | - # | | | | - # | OPT(to_bf16) OPT(to_bf16) | - # | | | | - # | OPT(reshape) permute | - # | \ / | - # | addmm/mm | - # | | | - # | OPT(quant_per_tensor) | - # | | | - # | OPT(reshape) | - assert dtype in [torch.float32, torch.bfloat16] - dequant_wgt_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8.default, - KeywordArg("q_weight"), - KeywordArg("w_scale"), - output_dtype=KeywordArg("w_dtype"), - ) - t_pattern = CallFunction( - aten.permute.default, - _may_generate_pattern_with_dtype_convert( - dequant_wgt_pattern, - KeywordArg("autocast_wgt_dtype"), - dtype == torch.bfloat16, - ), - KeywordArg("permute_axes"), - ) - dequantize_per_tensor_activation_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8.default, - KeywordArg("x"), - KeywordArg("x_scale"), - output_dtype=KeywordArg("x_dq_dtype"), - ) - - dequant_fp8_linear_bias_pattern = _may_generate_pattern_with_reshape( - CallFunction( - aten.addmm.default, - KeywordArg("b"), - _may_generate_pattern_with_reshape( - _may_generate_pattern_with_dtype_convert( - dequantize_per_tensor_activation_pattern, - KeywordArg("autocast_act_dtype"), - dtype == torch.bfloat16, - ), - KeywordArg("act_reshape_size"), - input_dim_exceeds_two, - ), - t_pattern, - ), - KeywordArg("output_reshape_size"), - input_dim_exceeds_two, - ) - dequant_fp8_linear_no_bias_pattern = _may_generate_pattern_with_reshape( - CallFunction( - aten.mm.default, - _may_generate_pattern_with_reshape( - _may_generate_pattern_with_dtype_convert( - dequantize_per_tensor_activation_pattern, - KeywordArg("autocast_act_dtype"), - dtype == torch.bfloat16, - ), - KeywordArg("act_reshape_size"), - input_dim_exceeds_two, - ), - t_pattern, - ), - KeywordArg("output_reshape_size"), - input_dim_exceeds_two, - ) - return dequant_fp8_linear_bias_pattern, dequant_fp8_linear_no_bias_pattern - - -def _is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two): - def _inner(match): - input_contiguous = True - # Check dequant pattern has only 1 user. - ( - linear_node, - _, - ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) - - input_index = 1 if linear_node.target is aten.addmm.default else 0 - assert dtype in [torch.float32, torch.bfloat16] - ( - dequant_node, - _, - _, - _, - ) = _get_linear_dq_node( - linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous - ) - assert dequant_node.target is torch.ops.torchao.dequantize_affine_float8.default - - # only support float8_e4m3 input - if dequant_node.meta["eager_input_vals"][0][0].dtype != torch.float8_e4m3fn: - return False - - if len(list(dequant_node.users)) != 1: - # Ensure the dequant pattern only has 1 user - # since we will delete the dequant pattern here - return False - - return True - - return _inner - - -def _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two): - @register_freezing_graph_pattern( - pattern, - extra_check=_is_valid_scaled_mm_pattern(dtype, input_dim_exceeds_two), - pass_number=1, - ) - def scaled_mm_fusion(match: Match, *args, **kwargs): - input_contiguous = True - assert dtype in [torch.float32, torch.bfloat16] - ( - linear_node, - output_reshape_node, - ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) - input_index = 1 if linear_node.target is aten.addmm.default else 0 - weight_index = input_index + 1 - - ( - dequant_node, - act_reshape_node, - activation_to_bf16_node, - act_expand_node, - ) = _get_linear_dq_node( - linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous - ) - - if input_dim_exceeds_two and not input_contiguous: - wgt_expand_node = linear_node.args[weight_index] - assert wgt_expand_node.target is aten.expand.default - t_node = wgt_expand_node.args[0] - else: - t_node = linear_node.args[weight_index] - - if dtype == torch.float32: - dequant_per_tensor = t_node.args[0] - else: - weight_to_bf16_node = t_node.args[0] - dequant_per_tensor = weight_to_bf16_node.args[0] - assert ( - dequant_per_tensor.target - is torch.ops.torchao.dequantize_affine_float8.default - ) - - # Activation QParams - qx, x_scale = ( - kwargs["x"], - kwargs["x_scale"], - ) - - # Weight QParams - qw, w_scale = ( - kwargs["q_weight"], - kwargs["w_scale"], - ) - - # Params - bias = kwargs["b"] if "b" in kwargs else None - - x_shape = qx.meta.get("tensor_meta").shape - if has_free_symbols(x_shape): - # For dynamic shape case, we can't get activation shape ahead of runtime. - x_shape = None - graph = match.graph - with graph.inserting_before(linear_node): - scaled_mm_input_node = qx - if input_dim_exceeds_two: - new_reshape_args: tuple[Any, ...] = (qx, act_reshape_node.args[1]) - new_act_reshape_node = graph.call_function( - torch.ops.aten.reshape.default, args=new_reshape_args - ) - scaled_mm_input_node = new_act_reshape_node - # Insert weight prepack node and the qlinear node - permute_weight_inputs = ( - qw, - t_node.args[1], - ) - permute_weight_op = torch.ops.aten.permute.default - permute_weight_node = graph.call_function( - permute_weight_op, args=permute_weight_inputs - ) - output_scale = torch.tensor(1.0) - new_args: tuple[Any, ...] = ( - scaled_mm_input_node, - permute_weight_node, - x_scale, - w_scale, - bias, - output_scale, # output_scale - dtype, # output_dtype - False, # use_fast_accum - ) - new_linear_node = graph.call_function( - torch.ops.aten._scaled_mm.default, args=new_args - ) - - linear_node.replace_all_uses_with(new_linear_node) - new_linear_node.meta.update(linear_node.meta) - - graph.erase_node(linear_node) - if input_dim_exceeds_two: - graph.erase_node(act_reshape_node) - if dtype == torch.bfloat16: - graph.erase_node(activation_to_bf16_node) - # Erase the dequant pattern - graph.erase_node(dequant_node) - # Erase the dequant per channel pattern - graph.erase_node(t_node) - if dtype == torch.bfloat16: - graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] - graph.erase_node(dequant_per_tensor) - - counters["inductor"]["scaled_mm_matcher_count"] += 1 - counters["inductor"]["scaled_mm_matcher_nodes"] += len(match.nodes) - - -def _register_scaled_mm(): - fp8_linear_weight_prepack_cases = itertools.product( - [torch.float32, torch.bfloat16], [False, True] - ) - for dtype, input_dim_exceeds_two in fp8_linear_weight_prepack_cases: - patterns = _generate_dequant_fp8_linear_node_pattern( - dtype, input_dim_exceeds_two - ) - for pattern in patterns: - _register_scaled_mm_pass(pattern, dtype, input_dim_exceeds_two) - - @functools.lru_cache(None) def _register_quantization_weight_pack_pass(): # Step 1: Dequant promotion for int8-mixed-fp32/bf16 @@ -3022,8 +2874,6 @@ def _register_quantization_weight_pack_pass(): _register_qlinear_unary_fusion() _register_qlinear_binary_fusion() - _register_scaled_mm() - def quant_lift_up(module_graph: torch.fx.graph.Graph): """ @@ -3061,7 +2911,23 @@ def quant_lift_up(module_graph: torch.fx.graph.Graph): """ def is_view_op(node): - return node.op == "call_function" and node.target in _VIEW_OPS + return (node.op == "call_function" and node.target in _VIEW_FUNCTION_OPS) or ( + node.op == "call_method" and node.target in _VIEW_METHOD_OPS + ) + + def quant_input_check(node): + if len(node.all_input_nodes) == 1: + return True + elif ( + node.target + == torch.ops.torchao.quantize_affine_float8_non_decomposed.default + ): + # check if scale created by torch.tensor + return ( + len(node.all_input_nodes) == 2 + and node.all_input_nodes[1].target == torch.tensor + ) + return False for node in module_graph.nodes: # Leslie: Here we verify that the quant node has exactly @@ -3072,23 +2938,23 @@ def is_view_op(node): if ( node.op == "call_function" and node.target in _PER_TENSOR_QUANTIZE_OPS - and len(node.all_input_nodes) == 1 + and quant_input_check(node) and is_view_op(node.all_input_nodes[0]) ): quant_node = node - input_node_of_quant = quant_node.args[0] + input_node_of_quant = quant_node.all_input_nodes[0] # Check the nodes along lift up path has only 1 user node # Propagate view like node to find where to insert the new quant node could_lift_up = True current_node = quant_node - input_node = current_node.args[0] + input_node = current_node.all_input_nodes[0] while is_view_op(input_node): if len(input_node.users) != 1: could_lift_up = False break current_node = input_node - input_node = current_node.args[0] + input_node = current_node.all_input_nodes[0] # Further check the input node of the first view node has only 1 user node if could_lift_up and len(input_node.users) == 1: diff --git a/torchao/quantization/pt2e/lowering.py b/torchao/quantization/pt2e/lowering.py index 76dad800cd..c0b4a3538b 100644 --- a/torchao/quantization/pt2e/lowering.py +++ b/torchao/quantization/pt2e/lowering.py @@ -55,7 +55,7 @@ def _node_replace(m): # type: ignore[no-untyped-def] m.recompile() lowered_model = ( - torch.export.export_for_training(model, example_inputs, strict=True) + torch.export.export(model, example_inputs, strict=True) .run_decompositions(_post_autograd_decomp_table()) .module() ) diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index b781f5a07e..de906f2f61 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -27,6 +27,7 @@ from torch.fx import Node import torchao +from torchao.quantization import Granularity from torchao.quantization.pt2e.utils import ( calculate_qmin_qmax, check_min_max_valid, @@ -67,17 +68,9 @@ "ReuseInputObserver", "UniformQuantizationObserverBase", "AffineQuantizedObserverBase", - "Granularity", "MappingType", - "PerAxis", - "PerBlock", - "PerGroup", - "PerRow", - "PerTensor", - "PerToken", "TorchAODType", "ZeroPointDomain", - "get_block_size", ] @@ -1248,7 +1241,7 @@ def _combine_histograms( # If the orig hist only has one value (i.e., the min and max are the same) # we can just add it into new histogram if orig_min == orig_max: - bin_value = torch.sum(update_hist) + bin_value = torch.sum(orig_hist) transformed_orig_hist = ( torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type] * bin_value @@ -1622,7 +1615,6 @@ def calculate_qparams(self): We plan to merge the following with torchao repo after we move pt2e flow to torchao copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py """ -from dataclasses import dataclass from enum import Enum, auto @@ -1679,139 +1671,6 @@ class TorchAODType(Enum): INT7 = auto() -@dataclass(frozen=True) -class Granularity: - """ - Base class for representing the granularity of quantization. - - This class serves as a parent for specific granularity types used in - quantization operations, such as per-tensor or per-axis quantization. - """ - - -@dataclass(frozen=True) -class PerBlock(Granularity): - """ - Represents per-block granularity in quantization. See - :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for - `block_size` - - Attributes: - block_size (Tuple[int, ...]): The size of each quantization group - """ - - block_size: tuple[int, ...] - - -@dataclass(frozen=True) -class PerTensor(Granularity): - """ - Represents per-tensor granularity in quantization. - - This granularity type calculates the quantization parameters - based off the entire tensor. - - """ - - -@dataclass(frozen=True) -class PerAxis(Granularity): - """ - Represents per-axis granularity in quantization. - - This granularity type calculates different quantization parameters - along a specified axis of the tensor. - - For example if the input tensor is shape [8, 16] and axis=0, then - the quantization parameters are calculated for each row of the tensor. - Giving a total of 8 quantization parameters. - - Attributes: - axis (int): The axis along which reduction is performed. - """ - - axis: int - - -@dataclass(frozen=True) -class PerGroup(Granularity): - """ - Represents per-channel group granularity in quantization. - - This granularity type calculates different quantization parameters - for each group of elements. - - For example if the input tensor is shape [8, 16], and the group size is 4, then - the input tensor is reshaped to [64, 4] - quantization parameters are calculated for each group of 4 elements, - giving a total of 64 quantization parameters. - - Attributes: - group_size (int): The size of each quantization group - - """ - - group_size: int - - -class PerRow(Granularity): - """ - Represents row-wise granularity in quantization. - - This is a special case of per-axis quantization and is unique to Float8 matmuls - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight - is quantized with a block_size of (1, weight.shape[1]). - """ - - -class PerToken(Granularity): - """ - Represents per-token granularity in quantization. - - This granularity type calculates a different set of quantization parameters - for each token, which is represented as the last dimension of the tensor. - - For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens - with 4 elements each, and we will calculate 6 sets of quantization parameters, - one for each token. - - If the input tensor has only two dimensions, e.g. [8, 16], then this is - equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. - """ - - -def get_block_size( - input_shape: tuple[int, ...], granularity: Granularity -) -> tuple[int, ...]: - """Get the block size based on the input shape and granularity type. - - Args: - input_shape: The input tensor shape possibly more than 2 dimensions - granularity: The granularity type of the quantization - """ - assert isinstance(granularity, Granularity), ( - "Please provide an instance of Granularity, not subclass of it" - ) - if isinstance(granularity, PerTensor): - return input_shape - elif isinstance(granularity, PerAxis): - block_size = list(input_shape) - block_size[granularity.axis] = 1 - return tuple(block_size) - elif isinstance(granularity, PerRow): - return (1,) * (len(input_shape) - 1) + (input_shape[-1],) - elif isinstance(granularity, PerGroup): - assert len(input_shape) == 2, ( - f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" - ) - return (1, granularity.group_size) - elif isinstance(granularity, PerToken): - block_size = [1] * len(input_shape) - block_size[-1] = input_shape[-1] - return tuple(block_size) - raise ValueError(f"Unsupported Granularity: {granularity}") - - class AffineQuantizedObserverBase(ABC, torch.nn.Module): """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) @@ -1877,13 +1736,6 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): observer_node: the observer node to convert """ - from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - - if not TORCH_VERSION_AT_LEAST_2_5: - raise NotImplementedError( - "convert for AffineQuantization is not implemented for pytorch version earlier than 2.5, please upgrade your pytorch to 2.5+." - ) - from torchao.quantization.pt2e.utils import create_getattr_from_value with model.graph.inserting_before(observer_node): @@ -1915,10 +1767,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): else: scale, zero_point = self.calculate_qparams() scale_node = create_getattr_from_value( - model, model.graph, "_scale", scale + model, + model.graph, + "_scale", + scale, + scale.device if isinstance(scale, torch.Tensor) else None, ) zero_point_node = create_getattr_from_value( - model, model.graph, "_zero_point", zero_point + model, + model.graph, + "_zero_point", + zero_point, + zero_point.device if isinstance(zero_point, torch.Tensor) else None, ) q_node = model.graph.call_function( diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index 97801f993c..fa9869c915 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -13,10 +13,7 @@ from torch._subclasses import FakeTensor from torch.ao.quantization import QConfigMapping from torch.ao.quantization.fx.custom_config import PrepareCustomConfig -from torch.ao.quantization.fx.prepare import ( - _insert_obs_or_fq, - _save_state, -) +from torch.ao.quantization.fx.prepare import _insert_obs_or_fq, _save_state from torch.ao.quantization.qconfig import QConfigAny from torch.fx import Graph, GraphModule, Node from torch.fx.node import Argument @@ -26,9 +23,7 @@ DerivedObserverOrFakeQuantize, ObserverOrFakeQuantize, ) -from torchao.quantization.pt2e.fake_quantize import ( - FixedQParamsFakeQuantize, -) +from torchao.quantization.pt2e.fake_quantize import FixedQParamsFakeQuantize from torchao.quantization.pt2e.observer import ( FixedQParamsObserver, PartialWrapper, @@ -42,7 +37,8 @@ QuantizationSpecBase, SharedQuantizationSpec, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +from torchao.utils import _assert_and_get_unique_device # TODO: make pt2e folder private? __all__ = [ @@ -208,8 +204,8 @@ def _get_edge_or_node_to_qspec( """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes""" edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {} for n in model.graph.nodes: - if hasattr(n, "meta") and "quantization_annotation" in n.meta: - qa = n.meta["quantization_annotation"] + if hasattr(n, "meta") and Q_ANNOTATION_KEY in n.meta: + qa = n.meta[Q_ANNOTATION_KEY] for input_to_n, qspec in qa.input_qspec_map.items(): input_edge = (input_to_n, n) edge_or_node_to_qspec[input_edge] = qspec @@ -324,7 +320,7 @@ def _get_edge_or_node_to_group_id( assert isinstance(input_edge, tuple) arg, n = input_edge - if n.meta["quantization_annotation"].allow_implicit_sharing: + if n.meta[Q_ANNOTATION_KEY].allow_implicit_sharing: # NOTE: the order is important here, we first share with other users and then share with previous # output because the reverse order could cause circular dependency # e.g node1 -> node2 @@ -413,6 +409,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Argument: """ Given a `node` and an `arg`, inserts an input observer between @@ -431,6 +428,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_arg_to_return.append(new_inner_arg) return type(arg)(new_arg_to_return) @@ -483,6 +481,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( return maybe_obs_node assert isinstance(model.graph, Graph) + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_arg = _insert_obs_or_fq( arg, input_edge_obs_or_fq, model, named_modules, model.graph ) @@ -496,6 +495,7 @@ def _maybe_insert_input_observers_for_node( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> None: """ If needed, inserts observers to the input args and kwargs of `node`. @@ -522,6 +522,7 @@ def _maybe_insert_input_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_args.append(new_arg) @@ -546,9 +547,11 @@ def _maybe_insert_output_observer_for_node( graph: Graph, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Optional[Node]: if node in obs_or_fq_map: output_act_obs_or_fq = obs_or_fq_map[node] + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_output = _insert_obs_or_fq( node, output_act_obs_or_fq, model, named_modules, graph ) @@ -557,7 +560,6 @@ def _maybe_insert_output_observer_for_node( isinstance(node, Node) and isinstance(new_output, Node) and FROM_NODE_KEY in node.meta - and TORCH_VERSION_AT_LEAST_2_6 ): new_output.meta[FROM_NODE_KEY] = node.meta[FROM_NODE_KEY] return new_output @@ -569,11 +571,10 @@ def _maybe_insert_input_and_output_observers_for_node( model: torch.fx.GraphModule, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ): this_node_quantization_annotation = ( - node.meta["quantization_annotation"] - if "quantization_annotation" in node.meta - else None + node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None ) if this_node_quantization_annotation is None: return @@ -586,6 +587,7 @@ def _maybe_insert_input_and_output_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) @@ -594,7 +596,13 @@ def _maybe_insert_input_and_output_observers_for_node( # this returns the new observer node if it was needed maybe_output_obs_node = _maybe_insert_output_observer_for_node( - node, model, named_modules, model.graph, obs_or_fq_map, is_qat + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + model_device, ) if maybe_output_obs_node is None: @@ -642,11 +650,16 @@ def prepare( ) if obs_or_fq_callback: obs_or_fq_callback(model, obs_or_fq_map) + model_device = _assert_and_get_unique_device(model) for node in nodes_before_observation: # TODO: simplify logic for inserting observers _maybe_insert_input_and_output_observers_for_node( - node, model, obs_or_fq_map, is_qat + node, + model, + obs_or_fq_map, + is_qat, + model_device, ) model = GraphModule(model, model.graph) diff --git a/torchao/quantization/pt2e/quantize_pt2e.py b/torchao/quantization/pt2e/quantize_pt2e.py index 5eb385b7de..8a7314359b 100644 --- a/torchao/quantization/pt2e/quantize_pt2e.py +++ b/torchao/quantization/pt2e/quantize_pt2e.py @@ -6,9 +6,9 @@ import torch -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least -if TORCH_VERSION_AT_LEAST_2_7: +if torch_version_at_least("2.7.0"): from .constant_fold import constant_fold from typing import Union @@ -46,7 +46,7 @@ def prepare_pt2e( """Prepare a model for post training quantization Args: - * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API. + * `model` (torch.fx.GraphModule): a model captured by `torch.export.export` API. * `quantizer`: A backend specific quantizer that conveys how user want the model to be quantized. Tutorial for how to write a quantizer can be found here: https://pytorch.org/tutorials/prototype/pt2e_quantizer.html @@ -84,7 +84,7 @@ def calibrate(model, data_loader): # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured # result shoud mostly stay the same - m = torch.export.export_for_training(m, *example_inputs).module() + m = torch.export.export(m, *example_inputs).module() # we get a model with aten ops # Step 2. quantization @@ -106,7 +106,7 @@ def calibrate(model, data_loader): return torch_prepare_pt2e(model, quantizer) - torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") + torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_pt2e") original_graph_meta = model.meta node_name_to_scope = _get_node_name_to_scope(model) # TODO: check qconfig_mapping to make sure conv and bn are both configured @@ -169,7 +169,7 @@ def train_loop(model, train_data): # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured # result shoud mostly stay the same - m = torch.export.export_for_training(m, *example_inputs).module() + m = torch.export.export(m, *example_inputs).module() # we get a model with aten ops # Step 2. quantization @@ -192,7 +192,7 @@ def train_loop(model, train_data): return torch_prepare_qat_pt2e(model, quantizer) - torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") + torch._C._log_api_usage_once("torchao.quantization.pt2e.prepare_qat_pt2e") original_graph_meta = model.meta node_name_to_scope = _get_node_name_to_scope(model) model = quantizer.transform_for_annotation(model) @@ -217,14 +217,9 @@ def train_loop(model, train_data): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.torchao.quantize_affine, ] -# ops are only registered after 2.5 -if TORCH_VERSION_AT_LEAST_2_5: - _QUANT_OPS += [ - torch.ops.torchao.quantize_affine, - ] - def _quant_node_constraint(n: Node) -> bool: """If there is any pure ops between get_attr and quantize op they will be const propagated @@ -309,7 +304,7 @@ def convert_pt2e( return torch_convert_pt2e(model, use_reference_representation, fold_quantize) - torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") + torch._C._log_api_usage_once("torchao.quantization.pt2e.convert_pt2e") if not isinstance(use_reference_representation, bool): raise ValueError( "Unexpected argument type for `use_reference_representation`, " @@ -325,7 +320,7 @@ def convert_pt2e( pm = PassManager([PortNodeMetaForQDQ()]) model = pm(model).graph_module - if fold_quantize and TORCH_VERSION_AT_LEAST_2_7: + if fold_quantize and torch_version_at_least("2.7.0"): constant_fold(model, _quant_node_constraint) if use_reference_representation: diff --git a/torchao/quantization/pt2e/quantizer/composable_quantizer.py b/torchao/quantization/pt2e/quantizer/composable_quantizer.py index 6602151e3f..2fd53117d7 100644 --- a/torchao/quantization/pt2e/quantizer/composable_quantizer.py +++ b/torchao/quantization/pt2e/quantizer/composable_quantizer.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY + from .quantizer import QuantizationAnnotation, Quantizer if TYPE_CHECKING: @@ -48,18 +50,17 @@ def _record_and_validate_annotations( self, gm: torch.fx.GraphModule, quantizer: Quantizer ) -> None: for n in gm.graph.nodes: - if "quantization_annotation" in n.meta: + if Q_ANNOTATION_KEY in n.meta: # check if the annotation has been changed by # comparing QuantizationAnnotation object id if n in self._graph_annotations and ( - id(self._graph_annotations[n]) - != id(n.meta["quantization_annotation"]) + id(self._graph_annotations[n]) != id(n.meta[Q_ANNOTATION_KEY]) ): raise RuntimeError( f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}" ) else: - self._graph_annotations[n] = n.meta["quantization_annotation"] + self._graph_annotations[n] = n.meta[Q_ANNOTATION_KEY] else: if n in self._graph_annotations: raise RuntimeError( diff --git a/torchao/quantization/pt2e/quantizer/duplicate_dq_pass.py b/torchao/quantization/pt2e/quantizer/duplicate_dq_pass.py index 2bf4e732c1..3e2d36e88e 100644 --- a/torchao/quantization/pt2e/quantizer/duplicate_dq_pass.py +++ b/torchao/quantization/pt2e/quantizer/duplicate_dq_pass.py @@ -12,13 +12,10 @@ from torch.fx.node import map_arg from torch.fx.passes.infra.pass_base import PassBase, PassResult -from torchao.quantization.pt2e.utils import ( - _filter_sym_size_users, -) +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +from torchao.quantization.pt2e.utils import _filter_sym_size_users -from .utils import ( - is_valid_annotation, -) +from .utils import is_valid_annotation logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -41,7 +38,7 @@ def _maybe_duplicate_dq( gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node ): - annotation = user.meta.get("quantization_annotation", None) + annotation = user.meta.get(Q_ANNOTATION_KEY, None) if not is_valid_annotation(annotation): return with gm.graph.inserting_after(dq_node): diff --git a/torchao/quantization/pt2e/quantizer/embedding_quantizer.py b/torchao/quantization/pt2e/quantizer/embedding_quantizer.py index 40979f6fe8..fdd7ccebd1 100644 --- a/torchao/quantization/pt2e/quantizer/embedding_quantizer.py +++ b/torchao/quantization/pt2e/quantizer/embedding_quantizer.py @@ -21,6 +21,7 @@ QuantizationSpec, Quantizer, ) +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY __all__ = [ "get_embedding_operators_config", @@ -87,7 +88,7 @@ def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None: raise ValueError( "Embedding config must have a valid weight quantization spec." ) - node.meta["quantization_annotation"] = QuantizationAnnotation( + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map={ node.args[0]: embedding_config.config.weight, } diff --git a/torchao/quantization/pt2e/quantizer/port_metadata_pass.py b/torchao/quantization/pt2e/quantizer/port_metadata_pass.py index b0d910e603..5e7e9344ee 100644 --- a/torchao/quantization/pt2e/quantizer/port_metadata_pass.py +++ b/torchao/quantization/pt2e/quantizer/port_metadata_pass.py @@ -12,18 +12,12 @@ from torch._export.error import InternalError from torch.fx.passes.infra.pass_base import PassBase, PassResult -from torchao.quantization.pt2e.utils import ( - _filter_sym_size_users, -) +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +from torchao.quantization.pt2e.utils import _filter_sym_size_users from torchao.quantization.quant_primitives import quant_lib # noqa: F401 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -from .quantizer import ( - QuantizationSpecBase, -) -from .utils import ( - is_valid_annotation, -) +from .quantizer import QuantizationSpecBase +from .utils import is_valid_annotation logger = logging.getLogger(__name__) logger.setLevel(logging.ERROR) @@ -39,27 +33,23 @@ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.torchao.quantize_affine, ] _DEQUANTIZE_OPS = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.torchao.dequantize_affine, ] _CHOOSE_QPARAMS_OPS = [ torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, + torch.ops.torchao.choose_qparams_affine, ] -# ops are only registered after 2.5 -if TORCH_VERSION_AT_LEAST_2_5: - _QUANTIZE_OPS += [torch.ops.torchao.quantize_affine] - _DEQUANTIZE_OPS += [torch.ops.torchao.dequantize_affine] - _CHOOSE_QPARAMS_OPS += [torch.ops.torchao.choose_qparams_affine] - - def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: from_meta = from_node.meta for meta_name in _METADATA_TO_PORT: @@ -68,7 +58,7 @@ def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: def _has_quant_annotation(node: torch.fx.Node) -> bool: - return "quantization_annotation" in node.meta + return Q_ANNOTATION_KEY in node.meta def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: @@ -281,10 +271,10 @@ class PortNodeMetaForQDQ(PassBase): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in graph_module.graph.nodes: - annotation = node.meta.get("quantization_annotation", None) + annotation = node.meta.get(Q_ANNOTATION_KEY, None) if is_valid_annotation(annotation): - input_qspec_map = node.meta["quantization_annotation"].input_qspec_map - output_qspec = node.meta["quantization_annotation"].output_qspec + input_qspec_map = node.meta[Q_ANNOTATION_KEY].input_qspec_map + output_qspec = node.meta[Q_ANNOTATION_KEY].output_qspec for input_node, qspec in input_qspec_map.items(): _port_metadata_for_input_quant_nodes(input_node, node, qspec) _port_metadata_for_output_quant_nodes(node, output_qspec) diff --git a/torchao/quantization/pt2e/quantizer/quantizer.py b/torchao/quantization/pt2e/quantizer/quantizer.py index 479a2a678f..1f0916f59c 100644 --- a/torchao/quantization/pt2e/quantizer/quantizer.py +++ b/torchao/quantization/pt2e/quantizer/quantizer.py @@ -30,6 +30,9 @@ ] +Q_ANNOTATION_KEY = "quantization_annotation" + + class QuantizationSpecBase(ABC): # noqa: B024 """Base class for different types of quantization specs that allows users to specify how to quantize a Tensor (input/output of a Node) in the model diff --git a/torchao/quantization/pt2e/quantizer/utils.py b/torchao/quantization/pt2e/quantizer/utils.py index f84ae44817..8f493a8521 100644 --- a/torchao/quantization/pt2e/quantizer/utils.py +++ b/torchao/quantization/pt2e/quantizer/utils.py @@ -13,6 +13,8 @@ import torch from torch.fx import Node +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY + from .quantizer import QuantizationAnnotation, QuantizationSpec @@ -103,21 +105,17 @@ def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): def annotate_input_qspec_map(node: Node, input_node: Node, qspec): - quantization_annotation = node.meta.get( - "quantization_annotation", QuantizationAnnotation() - ) + quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, QuantizationAnnotation()) if quantization_annotation.input_qspec_map is None: quantization_annotation.input_qspec_map = {} quantization_annotation.input_qspec_map[input_node] = qspec - node.meta["quantization_annotation"] = quantization_annotation + node.meta[Q_ANNOTATION_KEY] = quantization_annotation def annotate_output_qspec(node: Node, qspec): - quantization_annotation = node.meta.get( - "quantization_annotation", QuantizationAnnotation() - ) + quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, QuantizationAnnotation()) quantization_annotation.output_qspec = qspec - node.meta["quantization_annotation"] = quantization_annotation + node.meta[Q_ANNOTATION_KEY] = quantization_annotation def get_module_name_filter(module_name: str): diff --git a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py index 84a66447c1..656f4fbbeb 100644 --- a/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py +++ b/torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py @@ -1634,8 +1634,8 @@ def validate(self, model: torch.fx.GraphModule) -> None: _register_quantization_weight_pack_pass, quant_lift_up, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8 +from torchao.utils import torch_version_at_least -if TORCH_VERSION_AT_LEAST_2_8: +if torch_version_at_least("2.8.0"): torch._inductor.config.pre_grad_custom_pass = quant_lift_up _register_quantization_weight_pack_pass() diff --git a/torchao/quantization/pt2e/reference_representation_rewrite.py b/torchao/quantization/pt2e/reference_representation_rewrite.py index 6526c6044f..8df9f5537d 100644 --- a/torchao/quantization/pt2e/reference_representation_rewrite.py +++ b/torchao/quantization/pt2e/reference_representation_rewrite.py @@ -8,13 +8,14 @@ import contextlib from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional import torch from torch._higher_order_ops.out_dtype import out_dtype from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torch.fx import GraphModule -from torch.fx.subgraph_rewriter import replace_pattern +from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch +from torch.fx.subgraph_rewriter import ReplacedPatterns, replace_pattern_with_filters from torchao.quantization.pt2e.export_utils import WrapperModule from torchao.quantization.pt2e.utils import ( @@ -23,12 +24,17 @@ _replace_literals_with_new_placeholders, remove_tensor_overload_for_qdq_ops, ) +from torchao.quantization.quant_primitives import MappingType +from torchao.quantization.utils import _get_per_token_block_size +from torchao.utils import _register_custom_op try: from torch._export.utils import _disable_aten_to_metadata_assertions except: _disable_aten_to_metadata_assertions = contextlib.nullcontext +quant_lib = torch.library.Library("torchao", "FRAGMENT") +register_custom_op = _register_custom_op(quant_lib) __all__ = [ "reference_representation_rewrite", @@ -203,6 +209,280 @@ def _reference_dynamic_quantized_linear( return out_fp32 +def _qdq_dynamic_quantized_linear_4bit_groupwise( + x_fp32, + x_eps, + weight_i4, + weight_scale, + weight_zero_point, + bias_fp32, + group_size, +): + # Dynamic quantization of activation + x_mapping_type = MappingType.ASYMMETRIC + per_token_block_size = _get_per_token_block_size(x_fp32) + x_quant_min = -128 + x_quant_max = 127 + x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine( + x_fp32, + x_mapping_type.name, + per_token_block_size, + torch.int8, + x_quant_min, + x_quant_max, + x_eps, + torch.float32, + torch.int32, + ) + x_i8 = torch.ops.torchao.quantize_affine( + x_fp32, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + ) + x_fp32 = torch.ops.torchao.dequantize_affine( + x_i8, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + torch.float32, + ) + + assert group_size > 0, "Group size must be positive" + assert weight_i4.shape[1] % group_size == 0, ( + "Weight must be divisible by group_size" + ) + assert weight_i4.dim() == 2, "Weight must be 2D tensor" + block_size = (1, group_size) + weight_fp32 = torch.ops.torchao.dequantize_affine( + weight_i4, + block_size, + weight_scale, + weight_zero_point, + torch.int8, + -8, + 7, + ) + + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + return out_fp32 + + +@register_custom_op +def _reference_dqlinear_int4( + x_fp32: torch.Tensor, + x_eps: float, + weight_i4: torch.Tensor, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, # Not used because assuming weight is symmetric + bias_fp32: Optional[torch.Tensor], + group_size: List[int], +) -> torch.Tensor: + """ + Reference implementation for dynamically quantized linear 4-bit groupwise operation. + This implementation emulates actual numerics of on-device integer compute. + + Args: + x_fp32: Input activation tensor in fp32 + x_eps: Epsilon for quantization parameter computation + weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7]) + weight_scale: Groupwise scales for weight dequantization + weight_zero_point: Groupwise zero points for weight (unused for symmetric) + bias_fp32: Optional bias tensor in fp32 + group_size: Size of each group for groupwise quantization + + Returns: + Output tensor in fp32 + """ + # Dynamic quantization of activation + group_size = group_size[1] + x_mapping_type = MappingType.ASYMMETRIC + per_token_block_size = _get_per_token_block_size(x_fp32) + x_quant_min = -128 + x_quant_max = 127 + x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine( + x_fp32, + x_mapping_type.name, + per_token_block_size, + torch.int8, + x_quant_min, + x_quant_max, + x_eps, + torch.float32, + torch.int32, + ) + x_i8 = torch.ops.torchao.quantize_affine( + x_fp32, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + ) + + # For groupwise quantization, we need to handle the computation differently + # weight_i4 shape: [out_features, in_features] + # weight_scale shape: [out_features, in_features // group_size] + # weight_zero_point shape: [out_features, in_features // group_size] + out_features, in_features = weight_i4.shape + num_groups = in_features // group_size + + # scales in xnnpack are stored as bf16 and converted to fp32 for computation + weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32) + + # Reshape for group-wise processing + # x: [batch_size, in_features] -> [batch_size, num_groups, group_size] + x_orig_shape = x_i8.shape + k_dim = x_i8.shape[-1] + x_i8 = x_i8.view(-1, k_dim) + batch_size = x_i8.shape[0] + x_i8_grouped = x_i8.view(batch_size, num_groups, group_size) + + # weight: [out_features, in_features] -> [out_features, num_groups, group_size] + weight_i4_grouped = weight_i4.view(out_features, num_groups, group_size) + + # Convert to int16 for computation + x_i32_grouped = x_i8_grouped.to(torch.int32) + weight_i32_grouped = weight_i4_grouped.to(torch.int32) + + # Perform groupwise integer linear operation + acc_fp32 = torch.zeros( + batch_size, out_features, dtype=torch.float32, device=x_fp32.device + ) + out_shape = list(x_orig_shape) + out_shape[-1] = out_features + + if weight_scale.ndim == 1: + weight_scale = weight_scale.unsqueeze(0) + + for group_idx in range(num_groups): + # Extract current group + x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size] + weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size] + weight_group_col_sum = weight_group.sum(dim=-1) # [out_features] + + # Get scale for this group + weight_scale_group = weight_scale[:, group_idx] # [out_features] + + # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features] + group_acc = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_group, + weight_group, + None, + ) + + # Output has to be scaled by x_scale * weight_scale_group + # However we will first scale by weight_scale_group, that is accounting + # only for scale of weight, and then scale by x_scale at the end because + # x_scale applies to all groups + acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view( + 1, -1 + ) + + # we must also subtract x_zero_point * weight_group_sum + # since (X - x_zero_point) * W = X * W - x_zero_point * W + weights_col_sum_adjusted = ( + weight_group_col_sum.to(torch.float32).view(1, -1) + * x_zero_point.view(-1, 1) + * weight_scale_group.view(1, -1) + ) + acc_fp32 = acc_fp32 - weights_col_sum_adjusted + x_scale_multiplier = x_scale.view(-1, 1) + out_fp32 = acc_fp32 * x_scale_multiplier + if bias_fp32 is not None: + out_fp32 = out_fp32 + bias_fp32 + + return out_fp32.view(out_shape) + + +def _reference_dynamic_quantized_linear_4bit_groupwise( + x_fp32, + x_eps, + weight_i4, + weight_scale, + weight_zero_point, # Not used because assuming weight is symmetric + bias_fp32, + group_size, +): + """ + Reference implementation for dynamically quantized linear 4-bit groupwise operation. + This function now delegates to the custom op implementation. + """ + return torch.ops.torchao.reference_dqlinear_int4( + x_fp32, + x_eps, + weight_i4, + weight_scale, + weight_zero_point, + bias_fp32, + (1, group_size), + ) + + +def _filter_fn_for_dynamic_quantized_linear_4bit_groupwise( + match, + original_graph, + pattern_graph, +) -> bool: + weight_is_int4 = False + act_quant_is_int8 = False + for node in match.nodes_map.values(): + if ( + isinstance(node, torch.fx.Node) + and node.op == "call_function" + and node.target == torch.ops.torchao.dequantize_affine.default + ): + args = node.args + if len(args) >= 7: + weight_is_int4 = args[5] == -8 and args[6] == 7 + if ( + isinstance(node, torch.fx.Node) + and node.op == "call_function" + and node.target == torch.ops.torchao.quantize_affine.default + ): + args = node.args + if len(args) >= 5: + act_quant_is_int8 = args[4] == torch.int8 + return weight_is_int4 and act_quant_is_int8 + + +def _port_metadata_for_dynamic_quantized_linear_4bit_groupwise( + replacement_pattern: ReplacedPatterns, +): + """ + Port metadata for dynamically quantized linear 4-bit groupwise operation. + It custom_op node's metadata with corresponding linear node's metadata. + """ + from torch.fx.traceback import NodeSource, NodeSourceAction + + linear_node = None + int4_custom_op_node = None + for _, g_n in replacement_pattern.nodes_map.items(): + if g_n.target == torch.ops.aten.linear.default: + linear_node = g_n + break + if len(replacement_pattern.replacements) > 0: + int4_custom_op_node = replacement_pattern.replacements[-1] + if linear_node is not None and int4_custom_op_node is not None: + int4_custom_op_node.meta = linear_node.meta.copy() + int4_custom_op_node.meta["from_node"] = [ + NodeSource( + linear_node, + "ReplaceInt4DynamicQuantWithCustomOp", + NodeSourceAction.REPLACE, + ) + ] + + def _qdq_quantized_conv2d( x_i8, x_scale, @@ -627,6 +907,11 @@ class _RewriteInfo: # post transformation on the exported pattern and replacement GraphModule pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + filter_fn: Optional[ + list[Callable[["InternalMatch", torch.fx.Graph, torch.fx.Graph], bool]] + ] = None + ignore_literals: bool = False + port_metadata_fn: Optional[Callable[["ReplacedPatterns"], None]] = None def reference_representation_rewrite(model: GraphModule) -> GraphModule: @@ -738,6 +1023,31 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: 127, ) + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1 = ( + torch.randn((1, 32), dtype=torch.float), # x_fp32 + torch.finfo(torch.float32).eps, # x_eps + torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8) + torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups] + torch.zeros( + 8, 4, dtype=torch.int + ), # weight_zero_point [out_features, num_groups] + torch.randn(8, dtype=torch.float), # bias_fp32 + 8, # group_size + ) + + # just saw that we can match again > 2 dim input. Hacky. + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2 = ( + torch.randn((1, 1, 32), dtype=torch.float), # x_fp32 + torch.finfo(torch.float32).eps, # x_eps + torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8) + torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups] + torch.zeros( + 8, 4, dtype=torch.int + ), # weight_zero_point [out_features, num_groups] + torch.randn(8, dtype=torch.float), # bias_fp32 + 8, # group_size + ) + _REWRITE_INFO_LIST = [ _RewriteInfo( _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, @@ -752,6 +1062,50 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, ), ), + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1, + WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise), + WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + torch.finfo(torch.float32).eps: 1, + (1, 8): 6, + }, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + torch.finfo(torch.float32).eps: 1, + (1, 8): 6, + }, + ), + filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise], + ignore_literals=True, + port_metadata_fn=_port_metadata_for_dynamic_quantized_linear_4bit_groupwise, + ), + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2, + WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise), + WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + torch.finfo(torch.float32).eps: 1, + (1, 8): 6, + }, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + torch.finfo(torch.float32).eps: 1, + (1, 8): 6, + }, + ), + filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise], + ignore_literals=True, + port_metadata_fn=_port_metadata_for_dynamic_quantized_linear_4bit_groupwise, + ), _RewriteInfo( _QUANTIZED_LINEAR_EXAMPLE_INPUTS, WrapperModule(_qdq_quantized_linear), @@ -830,6 +1184,15 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: replacement = replacement_post_trans(replacement) pattern.recompile() # type: ignore[attr-defined] replacement.recompile() # type: ignore[attr-defined] - replace_pattern(model, pattern, replacement) + matches = replace_pattern_with_filters( + model, + pattern, + replacement, + match_filters=rewrite_info.filter_fn, + ignore_literals=rewrite_info.ignore_literals, + ) # type: ignore[arg-type] + if rewrite_info.port_metadata_fn: + for m in matches: + rewrite_info.port_metadata_fn(m) # type: ignore[arg-type] return model diff --git a/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py b/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py new file mode 100644 index 0000000000..5161e130a0 --- /dev/null +++ b/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py @@ -0,0 +1,438 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +import torch.nn as nn + +from torchao.quantization import Int8DynamicActivationInt4WeightConfig, quantize_ +from torchao.quantization.pt2e.reference_representation_rewrite import ( + _qdq_dynamic_quantized_linear_4bit_groupwise, + _reference_dynamic_quantized_linear_4bit_groupwise, + reference_representation_rewrite, +) +from torchao.utils import unwrap_tensor_subclass + + +class TestReferenceRepresentationRewrite(unittest.TestCase): + """Test cases for dynamically quantized linear 4-bit groupwise implementations.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + # This is a bit hacked since it makes all tests pass + # purpose of these tests is to catch no wild regressions and 1e-1 + # is ok for now + torch.manual_seed(78) + + def _get_default_quantization_params(self): + """Get default quantization parameters.""" + return { + "x_eps": torch.finfo(torch.float32).eps, + } + + def _create_test_tensors( + self, batch_size, in_features, out_features, group_size, bias=True + ): + """Create test tensors for the given dimensions.""" + + # Create input activation + x_fp32 = torch.randn(batch_size, in_features, dtype=torch.float32) + + # Create 4-bit quantized weight (stored as int8 with values in [-8, 7]) + weight_i4 = torch.randint(-8, 7, (out_features, in_features), dtype=torch.int8) + + # Create groupwise scales and zero points + num_groups = in_features // group_size + weight_scale = ( + torch.randn(out_features, num_groups, dtype=torch.float32).abs() + 0.01 + ) + weight_zero_point = torch.zeros( + out_features, num_groups, dtype=torch.int8 + ) # Symmetric quantization + + # Create bias if requested + bias_fp32 = torch.randn(out_features, dtype=torch.float32) if bias else None + + return { + "x_fp32": x_fp32, + "weight_i4": weight_i4, + "weight_scale": weight_scale, + "weight_zero_point": weight_zero_point, + "bias_fp32": bias_fp32, + } + + def _run_qdq_implementation(self, tensors, quant_params, group_size): + """Run the QDQ implementation with given tensors and parameters.""" + return _qdq_dynamic_quantized_linear_4bit_groupwise( + x_fp32=tensors["x_fp32"], + x_eps=quant_params["x_eps"], + weight_i4=tensors["weight_i4"], + weight_scale=tensors["weight_scale"], + weight_zero_point=tensors["weight_zero_point"], + bias_fp32=tensors["bias_fp32"], + group_size=group_size, + ) + + def _run_reference_implementation(self, tensors, quant_params, group_size): + """Run the reference implementation with given tensors and parameters.""" + return _reference_dynamic_quantized_linear_4bit_groupwise( + x_fp32=tensors["x_fp32"], + x_eps=quant_params["x_eps"], + weight_i4=tensors["weight_i4"], + weight_scale=tensors["weight_scale"], + weight_zero_point=tensors["weight_zero_point"], + bias_fp32=tensors["bias_fp32"], + group_size=group_size, + ) + + def _assert_basic_properties(self, result, expected_shape): + """Assert basic properties of the result tensor.""" + self.assertEqual(result.shape, expected_shape) + self.assertEqual(result.dtype, torch.float32) + + def _assert_implementations_close( + self, qdq_result, ref_result, atol=1e-1, rtol=1e-1, msg_suffix="" + ): + """Assert that QDQ and reference implementations produce similar results.""" + torch.testing.assert_close( + qdq_result, + ref_result, + atol=atol, + rtol=rtol, + msg=f"QDQ and reference results differ significantly{msg_suffix}", + ) + + def test_qdq_dynamic_quantized_linear_4bit_groupwise_basic(self): + """Test that QDQ implementation runs without errors and produces reasonable output.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 2, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + result = self._run_qdq_implementation(tensors, quant_params, group_size) + self._assert_basic_properties(result, (batch_size, out_features)) + + def test_reference_dynamic_quantized_linear_4bit_groupwise_basic(self): + """Test that reference implementation runs without errors and produces reasonable output.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 2, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + result = self._run_reference_implementation(tensors, quant_params, group_size) + self._assert_basic_properties(result, (batch_size, out_features)) + + def test_both_implementations_no_bias(self): + """Test both implementations without bias.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 1, 16, 4, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size, bias=False + ) + + qdq_result = self._run_qdq_implementation(tensors, quant_params, group_size) + ref_result = self._run_reference_implementation( + tensors, quant_params, group_size + ) + + self._assert_basic_properties(qdq_result, (batch_size, out_features)) + self._assert_basic_properties(ref_result, (batch_size, out_features)) + self._assert_implementations_close( + qdq_result, ref_result, msg_suffix=" for no-bias case" + ) + + def test_edge_cases_group_size_validation(self): + """Test edge cases and error conditions.""" + # Test-specific parameters + batch_size, in_features, out_features = 1, 32, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, 8 + ) # Valid group size for tensor creation + + # Test with group_size that doesn't divide in_features evenly + with self.assertRaises(AssertionError): + self._run_qdq_implementation( + tensors, quant_params, 7 + ) # 32 is not divisible by 7 + + # Test with zero group_size + with self.assertRaises(AssertionError): + self._run_qdq_implementation(tensors, quant_params, 0) + + def test_weight_dimension_validation(self): + """Test weight dimension validation.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 1, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + # Create 1D weight tensor (should fail) + tensors["weight_i4"] = torch.randint(-8, 7, (in_features,), dtype=torch.int8) + + with self.assertRaises((AssertionError, IndexError)): + self._run_qdq_implementation(tensors, quant_params, group_size) + + def test_different_group_sizes(self): + """Test with different valid group sizes.""" + # Test-specific parameters + batch_size, in_features, out_features = 2, 64, 16 + group_sizes = [8, 16, 32] + + quant_params = self._get_default_quantization_params() + + for group_size in group_sizes: + with self.subTest(group_size=group_size): + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + qdq_result = self._run_qdq_implementation( + tensors, quant_params, group_size + ) + ref_result = self._run_reference_implementation( + tensors, quant_params, group_size + ) + + self._assert_basic_properties(qdq_result, (batch_size, out_features)) + self._assert_basic_properties(ref_result, (batch_size, out_features)) + self._assert_implementations_close( + qdq_result, ref_result, msg_suffix=f" for group_size={group_size}" + ) + + def test_qdq_vs_reference_implementation_comparison(self): + """Test that QDQ and reference implementations produce similar results with various configurations.""" + # Test-specific parameters + test_cases = [ + (1, 32, 8, 8), + (2, 64, 16, 16), + (4, 128, 32, 32), + ] + + quant_params = self._get_default_quantization_params() + + for batch_size, in_features, out_features, group_size in test_cases: + with self.subTest( + batch_size=batch_size, + in_features=in_features, + out_features=out_features, + group_size=group_size, + ): + # Test with bias + tensors_with_bias = self._create_test_tensors( + batch_size, + in_features, + out_features, + group_size, + bias=True, + ) + + qdq_result = self._run_qdq_implementation( + tensors_with_bias, quant_params, group_size + ) + ref_result = self._run_reference_implementation( + tensors_with_bias, quant_params, group_size + ) + + self.assertEqual(qdq_result.shape, ref_result.shape) + self.assertEqual(qdq_result.shape, (batch_size, out_features)) + + self._assert_implementations_close( + qdq_result, + ref_result, + msg_suffix=f" for shape ({batch_size}, {in_features}, {out_features}) with group_size={group_size}", + ) + + # Test without bias + tensors_no_bias = self._create_test_tensors( + batch_size, + in_features, + out_features, + group_size, + bias=False, + ) + + qdq_result_no_bias = self._run_qdq_implementation( + tensors_no_bias, quant_params, group_size + ) + ref_result_no_bias = self._run_reference_implementation( + tensors_no_bias, quant_params, group_size + ) + + self._assert_implementations_close( + qdq_result_no_bias, + ref_result_no_bias, + msg_suffix=f" for no-bias case with shape ({batch_size}, {in_features}, {out_features}) and group_size={group_size}", + ) + + +class SimpleLinearModel(nn.Module): + """Simple model with linear layers for testing model rewrite functionality.""" + + def __init__(self, input_size=128, hidden_size=64, output_size=32): + super().__init__() + self.linear1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +class TestModelRewrite(unittest.TestCase): + """Test cases for model rewrite functionality with 8da4w quantization.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + torch.manual_seed(42) + + def test_export_and_rewrite_workflow(self): + """Test the complete export and rewrite workflow.""" + # Create model + model = SimpleLinearModel(input_size=64, hidden_size=32, output_size=16) + example_input = torch.randn(1, 64) + + # Apply 8da4w quantization + quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) + + # Unwrap tensor subclasses for export compatibility + model = unwrap_tensor_subclass(model) + + # Export model + exported_model = torch.export.export(model, (example_input,)) + + # Check that export was successful + self.assertIsNotNone(exported_model) + self.assertTrue(hasattr(exported_model, "graph_module")) + + # Test the exported model + with torch.no_grad(): + original_output = exported_model.module()(example_input) + + # Create a copy for rewriting + rewritten_model = copy.deepcopy(exported_model) + + # Apply reference representation rewrite + reference_representation_rewrite(rewritten_model.graph_module) + + # Test the rewritten model + with torch.no_grad(): + rewritten_output = rewritten_model.module()(example_input) + + # Check that outputs are close + self.assertEqual(original_output.shape, rewritten_output.shape) + self.assertEqual(original_output.dtype, rewritten_output.dtype) + + # The outputs should be close (allowing for some numerical differences) + torch.testing.assert_close( + original_output, rewritten_output, atol=5e-2, rtol=5e-2 + ) + + def test_different_group_sizes_rewrite(self): + """Test rewrite functionality with different group sizes.""" + group_sizes = [16, 32, 64] + + for group_size in group_sizes: + with self.subTest(group_size=group_size): + # Create model + model = SimpleLinearModel(input_size=64, hidden_size=32, output_size=16) + example_input = torch.randn(1, 2, 64) + + # Apply quantization with specific group size + quantize_( + model, Int8DynamicActivationInt4WeightConfig(group_size=group_size) + ) + + # Unwrap tensor subclasses for export compatibility + model = unwrap_tensor_subclass(model) + + # Export and test rewrite + exported_model = torch.export.export(model, (example_input,)) + + # Test the exported model + with torch.no_grad(): + original_output = exported_model.module()(example_input) + + # Create a copy for rewriting + rewritten_model = copy.deepcopy(exported_model) + + # Apply reference representation rewrite + reference_representation_rewrite(rewritten_model.graph_module) + + # Test the rewritten model + with torch.no_grad(): + rewritten_output = rewritten_model.module()(example_input) + + # The outputs should be close (allowing for some numerical differences) + torch.testing.assert_close( + original_output, + rewritten_output, + atol=5e-2, + rtol=5e-2, + msg=f"Rewrite failed for group_size={group_size}", + ) + + def test_model_without_bias_rewrite(self): + """Test rewrite functionality with linear layers that have no bias.""" + # Create model without bias + model = SimpleLinearModel(input_size=32, hidden_size=16, output_size=8) + model.linear1.bias = None + model.linear2.bias = None + + example_input = torch.randn(1, 32) + + # Apply quantization + quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=16)) + + # Unwrap tensor subclasses for export compatibility + model = unwrap_tensor_subclass(model) + + # Export and test rewrite + exported_model = torch.export.export(model, (example_input,)) + + # Test the exported model + with torch.no_grad(): + original_output = exported_model.module()(example_input) + + # Create a copy for rewriting + rewritten_model = copy.deepcopy(exported_model) + + # Apply reference representation rewrite + reference_representation_rewrite(rewritten_model.graph_module) + + # Test the rewritten model + with torch.no_grad(): + rewritten_output = rewritten_model.module()(example_input) + + # The outputs should be close (allowing for some numerical differences) + torch.testing.assert_close( + original_output, + rewritten_output, + atol=5e-2, + rtol=5e-2, + msg="Rewrite failed for model without bias", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index dc5f802fb8..7ff1dbc619 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -525,7 +525,11 @@ def get_attr_name(i: int): def create_getattr_from_value( - module: torch.nn.Module, graph: Graph, prefix: str, value: Any + module: torch.nn.Module, + graph: Graph, + prefix: str, + value: Any, + device: Optional[torch.device] = None, ) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and @@ -533,7 +537,8 @@ def create_getattr_from_value( """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) - device = _assert_and_get_unique_device(module) + if device is None: + device = _assert_and_get_unique_device(module) new_value = ( value.detach().clone() if isinstance(value, torch.Tensor) @@ -671,6 +676,7 @@ def fold_bn_weights_into_conv_node( conv_bias_node: Optional[Node], bn_node: Node, m: GraphModule, + fake_fuse: bool = False, # removes the BN nodes but doesn't change the conv weights ) -> None: # conv args: input, weight, bias, stride, padding, dilation, ... conv_w = _get_tensor_constant_from_node(conv_weight_node, m) @@ -703,6 +709,16 @@ def fold_bn_weights_into_conv_node( if len(conv_args) == 2: conv_args.append(None) + if fake_fuse: + fused_weight, fused_bias = ( + torch.nn.Parameter(conv_w, conv_w.requires_grad), + torch.nn.Parameter(conv_b, conv_b.requires_grad), + ) + else: + fused_weight, fused_bias = fuse_conv_bn_weights( + conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose + ) + # calling data since the fused_weight and fused_bias are nn.Parameter weight_attr_name = conv_weight_node.target assert isinstance(weight_attr_name, str) @@ -758,7 +774,7 @@ def fold_bn_weights_into_conv_node( # since the node refers to a mutating op. Here we still need to call DCE first # to get rid of the unused getitem nodes that consume the BN node. m.graph.eliminate_dead_code() - if len(bn_node.users) == 0: + if not bn_node._erased and len(bn_node.users) == 0: m.graph.erase_node(bn_node) @@ -767,6 +783,9 @@ def _fuse_conv_bn_(m: GraphModule) -> None: has_bn = any(_is_bn_node(n) for n in m.graph.nodes) if not has_bn: return + + # track which conv weights have been fused to avoid double fusing + fused_convs_weight_nodes = set() for n in m.graph.nodes: if n.op != "call_function" or n.target not in ( torch.ops.aten._native_batch_norm_legit_no_training.default, @@ -781,9 +800,14 @@ def _fuse_conv_bn_(m: GraphModule) -> None: conv_weight_node = conv_node.args[1] conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None fold_bn_weights_into_conv_node( - conv_node, conv_weight_node, conv_bias_node, bn_node, m + conv_node, + conv_weight_node, + conv_bias_node, + bn_node, + m, + (conv_weight_node in fused_convs_weight_nodes), ) - + fused_convs_weight_nodes.add(conv_weight_node) m.graph.eliminate_dead_code() m.recompile() @@ -815,7 +839,7 @@ def _get_aten_graph_module_for_pattern( [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs] ) - aten_pattern = torch.export.export_for_training( + aten_pattern = torch.export.export( pattern, # type: ignore[arg-type] example_inputs, kwargs, @@ -1031,6 +1055,8 @@ def replacement(x_i8, scale, zero_point, quant_min, quant_max): continue new_args = [] for arg in node.args: + if isinstance(arg, list): + arg = tuple(arg) # type: ignore[assignment] if ( _is_literal(arg) and arg not in exclude_literals diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index eee1047199..9a11aa7b51 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -67,79 +67,88 @@ def train_loop(m: torch.nn.Module): optimizer.zero_grad() ``` + ### quantize_ API (recommended) -The recommended way to run QAT in torchao is through the `quantize_` API: -1. **Prepare:** specify how weights and/or activations are to be quantized through -[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242) -2. **Convert:** quantize the model using the standard post-training quantization (PTQ) -functions such as [`Int8DynamicActivationInt4WeightConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606) +The recommended way to run QAT in torchao is through the `quantize_` API. -For example: +1. **Prepare:** The main [`QATConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.QATConfig.html) +accepts a post-training quantization (PTQ) config and automatically infers +the corresponding fake quantization configs to use. +2. **Convert:** quantize the model using the base config provided +Currently only the following PTQ base configs are supported: +- [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html) +- [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig.html) + +For example (most use cases): ```python -from torchao.quantization import ( - quantize_, - Int8DynamicActivationInt4WeightConfig, -) -from torchao.quantization.qat import ( - FakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, -) +from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig +from torchao.quantization.qat import QATConfig + model = get_model() -# prepare: insert fake quantization ops -# swaps `torch.nn.Linear` with `FakeQuantizedLinear` -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) -quantize_( - model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), -) +# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear` +base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) +quantize_(model, QATConfig(base_config, step="prepare")) # train train_loop(model) -# convert: transform fake quantization ops into actual quantized ops -# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts -# quantized activation and weight tensor subclasses -quantize_(model, FromIntXQuantizationAwareTrainingConfig()) -quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) +# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config` +quantize_(model, QATConfig(base_config, step="convert")) # inference or generate ``` +The `quantize_` API also allows more general quantization settings that +may not have a corresponding PTQ base config, e.g. for experimentation +purposes. Users can specify custom fake quantization configs for activations +and/or weights. For example, the following usage is numerically equivalent +to the above: + +```python +from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig +from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig + +model = get_model() + +# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear` +activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) +qat_config = QATConfig( + activation_config=activation_config, + weight_config=weight_config, + step="prepare", +) +quantize_(model, qat_config) + +# train +train_loop(model) + +# convert: (not shown, same as before) +``` + To fake quantize embedding in addition to linear, you can additionally call the following with a filter function during the prepare step: ``` -# first apply linear transformation to the model as above -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) -quantize_( - model, - IntXQuantizationAwareTrainingConfig(activation_config, weight_config), -) - -# then apply weight-only transformation to embedding layers -# activation fake quantization is not supported for embedding layers -quantize_( - m, - IntXQuantizationAwareTrainingConfig(weight_config=weight_config), - filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding) -) +# First apply linear transformation to the model as above +# Then apply weight-only transformation to embedding layers +# (activation fake quantization is not supported for embedding layers) +qat_config = QATConfig(weight_config=weight_config, step="prepare") +quantize_(m, qat_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding)) ``` ### Quantizer API (legacy) Alternatively, torchao provides a few hardcoded quantization settings through -the following Quantizers: -- [Int8DynActInt4QATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L126) (linear), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight -- [Int4WeightOnlyQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L308) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) -- [Int4WeightOnlyEmbeddingQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/embedding.py#L94) (embedding), targeting int4 per-group symmetric weight +the following Quantizers, but these may be removed soon: +- [Int8DynActInt4QATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer.html#torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer) (linear), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight +- [Int4WeightOnlyQATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int4WeightOnlyQATQuantizer.html#torchao.quantization.qat.Int4WeightOnlyQATQuantizer) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) +- [Int4WeightOnlyEmbeddingQATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer.html#torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer) (embedding), targeting int4 per-group symmetric weight For example: ```python @@ -162,7 +171,7 @@ model = qat_quantizer.convert(model) ``` To use multiple Quantizers in the same model for different layer types, -users can also leverage the [ComposableQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L242) +users can also leverage the [ComposableQATQuantizer](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.ComposableQATQuantizer.html#torchao.quantization.qat.ComposableQATQuantizer) as follows: ```python diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 4a4359e682..4218c763e2 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -1,31 +1,60 @@ from .api import ( ComposableQATQuantizer, - FakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, + QATConfig, + QATStep, from_intx_quantization_aware_training, initialize_fake_quantizers, intx_quantization_aware_training, ) from .embedding import ( + FakeQuantizedEmbedding, Int4WeightOnlyEmbeddingQATQuantizer, ) +from .fake_quantize_config import ( + FakeQuantizeConfig, + FakeQuantizeConfigBase, + Float8FakeQuantizeConfig, + IntxFakeQuantizeConfig, +) +from .fake_quantizer import ( + FakeQuantizer, + FakeQuantizerBase, + Float8FakeQuantizer, + IntxFakeQuantizer, +) from .linear import ( + FakeQuantizedLinear, Float8ActInt4WeightQATQuantizer, Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, ) __all__ = [ + "QATConfig", + "QATStep", + "FakeQuantizeConfigBase", + "FakeQuantizerBase", + "Float8FakeQuantizeConfig", + "Float8FakeQuantizer", + "IntxFakeQuantizeConfig", + "IntxFakeQuantizer", + "FakeQuantizedLinear", + "FakeQuantizedEmbedding", + # Prototype + "initialize_fake_quantizers", + # Legacy quantizers "ComposableQATQuantizer", - "FakeQuantizeConfig", "Float8ActInt4WeightQATQuantizer", - "FromIntXQuantizationAwareTrainingConfig", "Int4WeightOnlyEmbeddingQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", - "IntXQuantizationAwareTrainingConfig", - "initialize_fake_quantizers", - "intx_quantization_aware_training", + # for BC + "FakeQuantizer", + "FakeQuantizeConfig", "from_intx_quantization_aware_training", + "FromIntXQuantizationAwareTrainingConfig", + "intx_quantization_aware_training", + "IntXQuantizationAwareTrainingConfig", ] diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py index 80ecd173c2..dab63b3a00 100644 --- a/torchao/quantization/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py @@ -20,16 +20,12 @@ ) from torchao.utils import TorchAOBaseTensor -from .utils import ( - _UnwrapAffineFakeQuantizedTensor, -) - aten = torch.ops.aten class _ToAffineFakeQuantized(torch.autograd.Function): """ - Differentiable constructor for `AffineFakeQuantizedTensor`, + Differentiable constructor for `_AffineFakeQuantizedTensor`, needed for input activation fake quantization. """ @@ -47,12 +43,12 @@ def forward( zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - ) -> "AffineFakeQuantizedTensor": + ) -> "_AffineFakeQuantizedTensor": if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") def apply_fake_quant_fn(t: torch.Tensor): - assert isinstance(t, AffineFakeQuantizedTensor) + assert isinstance(t, _AffineFakeQuantizedTensor) qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: scale, zero_point = _choose_qparams_affine_tinygemm( @@ -102,7 +98,7 @@ def apply_fake_quant_fn(t: torch.Tensor): ) return fq - return AffineFakeQuantizedTensor( + return _AffineFakeQuantizedTensor( original_tensor, apply_fake_quant_fn, fake_quant_enabled=True, @@ -113,7 +109,7 @@ def backward(ctx, gy): return gy, None, None, None, None, None, None, None, None, None, None -class AffineFakeQuantizedTensor(TorchAOBaseTensor): +class _AffineFakeQuantizedTensor(TorchAOBaseTensor): """ Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: @@ -212,7 +208,7 @@ def get_value(self) -> torch.Tensor: if self.fake_quant_enabled: return self.apply_fake_quant_fn(self) else: - return _UnwrapAffineFakeQuantizedTensor.apply(self) + return self.original_tensor def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) @@ -243,14 +239,14 @@ def to(self, *args, **kwargs): def _apply_fn_to_data(self, fn: Callable): """ - Create a new `AffineFakeQuantizedTensor` with `fn` applied to the + Create a new `_AffineFakeQuantizedTensor` with `fn` applied to the original tensor, to be called within __torch_dispatch__. """ return self._create_new(fn(self.original_tensor)) def _create_new(self, new_value: torch.Tensor): """ - Create a new `AffineFakeQuantizedTensor` with a new value, + Create a new `_AffineFakeQuantizedTensor` with a new value, to be called within __torch_dispatch__. Note: `requires_grad` must be False here because tensors created @@ -267,7 +263,7 @@ def _create_new(self, new_value: torch.Tensor): ) -implements = AffineFakeQuantizedTensor.implements +implements = _AffineFakeQuantizedTensor.implements @implements(torch.nn.functional.linear) @@ -277,9 +273,9 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) - if isinstance(input_tensor, AffineFakeQuantizedTensor): + if isinstance(input_tensor, _AffineFakeQuantizedTensor): input_tensor = input_tensor.get_value() - if isinstance(weight_tensor, AffineFakeQuantizedTensor): + if isinstance(weight_tensor, _AffineFakeQuantizedTensor): weight_tensor = weight_tensor.get_value() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @@ -288,9 +284,9 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): input_tensor = args[0] weight_tensor = args[1] - if isinstance(input_tensor, AffineFakeQuantizedTensor): + if isinstance(input_tensor, _AffineFakeQuantizedTensor): input_tensor = input_tensor.get_value() - if isinstance(weight_tensor, AffineFakeQuantizedTensor): + if isinstance(weight_tensor, _AffineFakeQuantizedTensor): weight_tensor = weight_tensor.get_value() return func(input_tensor, weight_tensor) @@ -300,9 +296,9 @@ def _(func, types, args, kwargs): bias = args[0] input_tensor = args[1] weight_tensor = args[2] - if isinstance(input_tensor, AffineFakeQuantizedTensor): + if isinstance(input_tensor, _AffineFakeQuantizedTensor): input_tensor = input_tensor.get_value() - if isinstance(weight_tensor, AffineFakeQuantizedTensor): + if isinstance(weight_tensor, _AffineFakeQuantizedTensor): weight_tensor = weight_tensor.get_value() return func(bias, input_tensor, weight_tensor) @@ -348,10 +344,10 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}" new_args = pytree.tree_map_only( - AffineFakeQuantizedTensor, lambda x: x.original_tensor, args + _AffineFakeQuantizedTensor, lambda x: x.original_tensor, args ) first_afq_tensor = ( - args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1] + args[0] if isinstance(args[0], _AffineFakeQuantizedTensor) else args[1] ) new_value = func(*new_args, **kwargs) out = first_afq_tensor._create_new(new_value) @@ -384,4 +380,4 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, out) -to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float +_to_affine_fake_quantized = _AffineFakeQuantizedTensor.from_float diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 60370ee52b..1287126bac 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,284 +5,277 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from enum import Enum +from typing import Any, List, Optional, Tuple import torch from torchao.core.config import AOBaseConfig -from torchao.quantization.granularity import ( - Granularity, - PerAxis, - PerGroup, - PerToken, -) -from torchao.quantization.quant_primitives import ( - _SUB_BYTE_INT_BOUNDS, - _SUB_BYTE_UINT_BOUNDS, - MappingType, - TorchAODType, - ZeroPointDomain, -) from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) from torchao.quantization.unified import TwoStepQuantizer +from .embedding import FakeQuantizedEmbedding +from .fake_quantize_config import ( + FakeQuantizeConfig, # noqa: F401, for BC + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, + _infer_fake_quantize_configs, +) +from .linear import FakeQuantizedLinear +from .utils import _log_deprecation_warning + + +class QATStep(str, Enum): + """ + Enum value for the `step` field in :class:`~torchao.quantization.qat.QATConfig`. + """ + + PREPARE = "prepare" + CONVERT = "convert" + @dataclass -class FakeQuantizeConfig: +class QATConfig(AOBaseConfig): """ - Config for how to fake quantize weights or activations. - - args: - dtype: dtype to simulate during fake quantization, e.g. torch.int8. - For PyTorch versions older than 2.6, you may use `TorchAODType` to represent - torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. - granularity: granularity of scales and zero points, e.g. PerGroup(32). - We also support the following strings: - 1) 'per_token': equivalent to PerToken() - 2) 'per_channel': equivalent to PerAxis(0) - 3) 'per_group': equivalent to PerGroup(group_size), must be combined - with separate `group_size` kwarg, Alternatively, just set the - `group_size` kwarg and leave this field empty. - mapping_type: whether to use symmetric (default) or asymmetric quantization - Alternatively, set `is_symmetric` (bool) and leave this field empty. - scale_precision: scale dtype (default torch.fp32) - zero_point_precision: zero point dtype (default torch.int32) - zero_point_domain: whether zero point is in integer (default) or float domain - is_dynamic: whether to use dynamic (default) or static scale and zero points - range_learning (prototype): whether to learn scale and zero points during training - (default false), not compatible with `is_dynamic`. - - kwargs (optional): - group_size: size of each group in per group fake quantization, - can be set instead of `granularity` - is_symmetric: whether to use symmetric or asymmetric quantization, - can be set instead of `mapping_type` + Config for applying quantization-aware training (QAT) to a `torch.nn.Module`, + to be used with :func:`~torchao.quantization.quant_api.quantize_`. + + This config has two steps, "prepare" and "convert". The prepare step applies + "fake" quantization to the model and should be applied before training, while + the convert step converts the model into an actual quantized model. Fake + quantization here refers to simulating the quantization numerics (e.g. int4) + using high precision arithmetic (e.g. bf16), with the goal of reducing + eventual degradation from quantization. + + There are two ways to use this config. The first involves passing a base + post-training quantization (PTQ) config, which we will use to automatically + infer the corresponding fake quantization schemes to use in the prepare phase. + In the convert phase, we will then apply the base PTQ config to the model. + This will be the most common use case. Example usage:: - # Per token asymmetric quantization - FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) + from torchao.quantization import ( + quantize_, + Int8DynamicActivationInt4WeightConfig, + ) + from torchao.quantization.qat import QATConfig + + base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) + quantize_(model, QATConfig(base_config, step="prepare")) + train_loop(model) + quantize_(model, QATConfig(base_config, step="convert")) + + Currently only the following are supported as base configs: - # Per channel symmetric quantization - FakeQuantizeConfig(torch.int4, "per_channel") - FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) - FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) + - :class:`~torchao.quantization.Int8DynamicActivationInt4WeightConfig` + - :class:`~torchao.quantization.Int4WeightOnlyConfig` + + The second way to use this config involves specifying the fake quantization + schemes directly. Users will pass in :class:`~torchao.quantization.qat.FakeQuantizeConfigBase` + for weights and/or activations instead of the base PTQ config. This use case + is mostly for experimentation, e.g. when the corresponding PTQ config does + not exist yet. + + Example usage:: + + from torchao.quantization import quantize_ + from torchao.quantization.qat import IntxFakeQuantizeConfig - # Per group symmetric quantization - FakeQuantizeConfig(torch.int4, group_size=32) - FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) - FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) - FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) + activation_config = IntxFakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False, + ) + weight_config = IntxFakeQuantizeConfig( + torch.int4, group_size=32, is_symmetric=True, + ) + qat_config = QATConfig( + # must specify one of `base_config` or `weight_config` + activation_config=act_config, + weight_config=weight_config, + step="prepare", + ) + quantize_(model, qat_config) + + Args: + base_config (Optional[AOBaseConfig]): Base PTQ config to infer the fake + quantization configs during the prepare phase, and to apply directly + during the convert phase. + activation_config (Optional[FakeQuantizeConfigBase]): Custom fake + quantization config for input activations, always optional. + Must be None if `base_config` is used. + weight_config (Optional[FakeQuantizeConfigBase]): Custom fake quantization + config for weights. Must be None if `base_config` is used. + + Keyword args: + step (str): One of "prepare" or "convert", determines the QAT phase + + Raises: + ValueError: If `base_config` and `activation_config` are both specified + ValueError: If `base_config` and `weight_config` are both specified + ValueError: If none of `base_config`, `activation_config`, or + `weight_config` are specified + ValueError: If either `activation_config` or `weight_config` is specified + and `step` is "convert" + ValueError: If `step` is not one of "prepare" or "convert" + ValueError: If the config is applied on a module that is not a + `torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on + `torch.nn.Embedding` with an activation config """ - dtype: Union[torch.dtype, TorchAODType] - granularity: Granularity - mapping_type: MappingType - scale_precision: torch.dtype - zero_point_precision: torch.dtype - zero_point_domain: ZeroPointDomain - is_dynamic: bool = True - range_learning: bool = False - eps: Optional[float] = None + base_config: Optional[AOBaseConfig] + activation_config: Optional[FakeQuantizeConfigBase] + weight_config: Optional[FakeQuantizeConfigBase] + step: QATStep + # Express `step` as a keyword argument + # TODO: Use `kw_only=True` instead, added in python 3.10 def __init__( self, - dtype: Union[torch.dtype, TorchAODType], - granularity: Union[Granularity, str, None] = None, - mapping_type: Optional[MappingType] = None, - scale_precision: torch.dtype = torch.float32, - zero_point_precision: torch.dtype = torch.int32, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - is_dynamic: bool = True, - range_learning: bool = False, - eps: Optional[float] = None, + base_config: Optional[AOBaseConfig] = None, + activation_config: Optional[FakeQuantizeConfigBase] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, *, - group_size: Optional[int] = None, - is_symmetric: Optional[bool] = None, + step: QATStep = "prepare", ): - if zero_point_domain is None: - raise ValueError("Please use ZeroPointDomain.NONE instead of None") - self.dtype = dtype - self.granularity = self._get_granularity(granularity, group_size) - self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) - self.scale_precision = scale_precision - self.zero_point_precision = zero_point_precision - self.zero_point_domain = zero_point_domain - self.is_dynamic = is_dynamic - self.range_learning = range_learning - self.eps = eps - - # Validate dtype - all_dtypes = [torch.int8, torch.uint8] - all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) - all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) - if dtype not in all_dtypes: + self.base_config = base_config + self.activation_config = activation_config + self.weight_config = weight_config + self.step = step + self.__post_init__() + + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.qat.QATConfig") + self.step = self.step.lower() + all_step_values = [s.value for s in QATStep] + if self.step not in all_step_values: + raise ValueError(f"`step` must be one of {all_step_values}") + if self.base_config is not None and self.activation_config is not None: raise ValueError( - "Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes) + "Cannot specify both `base_config` and `activation_config`" ) - - # Dynamic is not compatible with range learning - if is_dynamic and range_learning: - raise ValueError("`is_dynamic` is not compatible with `range_learning`") - - def _get_granularity( - self, - granularity: Union[Granularity, str, None], - group_size: Optional[int], - ) -> Granularity: - """ - Parse the `Granularity` represented in the args. - - Granularity can be specified in one of three ways: - 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) - 2) str: one of 'per_token', 'per_channel', and 'per_group' - 3) None: `group_size` must be set instead, represents per group granularity - """ - # If group_size is set, then granularity must be either "per_group" or None - if ( - group_size is not None - and granularity != "per_group" - and granularity is not None + if self.base_config is not None and self.weight_config is not None: + raise ValueError("Cannot specify both `base_config` and `weight_config`") + if self.step == QATStep.PREPARE and not any( + (self.base_config, self.activation_config, self.weight_config) ): raise ValueError( - "`group_size` conflicts with granularity '%s'" % granularity - ) - - # Case 1: Granularity object - if isinstance(granularity, Granularity): - if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): - raise ValueError("Granularity '%s' is not supported" % granularity) - if isinstance(granularity, PerAxis) and granularity.axis != 0: - raise ValueError("Only axis=0 is supported for PerAxis granularity") - return granularity - - # Case 2: str granularity - if granularity == "per_token": - return PerToken() - elif granularity == "per_channel": - return PerAxis(axis=0) - elif granularity == "per_group": - if group_size is None: - raise ValueError( - "Granularity was 'per_group' but no `group_size` was set" - ) - return PerGroup(group_size) - elif isinstance(granularity, str): - raise ValueError( - "Unexpected granularity: '%s', must be one of %s" - % (granularity, ["per_token", "per_channel", "per_group"]) + "Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step" ) - # Case 3: None granularity + group_size was specified - if granularity is not None: + if self.step == QATStep.CONVERT and ( + self.activation_config is not None or self.weight_config is not None + ): raise ValueError( - "Granularity '%s' has unexpected type %s" - % (granularity, type(granularity)) + "Cannot specify `weight_config` or `activation_config` in the convert step" ) - if group_size is None: + if isinstance(self.base_config, FakeQuantizeConfigBase): + config_type = self.base_config.__class__.__name__ raise ValueError( - "At least one of `granularity` or `group_size` must be set" + f"{config_type} was passed as `base_config`. Did you mean to do the following instead?\n" + " qat_config = QATConfig(\n" + f" activation_config={config_type}(...),\n" + f" weight_config={config_type}(...),\n" + ' step="prepare",\n' + " )" ) - return PerGroup(group_size) - def _get_mapping_type( - self, - mapping_type: Optional[MappingType], - is_symmetric: Optional[bool], - ) -> MappingType: - """ - Parse the `MappingType` represented in the args. - - Mapping type can be specified in one of two ways: - 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC - 2): is_symmetric bool - """ - if mapping_type is not None and is_symmetric is not None: - raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") - - # Case 0: Default to symmetric - if mapping_type is None and is_symmetric is None: - return MappingType.SYMMETRIC - - # Case 1: MappingType object - if mapping_type is not None: - if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: - raise ValueError("MappingType '%s' is not supported" % mapping_type) - return mapping_type - - # Case 2: is_symmetric flag - assert is_symmetric is not None - if is_symmetric: - return MappingType.SYMMETRIC + +@register_quantize_module_handler(QATConfig) +def _qat_config_transform( + module: torch.nn.Module, + config: QATConfig, +) -> torch.nn.Module: + """ + During the prepare step, perform module swap to apply fake quantization. + If the base PTQ config is specified, derive the fake quantization configs from it. + + During the convert step, first perform module swap to revert all fake quantized + modules to the corresponding built-in `torch.nn.Module`s, then apply the + base config directly to quantize the module. + """ + # Prepare step + # Swap nn.Linear -> FakeQuantizedLinear + # Swap nn.Embedding -> FakeQuantizedEmbedding + base_config = config.base_config + step = config.step + if step == QATStep.PREPARE: + if base_config is not None: + (act_config, weight_config) = _infer_fake_quantize_configs(base_config) else: - return MappingType.ASYMMETRIC - - @property - def group_size(self) -> int: - """ - If this is per group granularity, return the group size. - Otherwise, throw an error. - """ - if isinstance(self.granularity, PerGroup): - return self.granularity.group_size + act_config = config.activation_config + weight_config = config.weight_config + if isinstance(module, torch.nn.Linear): + return FakeQuantizedLinear.from_linear(module, act_config, weight_config) + elif isinstance(module, torch.nn.Embedding): + if act_config is not None: + raise ValueError( + "Activation fake quantization is not supported for embedding" + ) + return FakeQuantizedEmbedding.from_embedding(module, weight_config) else: raise ValueError( - "`group_size` is undefined for %s granularity" % self.granularity + "Module of type '%s' does not have QAT support" % type(module) + ) + else: + # Convert step + assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step + assert config.activation_config is None, "unexpected `activation_config`" + assert config.weight_config is None, "unexpected `weight_config`" + + # Ignore unrelated modules + if not isinstance(module, (FakeQuantizedLinear, FakeQuantizedEmbedding)): + return module + + # Optionally pass custom scales and zero points to base config handler + # This is only for range learning and only applies to weights + kwargs = {} + weight_config = module.weight_fake_quantizer.config + if ( + isinstance(weight_config, IntxFakeQuantizeConfig) + and weight_config.range_learning + ): + kwargs["custom_scale"] = module.weight_fake_quantizer.scale + kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point + + # Swap FakeQuantizedLinear -> nn.Linear + # Swap FakeQuantizedEmbedding -> nn.Embedding + # Then apply the base config's transform function to quantize the model + # If there is no base config, then simply perform the module swap + if isinstance(module, FakeQuantizedLinear): + module = module.to_linear() + elif isinstance(module, FakeQuantizedEmbedding): + module = module.to_embedding() + else: + raise ValueError( + f"Encountered unexpected module {module}, should never happen" + ) + if base_config is not None: + return _QUANTIZE_CONFIG_HANDLER[type(base_config)]( + module, base_config, **kwargs ) - - @property - def is_symmetric(self) -> bool: - """ - Return True if mapping type is symmetric, else False (asymmetric). - """ - return self.mapping_type == MappingType.SYMMETRIC - - def __setattr__(self, name: str, value: Any): - """ - Support setting `group_size` and `is_symmetric`. - """ - if name == "group_size": - super().__setattr__("granularity", PerGroup(value)) - elif name == "is_symmetric": - mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC - super().__setattr__("mapping_type", mapping_type) else: - super().__setattr__(name, value) + return module @dataclass class IntXQuantizationAwareTrainingConfig(AOBaseConfig): - activation_config: Optional[FakeQuantizeConfig] = None - weight_config: Optional[FakeQuantizeConfig] = None - - -# for BC -intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig - - -@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) -def _intx_quantization_aware_training_transform( - module: torch.nn.Module, - config: IntXQuantizationAwareTrainingConfig, -) -> torch.nn.Module: """ - THIS IS NOT A PUBLIC API - any usage of this outside of torchao - can break at any time. + (Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead. - Apply fake quantization to a `torch.nn.Module`. + Config for applying fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. Example usage:: from torchao.quantization import quantize_ - from torchao.quantization.qat import FakeQuantizeConfig - activation_config = FakeQuantizeConfig( + from torchao.quantization.qat import IntxFakeQuantizeConfig + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( torch.int4, group_size=32, is_symmetric=True, ) quantize_( @@ -290,14 +283,29 @@ def _intx_quantization_aware_training_transform( IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) - Note: If the returned function is applied on a module that is not + Note: If the config is applied on a module that is not `torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on `torch.nn.Embedding` with an activation config, then we will raise ValueError as these are not supported. """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear + activation_config: Optional[FakeQuantizeConfigBase] = None + weight_config: Optional[FakeQuantizeConfigBase] = None + + def __post_init__(self): + _log_deprecation_warning(self) + + +# for BC +class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig): + pass + + +@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) +def _intx_quantization_aware_training_transform( + module: torch.nn.Module, + config: IntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: mod = module activation_config = config.activation_config weight_config = config.weight_config @@ -318,9 +326,12 @@ def _intx_quantization_aware_training_transform( raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) +@dataclass class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): """ - Object that knows how to convert a model with fake quantized modules, + (Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead. + + Config for converting a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, back to model with the original, corresponding modules without @@ -336,11 +347,13 @@ class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): ) """ - pass + def __post_init__(self): + _log_deprecation_warning(self) # for BC -from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig +class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig): + pass @register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) @@ -352,9 +365,6 @@ def _from_intx_quantization_aware_training_transform( If the given module is a fake quantized module, return the original corresponding version of the module without fake quantization. """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear - if isinstance(mod, FakeQuantizedLinear): return mod.to_linear() elif isinstance(mod, FakeQuantizedEmbedding): @@ -384,6 +394,7 @@ class ComposableQATQuantizer(TwoStepQuantizer): """ def __init__(self, quantizers: List[TwoStepQuantizer]): + torch._C._log_api_usage_once("torchao.quantization.qat.ComposableQATQuantizer") self.quantizers = quantizers def prepare( @@ -407,14 +418,16 @@ def initialize_fake_quantizers( ) -> None: """ (Prototype) Initialize the scales and zero points on all - :class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer` + :class:`~torchao.quantization.qat.fake_quantizer.IntxFakeQuantizerBase` in the model based on the provided example inputs. """ + torch._C._log_api_usage_once("torchao.quantization.qat.initialize_fake_quantizers") + # avoid circular dependencies - from torchao.quantization.qat.fake_quantizer import FakeQuantizer + from torchao.quantization.qat.fake_quantizer import IntxFakeQuantizer def _set_initialized(m: torch.nn.Module): - if isinstance(m, FakeQuantizer): + if isinstance(m, IntxFakeQuantizer): m._initialized = True model.apply(_set_initialized) diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index aec23712ed..a1a6484772 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -13,8 +13,11 @@ from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric -from .api import FakeQuantizeConfig -from .fake_quantizer import FakeQuantizer +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, +) +from .fake_quantizer import FakeQuantizerBase from .utils import ( _get_qmin_qmax, ) @@ -29,7 +32,7 @@ class FakeQuantizedEmbedding(torch.nn.Embedding): Example usage:: - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, group_size=8, symmetric=True, @@ -47,7 +50,7 @@ def __init__( norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, - weight_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, *args, **kwargs, ) -> None: @@ -62,8 +65,9 @@ def __init__( *args, **kwargs, ) + torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedEmbedding") if weight_config is not None: - self.weight_fake_quantizer = FakeQuantizer(weight_config) + self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config) else: self.weight_fake_quantizer = None @@ -105,7 +109,7 @@ def to_embedding(self) -> torch.nn.Embedding: def from_embedding( cls, mod: torch.nn.Embedding, - weight_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, ): new_embedding = FakeQuantizedEmbedding( mod.num_embeddings, @@ -145,6 +149,9 @@ def __init__( zero_point_precision: torch.dtype = torch.int32, ) -> None: super().__init__() + torch._C._log_api_usage_once( + "torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer" + ) self.bit_width = 4 self.group_size: int = group_size self.scale_precision: torch.dtype = scale_precision @@ -285,7 +292,7 @@ def __init__( *args, **kwargs, ): - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=TorchAODType.INT4, group_size=group_size, is_symmetric=True, diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py new file mode 100644 index 0000000000..ebc9864f3d --- /dev/null +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -0,0 +1,505 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import torch + +from torchao.core.config import AOBaseConfig +from torchao.float8.config import e4m3_dtype +from torchao.float8.inference import ( + FP8Granularity, + _normalize_granularity, +) +from torchao.quantization.granularity import ( + Granularity, + PerAxis, + PerGroup, + PerRow, + PerTensor, + PerToken, +) +from torchao.quantization.quant_primitives import ( + _SUB_BYTE_INT_BOUNDS, + _SUB_BYTE_UINT_BOUNDS, + MappingType, + TorchAODType, + ZeroPointDomain, +) +from torchao.quantization.quantize_.workflows import Int4PackingFormat +from torchao.utils import _is_float8_type + +from .utils import _log_deprecation_warning + + +class FakeQuantizeConfigBase(abc.ABC): + """ + Base class for representing fake quantization config. + """ + + pass + + +@dataclass +class Float8FakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for float8 fake quantization, targeting :class:`~torchao.quantization.Float8Tensor`. + + Args: + dtype (torch.dtype): the dtype for float8 Tensor + granularity (FP8Granularity): the granularity for the Tensor, currently either PerRow() or PerTensor() + hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale + hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale + """ + + dtype: torch.dtype = e4m3_dtype + granularity: FP8Granularity = PerRow() + hp_value_lb: Optional[float] = None + hp_value_ub: Optional[float] = None + + def __post_init__(self): + """ + Verify dtype and granularity are the ones we support. + """ + if not _is_float8_type(self.dtype): + raise ValueError(f"{self.dtype} is not a float8 dtype") + if isinstance(self.granularity, type): + raise ValueError( + "Please specify the granularity object instead of the class, e.g. PerRow() instead of PerRow" + ) + if type(self.granularity) not in [PerRow, PerTensor]: + raise ValueError( + f"Expected PerRow or PerTensor granularity, got {self.granularity}" + ) + + +@dataclass +class Int4WeightFakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel: + torch.ops.fbgemm.f8i4bf16_shuffled + torch.ops.fbgemm.bf16i4bf16_shuffled + torch.ops.fbgemm.bf16i4bf16_rowwise + + Currently this only supports float8 input activations. It is expected to be used in conjunction with + :class:`~torchao.quantization.Float8DynamicActivationInt4WeightConfig`. In the future, we may extend + this to support bfloat16 as well. + """ + + group_size: int = 128 + activation_dtype: torch.dtype = e4m3_dtype + + def __post_init__(self): + if self.activation_dtype not in [e4m3_dtype, torch.bfloat16]: + raise ValueError( + f"Only {e4m3_dtype} or torch.bfloat16 activation are supported" + ) + + +@dataclass +class IntxFakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for how to fake quantize weights or activations, + targeting integer dtypes up to torch.int8. + + Args: + dtype: dtype to simulate during fake quantization, e.g. torch.int8. + For PyTorch versions older than 2.6, you may use `TorchAODType` to represent + torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. + granularity: granularity of scales and zero points, e.g. PerGroup(32). + We also support the following strings: + 1) 'per_token': equivalent to PerToken() + 2) 'per_channel': equivalent to PerAxis(0) + 3) 'per_group': equivalent to PerGroup(group_size), must be combined + with separate `group_size` kwarg, Alternatively, just set the + `group_size` kwarg and leave this field empty. + mapping_type: whether to use symmetric (default) or asymmetric quantization + Alternatively, set `is_symmetric` (bool) and leave this field empty. + scale_precision: scale dtype (default torch.fp32) + zero_point_precision: zero point dtype (default torch.int32) + zero_point_domain: whether zero point is in integer (default) or float domain + is_dynamic: whether to use dynamic (default) or static scale and zero points + range_learning (prototype): whether to learn scale and zero points during training + (default false), not compatible with `is_dynamic`. + + Keyword args: + group_size: size of each group in per group fake quantization, + can be set instead of `granularity` + is_symmetric: whether to use symmetric or asymmetric quantization, + can be set instead of `mapping_type` + + Example usage:: + + # Per token asymmetric quantization + IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + IntxFakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) + + # Per channel symmetric quantization + IntxFakeQuantizeConfig(torch.int4, "per_channel") + IntxFakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) + IntxFakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) + + # Per group symmetric quantization + IntxFakeQuantizeConfig(torch.int4, group_size=32) + IntxFakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) + IntxFakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) + IntxFakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) + """ + + dtype: Union[torch.dtype, TorchAODType] + granularity: Granularity + mapping_type: MappingType + scale_precision: torch.dtype + zero_point_precision: torch.dtype + zero_point_domain: ZeroPointDomain + is_dynamic: bool = True + range_learning: bool = False + eps: Optional[float] = None + + def __init__( + self, + dtype: Union[torch.dtype, TorchAODType], + granularity: Union[Granularity, str, None] = None, + mapping_type: Optional[MappingType] = None, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + is_dynamic: bool = True, + range_learning: bool = False, + eps: Optional[float] = None, + *, + group_size: Optional[int] = None, + is_symmetric: Optional[bool] = None, + ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + self.dtype = dtype + self.granularity = self._get_granularity(granularity, group_size) + self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) + self.scale_precision = scale_precision + self.zero_point_precision = zero_point_precision + self.zero_point_domain = zero_point_domain + self.is_dynamic = is_dynamic + self.range_learning = range_learning + self.eps = eps + + # Validate dtype + all_dtypes = [torch.int8, torch.uint8] + all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) + all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) + if dtype not in all_dtypes: + raise ValueError( + "Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes) + ) + + # Dynamic is not compatible with range learning + if is_dynamic and range_learning: + raise ValueError("`is_dynamic` is not compatible with `range_learning`") + + self.__post_init__() + + def __post_init__(self): + """ + For deprecation only, can remove after https://github.com/pytorch/ao/issues/2630. + """ + pass + + def _get_granularity( + self, + granularity: Union[Granularity, str, None], + group_size: Optional[int], + ) -> Granularity: + """ + Parse the `Granularity` represented in the args. + + Granularity can be specified in one of three ways: + 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) + 2) str: one of 'per_token', 'per_channel', and 'per_group' + 3) None: `group_size` must be set instead, represents per group granularity + """ + # If group_size is set, then granularity must be either "per_group" or None + if ( + group_size is not None + and granularity != "per_group" + and granularity is not None + ): + raise ValueError( + "`group_size` conflicts with granularity '%s'" % granularity + ) + + # Case 1: Granularity object + if isinstance(granularity, Granularity): + if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): + raise ValueError("Granularity '%s' is not supported" % granularity) + if isinstance(granularity, PerAxis) and granularity.axis != 0: + raise ValueError("Only axis=0 is supported for PerAxis granularity") + return granularity + + # Case 2: str granularity + if granularity == "per_token": + return PerToken() + elif granularity == "per_channel": + return PerAxis(axis=0) + elif granularity == "per_group": + if group_size is None: + raise ValueError( + "Granularity was 'per_group' but no `group_size` was set" + ) + return PerGroup(group_size) + elif isinstance(granularity, str): + raise ValueError( + "Unexpected granularity: '%s', must be one of %s" + % (granularity, ["per_token", "per_channel", "per_group"]) + ) + + # Case 3: None granularity + group_size was specified + if granularity is not None: + raise ValueError( + "Granularity '%s' has unexpected type %s" + % (granularity, type(granularity)) + ) + if group_size is None: + raise ValueError( + "At least one of `granularity` or `group_size` must be set" + ) + return PerGroup(group_size) + + def _get_mapping_type( + self, + mapping_type: Optional[MappingType], + is_symmetric: Optional[bool], + ) -> MappingType: + """ + Parse the `MappingType` represented in the args. + + Mapping type can be specified in one of two ways: + 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC + 2): is_symmetric bool + """ + if mapping_type is not None and is_symmetric is not None: + raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") + + # Case 0: Default to symmetric + if mapping_type is None and is_symmetric is None: + return MappingType.SYMMETRIC + + # Case 1: MappingType object + if mapping_type is not None: + if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: + raise ValueError("MappingType '%s' is not supported" % mapping_type) + return mapping_type + + # Case 2: is_symmetric flag + assert is_symmetric is not None + if is_symmetric: + return MappingType.SYMMETRIC + else: + return MappingType.ASYMMETRIC + + @property + def group_size(self) -> int: + """ + If this is per group granularity, return the group size. + Otherwise, throw an error. + """ + if isinstance(self.granularity, PerGroup): + return self.granularity.group_size + else: + raise ValueError( + "`group_size` is undefined for %s granularity" % self.granularity + ) + + @property + def is_symmetric(self) -> bool: + """ + Return True if mapping type is symmetric, else False (asymmetric). + """ + return self.mapping_type == MappingType.SYMMETRIC + + def __setattr__(self, name: str, value: Any): + """ + Support setting `group_size` and `is_symmetric`. + """ + if name == "group_size": + super().__setattr__("granularity", PerGroup(value)) + elif name == "is_symmetric": + mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC + super().__setattr__("mapping_type", mapping_type) + else: + super().__setattr__(name, value) + + +# For BC +class FakeQuantizeConfig(IntxFakeQuantizeConfig): + """ + (Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizeConfig` instead. + """ + + def __post_init__(self): + _log_deprecation_warning(self) + + +def _infer_fake_quantize_configs( + base_config: AOBaseConfig, +) -> Tuple[Optional[FakeQuantizeConfigBase], Optional[FakeQuantizeConfigBase]]: + """ + Given a base post-training quantization (PTQ) config, infer the corresponding + `FakeQuantizeConfigBase`s for both the activations and the weights. + This is called during the prepare phase of QAT. + + Return a 2-tuple of (activation_config, weight_config) for fake quantization. + """ + # TODO: rewrite using registration API so we don't need to import here + # avoid circular imports + from torchao.prototype.mx_formats import ( + NVFP4InferenceConfig, + NVFP4MMConfig, + ) + from torchao.prototype.qat import ( + NVFP4FakeQuantizeConfig, + ) + from torchao.quantization import ( + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + ) + + if isinstance(base_config, Int8DynamicActivationInt4WeightConfig): + act_config = IntxFakeQuantizeConfig( + dtype=torch.int8, + granularity="per_token", + is_symmetric=base_config.act_mapping_type == MappingType.SYMMETRIC, + ) + weight_config = IntxFakeQuantizeConfig( + dtype=torch.int4, + group_size=base_config.group_size, + is_symmetric=base_config.mapping_type == MappingType.SYMMETRIC, + ) + elif isinstance(base_config, Int4WeightOnlyConfig): + act_config = None + if base_config.version == 2: + supported_packing_formats = [ + Int4PackingFormat.PLAIN, + Int4PackingFormat.PRESHUFFLED, + ] + if base_config.int4_packing_format not in supported_packing_formats: + raise ValueError( + f"Packing format must be one of {supported_packing_formats}" + ) + weight_config = Int4WeightFakeQuantizeConfig( + group_size=128, + activation_dtype=torch.bfloat16, + ) + elif base_config.version == 1: + # For BC + from torchao.quantization.quant_api import ( + LAYOUT_TO_ZERO_POINT_DOMAIN, + ) + + if base_config.zero_point_domain == ZeroPointDomain.NONE: + zp_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(base_config.layout)][0] + else: + zp_domain = base_config.zero_point_domain + weight_config = IntxFakeQuantizeConfig( + dtype=torch.uint4, + group_size=base_config.group_size, + is_symmetric=False, + zero_point_domain=zp_domain, + ) + else: + raise ValueError(f"Unknown version on base config {type(base_config)}") + elif isinstance(base_config, Float8DynamicActivationFloat8WeightConfig): + if base_config.version != 2: + raise ValueError(f"Only version 2 of {type(base_config)} is supported") + (act_granularity, weight_granularity) = _normalize_granularity( + base_config.granularity + ) + act_config = Float8FakeQuantizeConfig( + dtype=base_config.activation_dtype, + granularity=act_granularity, + hp_value_lb=base_config.activation_value_lb, + hp_value_ub=base_config.activation_value_ub, + ) + weight_config = Float8FakeQuantizeConfig( + dtype=base_config.weight_dtype, + granularity=weight_granularity, + ) + elif isinstance(base_config, Float8DynamicActivationInt4WeightConfig): + act_config = Float8FakeQuantizeConfig( + dtype=e4m3_dtype, + granularity=PerRow(), + ) + weight_config = Int4WeightFakeQuantizeConfig( + group_size=128, + activation_dtype=e4m3_dtype, + ) + elif isinstance(base_config, NVFP4InferenceConfig): + # Note: today the PTQ config does not allow the user to specify + # `per_tensor_scales` due to serialization concerns. In the future + # we may add a way to compute these dynamically (for activations), + # but for now QAT will mimic the existing behavior of not having + # `per_tensor_scales` (subject to change) + if NVFP4MMConfig.DYNAMIC: + act_config = NVFP4FakeQuantizeConfig(False) + else: + act_config = None + weight_config = NVFP4FakeQuantizeConfig(False) + elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig): + assert base_config.version >= 2, "Only version 2+ is supported" + assert base_config.intx_packing_format == "unpacked_to_int8", ( + "Only unpacked_to_int8 is supported" + ) + assert base_config.weight_dtype != torch.int1, "Only int2+ is supported" + assert base_config.act_mapping_type == MappingType.ASYMMETRIC, ( + "Only asymmetric activation mapping is supported" + ) + assert base_config.weight_mapping_type == MappingType.SYMMETRIC, ( + "Only symmetric weight mapping is supported" + ) + assert base_config.weight_scale_dtype is None, ( + "Specifying weight_scale_dtype is not supported" + ) + + act_config = IntxFakeQuantizeConfig( + torch.int8, + "per_token", + is_symmetric=False, + scale_precision=base_config.weight_scale_dtype, + ) + weight_config = IntxFakeQuantizeConfig( + dtype=base_config.weight_dtype, + granularity=base_config.weight_granularity, + mapping_type=base_config.weight_mapping_type, + scale_precision=base_config.weight_scale_dtype, + ) + elif isinstance(base_config, IntxWeightOnlyConfig): + assert base_config.version >= 2, "Only version 2+ is supported" + assert base_config.intx_packing_format == "unpacked_to_int8", ( + "Only unpacked_to_int8 is supported" + ) + assert base_config.mapping_type == MappingType.SYMMETRIC, ( + "Only symmetric mapping is supported" + ) + assert base_config.weight_dtype != torch.int1, "Only int2+ is supported" + assert base_config.scale_dtype is None, ( + "Specifying scale_dtype is not supported" + ) + + act_config = None + weight_config = IntxFakeQuantizeConfig( + dtype=base_config.weight_dtype, + granularity=base_config.granularity, + mapping_type=base_config.mapping_type, + scale_precision=base_config.scale_dtype, + ) + else: + raise ValueError("Unexpected base config: %s" % base_config) + return (act_config, weight_config) diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index b7ad792dc1..9c06264be8 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -11,38 +11,192 @@ from torchao.quantization.granularity import ( PerAxis, PerGroup, + PerRow, PerToken, ) from torchao.quantization.quant_primitives import ( _DTYPE_TO_BIT_WIDTH, _DTYPE_TO_QVALUE_BOUNDS, MappingType, + _choose_scale_float8, + _dequantize_affine_float8, + _fake_quantize_affine, + _quantize_affine_float8, _Round, choose_qparams_affine, ) from torchao.quantization.utils import ( _get_per_token_block_size, + get_block_size, get_group_qparams_symmetric, get_groupwise_affine_qparams, ) -from .api import ( - FakeQuantizeConfig, +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + Float8FakeQuantizeConfig, + Int4WeightFakeQuantizeConfig, + IntxFakeQuantizeConfig, ) from .utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, - _Float8RowwiseFakeQuantize, + _log_deprecation_warning, ) -class FakeQuantizer(torch.nn.Module): +class FakeQuantizerBase(torch.nn.Module): """ Generic module for applying fake quantization to a tensor, as specified in the config. """ - def __init__(self, config: FakeQuantizeConfig): + config: FakeQuantizeConfigBase + + def __repr__(self) -> str: + """ + Return a human readable representation of this `FakeQuantizer` with config details. + """ + return "FakeQuantizer(%s)" % self.config + + @staticmethod + def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase": + # TODO: rewrite using registration API so we don't need to import here + from torchao.prototype.qat import ( + NVFP4FakeQuantizeConfig, + NVFP4FakeQuantizer, + ) + + if isinstance(config, IntxFakeQuantizeConfig): + return IntxFakeQuantizer(config) + elif isinstance(config, Int4WeightFakeQuantizeConfig): + return Int4WeightFakeQuantizer(config) + elif isinstance(config, Float8FakeQuantizeConfig): + return Float8FakeQuantizer(config) + elif isinstance(config, NVFP4FakeQuantizeConfig): + return NVFP4FakeQuantizer(config) + else: + raise ValueError(f"Unknown config type: {config}") + + +class Float8FakeQuantizer(FakeQuantizerBase): + """ + Generic module for applying float8 fake quantization to a tensor, as specified in the config. + """ + + def __init__(self, config: Float8FakeQuantizeConfig): + super().__init__() + self.config = config + torch._C._log_api_usage_once("torchao.quantization.qat.Float8FakeQuantizer") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + original_dtype = x.dtype + block_size = get_block_size(x.shape, self.config.granularity) + scale = _choose_scale_float8( + x, + block_size, + self.config.dtype, + hp_value_lb=self.config.hp_value_lb, + hp_value_ub=self.config.hp_value_ub, + ) + q = _quantize_affine_float8(x, scale, self.config.dtype) + dq = _dequantize_affine_float8(q, scale, original_dtype) + return dq + + +class Int4WeightFakeQuantizer(FakeQuantizerBase): + """ + Generic module for applying int4 fake quantization to a weight tensor, + targeting the following FBGEMM kernels: + torch.ops.fbgemm.f8i4bf16_shuffled + torch.ops.fbgemm.bf16i4bf16_shuffled + torch.ops.fbgemm.bf16i4bf16_rowwise + """ + + def __init__(self, config: Int4WeightFakeQuantizeConfig): + super().__init__() + self.config = config + torch._C._log_api_usage_once("torchao.quantization.qat.Int4WeightFakeQuantizer") + + def forward(self, w: torch.Tensor) -> torch.Tensor: + if self.config.activation_dtype == torch.float8_e4m3fn: + return self._fp8_activations_forward(w) + elif self.config.activation_dtype == torch.bfloat16: + return self._bf16_activations_forward(w) + else: + raise ValueError(f"Unknown activation dtype {self.config.activation_dtype}") + + def _fp8_activations_forward(self, w: torch.Tensor) -> torch.Tensor: + """ + Apply int4 fake quantization to the weight tensor where the input activations + are expected to be rowwise fp8, using the following as a reference: + https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L136 + """ + assert w.dim() == 2 + assert self.config.activation_dtype == torch.float8_e4m3fn + + # First quantize weights to fp8 per row + # This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row + per_row_block_size = get_block_size(w.shape, PerRow()) + fp8_scale = _choose_scale_float8( + w, + per_row_block_size, + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + w_fp8 = _quantize_affine_float8(w, fp8_scale, torch.float8_e4m3fn) + w_fp8 = _dequantize_affine_float8(w_fp8, fp8_scale, w.dtype) + + # Now quantize to int4 per group + # This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize + eps = 1e-6 + fbgemm_scale_quant_max = 8 + w_fp8_grouped = w_fp8.view(w_fp8.shape[0], -1, self.config.group_size) + max_abs = torch.amax(torch.abs(w_fp8_grouped), dim=-1, keepdim=False) + scale = torch.clamp(max_abs / fbgemm_scale_quant_max, min=eps) + zero_point = torch.zeros_like(scale) + per_group_block_size = (1, self.config.group_size) + fq = _fake_quantize_affine( + w_fp8, + per_group_block_size, + scale, + zero_point, + quant_dtype=torch.int8, + quant_min=-8, + quant_max=7, + ) + return fq.to(w.dtype) + + def _bf16_activations_forward(self, w: torch.Tensor) -> torch.Tensor: + """ + Apply int4 fake quantization to the weight tensor where the input activations + are expected to be bf16, using the following as a reference: + https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L152 + """ + assert w.dim() == 2 + assert self.config.activation_dtype == torch.bfloat16 + + eps = 1e-6 + qmin, qmax = 0, 15 + fbgemm_symmetric_qmax = 8 + w_grouped = w.to(torch.float32).view(w.shape[0], -1, self.config.group_size) + max_val = torch.amax(w_grouped, dim=-1, keepdim=True) + min_val = torch.amin(w_grouped, dim=-1, keepdim=True) + scale = torch.clamp(max_val - min_val, min=eps) / qmax + zero_point = min_val + scale * fbgemm_symmetric_qmax + fq = _Round.apply((w_grouped - min_val) / scale).clamp(qmin, qmax) + fq = fq - fbgemm_symmetric_qmax + fq = fq * scale + zero_point + return fq.view(w.shape).to(w.dtype) + + +class IntxFakeQuantizer(FakeQuantizerBase): + """ + Generic module for applying integer fake quantization to a tensor, as specified in the config. + """ + + def __init__(self, config: IntxFakeQuantizeConfig): super().__init__() + torch._C._log_api_usage_once("torchao.quantization.qat.IntxFakeQuantizer") self.config = config self.enabled = True self.scale: Optional[torch.Tensor] = None @@ -177,33 +331,21 @@ def _maybe_update_qparams_for_range_learning(self) -> None: qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] # Stabilize range learning scale = torch.clamp(scale, min=self._scale_eps) - zero_point = _Round.apply(zero_point) - zero_point = torch.clamp(zero_point, qmin, qmax) self.scale = torch.nn.Parameter(scale, requires_grad=True) - self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True) - - def __repr__(self) -> str: - """ - Return a human readable representation of this `FakeQuantizer` with config details. - """ - return "FakeQuantizer(%s)" % self.config + if self.config.is_symmetric: + self.zero_point.zero_() + else: + zero_point = _Round.apply(zero_point) + zero_point = torch.clamp(zero_point, qmin, qmax) + self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True) -class _Float8RowwiseActivationFakeQuantizer(torch.nn.Module): +# For BC +class FakeQuantizer(IntxFakeQuantizer): """ - Simple fake quantizer for float8 rowwise fake quantization, intended for activations only. + (Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizer` instead. """ - def __init__(self): - super().__init__() - self.enabled = True - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.enabled: - return _Float8RowwiseFakeQuantize.apply( - x, - torch.float8_e4m3fn, - -1, - ) - else: - return x + def __init__(self, config: FakeQuantizeConfigBase): + super().__init__(config) + _log_deprecation_warning(self) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 567b87f342..61f783ab8c 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from torchao.dtypes.utils import is_device -from torchao.quantization.granularity import PerGroup +from torchao.quantization.granularity import PerGroup, PerRow from torchao.quantization.linear_quant_modules import ( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, @@ -25,12 +25,14 @@ ) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 -from .api import FakeQuantizeConfig +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + Float8FakeQuantizeConfig, + IntxFakeQuantizeConfig, +) from .fake_quantizer import ( - FakeQuantizer, - _Float8RowwiseActivationFakeQuantizer, + FakeQuantizerBase, ) from .utils import ( _get_qmin_qmax, @@ -46,12 +48,12 @@ class FakeQuantizedLinear(torch.nn.Linear): Example usage:: - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, group_size=8, is_symmetric=True, @@ -67,8 +69,8 @@ def __init__( in_features: int, out_features: int, bias: bool = False, - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, + activation_config: Optional[FakeQuantizeConfigBase] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, *args, **kwargs, ) -> None: @@ -79,22 +81,27 @@ def __init__( *args, **kwargs, ) + torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedLinear") # initialize activation fake quantizer if activation_config is not None: - self.activation_fake_quantizer = FakeQuantizer(activation_config) + self.activation_fake_quantizer = FakeQuantizerBase.from_config( + activation_config + ) else: self.activation_fake_quantizer = None # initialize weight fake quantizer if weight_config is not None: - if isinstance(weight_config.granularity, PerGroup): + if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance( + weight_config.granularity, PerGroup + ): group_size = weight_config.group_size if group_size is not None and in_features % group_size != 0: raise ValueError( "in_features (%s) %% group_size (%s) must be == 0" % (in_features, group_size) ) - self.weight_fake_quantizer = FakeQuantizer(weight_config) + self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config) else: self.weight_fake_quantizer = None @@ -127,8 +134,8 @@ def to_linear(self) -> torch.nn.Linear: def from_linear( cls, mod: torch.nn.Linear, - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, + activation_config: Optional[FakeQuantizeConfigBase] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, ): new_linear = FakeQuantizedLinear( mod.in_features, @@ -148,29 +155,12 @@ def from_linear( return new_linear -# =========================== -# | QAT quantizer interface | -# =========================== - - -class _LegacyQATQuantizer(TwoStepQuantizer): - """ - Base class for sharing common methods across legacy QAT quantizers. - """ - - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: - return None - - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: - return None - - def enable_linear_fake_quant( mod: torch.nn.Module, enabled: bool = True, ): """ - Helper function to enable fake quantization in `FakeQuantizerLinear`. + Helper function to enable fake quantization in `FakeQuantizedLinear`. """ if isinstance(mod, FakeQuantizedLinear): if mod.activation_fake_quantizer is not None: @@ -181,11 +171,28 @@ def enable_linear_fake_quant( def disable_linear_fake_quant(mod: torch.nn.Module): """ - Helper function to disable fake quantization in `FakeQuantizerLinear`. + Helper function to disable fake quantization in `FakeQuantizedLinear`. """ enable_linear_fake_quant(mod, enabled=False) +# =========================== +# | QAT quantizer interface | +# =========================== + + +class _LegacyQATQuantizer(TwoStepQuantizer): + """ + Base class for sharing common methods across legacy QAT quantizers. + """ + + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: + return None + + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: + return None + + # =========================================== # | int8 dynamic activations + int4 weights | # =========================================== @@ -206,6 +213,9 @@ def __init__( scales_precision: torch.dtype = torch.float32, ) -> None: super().__init__() + torch._C._log_api_usage_once( + "torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer" + ) self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed self.precision: torch.dtype = precision @@ -281,10 +291,10 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): else: self._convert_qat_linear_8da4w(child) - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_8da4w_activation_config(self.activation_scales_precision) - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_8da4w_weight_config(self.groupsize, self.scales_precision) @@ -339,7 +349,7 @@ def disable_fake_quant(self): # TODO: remove these in favor of enable_linear_fake_quant def enable_8da4w_fake_quant(mod: torch.nn.Module): """ - Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + (deprecated) Enable fake quantization for `Int8DynActInt4WeightQATLinear`. """ if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.enable_fake_quant() @@ -348,19 +358,21 @@ def enable_8da4w_fake_quant(mod: torch.nn.Module): # TODO: remove in favor of disable_linear_fake_quant def disable_8da4w_fake_quant(mod: torch.nn.Module): """ - Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + (deprecated) Disable fake quantization for `Int8DynActInt4WeightQATLinear`. """ if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.disable_fake_quant() -def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantizeConfig: +def _get_8da4w_activation_config( + qparams_precision: torch.dtype, +) -> IntxFakeQuantizeConfig: """ - Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. + Return the activation `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ # TODO: generalize this assert qparams_precision == torch.float32 - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, @@ -374,11 +386,11 @@ def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantize def _get_8da4w_weight_config( group_size: int, qparams_precision: torch.dtype, -) -> FakeQuantizeConfig: +) -> IntxFakeQuantizeConfig: """ - Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. + Return the weight `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=TorchAODType.INT4, group_size=group_size, is_symmetric=True, @@ -407,6 +419,9 @@ def __init__( scales_precision: torch.dtype = torch.bfloat16, ) -> None: super().__init__() + torch._C._log_api_usage_once( + "torchao.quantization.qat.Int4WeightOnlyQATQuantizer" + ) assert inner_k_tiles in [2, 4, 8] assert groupsize in [32, 64, 128, 256] self.inner_k_tiles = inner_k_tiles @@ -464,10 +479,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - if ( - is_device(q_weight.device.type, "cpu") - and TORCH_VERSION_AT_LEAST_2_6 - ): + if is_device(q_weight.device.type, "cpu"): q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( q_weight.to(child.weight.device), child.inner_k_tiles, @@ -482,7 +494,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): else: self._convert_qat_linear_4w(child) - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_4w_weight_config(self.groupsize, self.scales_precision) @@ -535,7 +547,7 @@ def disable_fake_quant(self): # TODO: remove these in favor of enable_linear_fake_quant def enable_4w_fake_quant(mod: torch.nn.Module): """ - Enable fake quantization for `Int4WeightOnlyQATLinear`. + (deprecated) Enable fake quantization for `Int4WeightOnlyQATLinear`. """ if isinstance(mod, Int4WeightOnlyQATLinear): mod.enable_fake_quant() @@ -544,7 +556,7 @@ def enable_4w_fake_quant(mod: torch.nn.Module): # TODO: remove these in favor of disable_linear_fake_quant def disable_4w_fake_quant(mod: torch.nn.Module): """ - Disable fake quantization for `Int4WeightOnlyQATLinear`. + (deprecated) Disable fake quantization for `Int4WeightOnlyQATLinear`. """ if isinstance(mod, Int4WeightOnlyQATLinear): mod.disable_fake_quant() @@ -553,11 +565,11 @@ def disable_4w_fake_quant(mod: torch.nn.Module): def _get_4w_weight_config( group_size: int, qparams_precision: torch.dtype, -) -> FakeQuantizeConfig: +) -> IntxFakeQuantizeConfig: """ - Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. + Return the weight `IntxFakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. """ - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=torch.uint4, group_size=group_size, is_symmetric=False, @@ -591,11 +603,18 @@ def __init__( group_size: Optional[int] = 64, scale_precision: torch.dtype = torch.bfloat16, ): + torch._C._log_api_usage_once( + "torchao.quantization.qat.Float8ActInt4WeightQATQuantizer" + ) if group_size is not None: weight_granularity = "per_group" else: weight_granularity = "per_channel" - self._weight_config = FakeQuantizeConfig( + self._activation_config = Float8FakeQuantizeConfig( + dtype=torch.float8_e4m3fn, + granularity=PerRow(), + ) + self._weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, granularity=weight_granularity, group_size=group_size, @@ -613,14 +632,11 @@ def prepare( """ for name, child in model.named_children(): if isinstance(child, torch.nn.Linear): - # TODO: add a config for float8? new_linear = FakeQuantizedLinear.from_linear( child, + activation_config=self._activation_config, weight_config=self._weight_config, ) - new_linear.activation_fake_quantizer = ( - _Float8RowwiseActivationFakeQuantizer() - ) setattr(model, name, new_linear) else: self.prepare(child) @@ -632,8 +648,8 @@ def convert( ) -> torch.nn.Module: raise NotImplementedError - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet") - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return self.weight_config diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index 4f3323a1e8..c5f339c945 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import warnings +from typing import Any import torch @@ -16,64 +18,6 @@ ) -class _Float8RowwiseFakeQuantize(torch.autograd.Function): - """ - Implementation of float8 rowwise fake quantize with backward STE. - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - x: torch.Tensor, - float8_dtype: torch.dtype, - axiswise_dim: int, - ): - # compute rowwise scale based on `torchao.float8.float8_utils.tensor_to_scale` - eps = 1e-12 - amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) - amax = amax.to(torch.float64) - scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=eps) - scale = scale.to(torch.float32) - - # fake quantize - max_value = torch.finfo(float8_dtype).max - x_fq = x.to(torch.float32) * scale - x_fq = x_fq.clamp(min=-max_value, max=max_value) - x_fq = x_fq.to(float8_dtype).to(x.dtype) - x_fq = x_fq / scale - return x_fq.to(x.dtype) - - @staticmethod - def backward(ctx, gy): - return gy, None, None - - -# TODO: delete? -class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function): - """ - Helper autograd function to unwrap `AffineFakeQuantizedTensor` while ensuring - gradients are still passed to the tensor subclass. This is used in place of - `_GenericFakeQuantize` when fake quant is disabled. - """ - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - input: torch.Tensor, - ) -> torch.Tensor: - # avoid circular dependencies - from torchao.quantization.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - ) - - assert isinstance(input, AffineFakeQuantizedTensor) - return input.original_tensor - - @staticmethod - def backward(ctx, gy): - return (gy,) - - def _fake_quantize_per_channel_group( input: torch.Tensor, scales: torch.Tensor, @@ -130,3 +74,33 @@ def _get_qmin_qmax(n_bit: int, symmetric: bool = True): qmin = 0 qmax = 2**n_bit - 1 return (qmin, qmax) + + +def _log_deprecation_warning(old_api_object: Any): + """ + Log a helpful deprecation message pointing users to the new QAT API, + only once per deprecated class. + """ + warnings.warn( + """'%s' is deprecated and will be removed in a future release. Please use the following API instead: + + base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) + quantize_(model, QATConfig(base_config, step="prepare")) + # train (not shown) + quantize_(model, QATConfig(base_config, step="convert")) + +Alternatively, if you prefer to pass in fake quantization configs: + + activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) + qat_config = QATConfig( + activation_config=activation_config, + weight_config=weight_config, + step="prepare", + ) + quantize_(model, qat_config) + +Please see https://github.com/pytorch/ao/issues/2630 for more details. + """ + % old_api_object.__class__.__name__ + ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7287ae2bc0..15caddcadc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-2025 Arm Limited and affiliates. # All rights reserved. - # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. @@ -15,7 +15,6 @@ and mixed GEMM kernels """ -import importlib.util import logging import types import warnings @@ -47,8 +46,6 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, - to_fbgemm_fp8, - to_fbgemm_int4, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -67,19 +64,35 @@ from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, ) -from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size +from torchao.quantization.observer import AffineQuantizedObserverBase +from torchao.quantization.quantize_.common import ( + KernelPreference, +) +from torchao.quantization.quantize_.workflows import ( + Float8Tensor, + Int4ChooseQParamsAlgorithm, + Int4MarlinSparseTensor, + Int4OpaqueTensor, + Int4PackingFormat, + Int4PlainInt32Tensor, + Int4PreshuffledTensor, + Int4Tensor, + Int4TilePackedTo4dTensor, + IntxOpaqueTensor, + IntxPackingFormat, + IntxUnpackedToInt8Tensor, + QuantizeTensorToFloat8Kwargs, +) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) +from torchao.quantization.utils import get_block_size from torchao.quantization.weight_tensor_linear_activation_quantization import ( to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - is_fbcode, + _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -111,11 +124,9 @@ _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain, + quantize_affine, ) from .subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) from .unified import Quantizer, TwoStepQuantizer @@ -123,6 +134,7 @@ logger = logging.getLogger(__name__) +# TODO: revisit this list? __all__ = [ "swap_conv2d_1x1_to_linear", "Quantizer", @@ -147,7 +159,6 @@ "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", - "FbgemmConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -165,109 +176,6 @@ } -###### -# TO BE DEPRECATED START -###### -def _in_features_greater_than_16(mod, *args): - return hasattr(mod, "in_features") and mod.in_features > 16 - - -def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` - Tensor subclass, effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - if filter_fn is None: - filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( - *args - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - filter_fn, - ) - - -def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the - `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_weight_only_int8_quant while not modifying the linear modules. - """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs - ), - _is_linear if filter_fn is None else filter_fn, - ) - - -def change_linear_weights_to_int4_woqtensors( - model, - groupsize=128, - inner_k_tiles=8, - filter_fn=None, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, -): - """ - Converts all linear weight tensors to the - `Int4WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained, choices are [256, 128, 64, 32] - `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] - `filter_fn`: function that takes a nn.Module instance and fully qualified name of the module, \ - returns True if we want to run `config` on - `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, \ - ZeroPointDomain.INT, ZeroPointDomain.NONE] - `preserve_zero`: whether to preserve zero, default is False - """ - if TORCH_VERSION_AT_LEAST_2_4: - raise ImportError( - "This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs" - ) - - if filter_fn is None: - filter_fn = _is_linear - - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter( - Int4WeightOnlyQuantizedLinearWeight, - enable_parametrization=False, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - ), - filter_fn, - ) - - -######## -# TO BE DEPRECATED END -######## - - def _replace_with_custom_fn_if_matches_filter( model, replacement_fn, @@ -369,7 +277,7 @@ def _replace_with_custom_fn_if_matches_filter_with_name( def _is_linear(mod, *args): # avoid circular dependencies from torchao.quantization.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, + _AffineFakeQuantizedTensor, ) # adding weight tensor subclass isinstance check to make sure the weight is only quantized once @@ -381,7 +289,7 @@ def _is_linear(mod, *args): and not isinstance(mod.weight, AutoQuantizableLinearWeight) and not isinstance(mod.weight, AffineQuantizedTensor) and not isinstance(mod.weight, LinearActivationQuantizedTensor) - and not isinstance(mod.weight, AffineFakeQuantizedTensor) + and not isinstance(mod.weight, _AffineFakeQuantizedTensor) and not isinstance(mod, nn.modules.linear.NonDynamicallyQuantizableLinear) ) @@ -543,16 +451,20 @@ def _quantization_type(weight: torch.Tensor): if hasattr(weight, "_quantization_type"): return f"{weight.__class__.__name__}({weight._quantization_type()})" - if type(weight) is torch.Tensor: - return "not quantized" + if type(weight) is torch.Tensor or isinstance(weight, torch.nn.Parameter): + return f"Tensor: {type(weight)}" - return "not recognized" + return f"not recognized: {type(weight)}" def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" +def _embedding_extra_repr(self): + return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}" + + def _get_linear_subclass_inserter( constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs ): @@ -599,18 +511,19 @@ def quantize_( # optimized execution paths or kernels (e.g. int4 tinygemm kernel) # also customizable with arguments # currently options are - # int8_dynamic_activation_int4_weight (for executorch) - # int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile) - # int4_weight_only (optimized with int4 tinygemm kernel and torch.compile) - # int8_weight_only (optimized with int8 mm op and torch.compile + # Int8DynamicActivationInt4WeightConfig (for executorch) + # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) + # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) + # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile from torchao.quantization.quant_api import int4_weight_only m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - quantize_(m, int4_weight_only(group_size=32)) + quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1)) """ - filter_fn = _is_linear if filter_fn is None else filter_fn + torch._C._log_api_usage_once("torchao.quantization.quantize_") + filter_fn = _is_linear if filter_fn is None else filter_fn if isinstance(config, ModuleFqnToConfig): _replace_with_custom_fn_if_matches_filter_with_name( model, @@ -645,20 +558,15 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: scale_dtype = torch.float32 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int8 - if TORCH_VERSION_AT_LEAST_2_6: - return to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - ) - else: - return to_affine_quantized_intx( - x, mapping_type, _get_per_token_block_size(x), target_dtype - ) + return to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + ) def _uint8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: @@ -669,27 +577,17 @@ def _uint8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: zero_point_dtype = torch.int32 quant_min = 0 quant_max = 255 - if TORCH_VERSION_AT_LEAST_2_6: - out = to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - ) - else: - out = to_affine_quantized_intx( - x, - mapping_type, - _get_per_token_block_size(x), - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - ) + out = to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + ) return out @@ -733,14 +631,25 @@ class Int8DynamicActivationInt4WeightConfig(AOBaseConfig): act_mapping_type: MappingType = MappingType.ASYMMETRIC set_inductor_config: bool = True + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Int8DynamicActivationInt4WeightConfig" + ) + # for BC -int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig +int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( + "int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig +) @register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) def _int8_dynamic_activation_int4_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationInt4WeightConfig + module: torch.nn.Module, + config: Int8DynamicActivationInt4WeightConfig, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, ): group_size = config.group_size layout = config.layout @@ -793,6 +702,8 @@ def _int8_dynamic_activation_int4_weight_transform( quant_min=0, quant_max=15, _layout=layout, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, ) else: weight = to_affine_quantized_intx( @@ -803,6 +714,8 @@ def _int8_dynamic_activation_int4_weight_transform( quant_min, quant_max, _layout=layout, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, ) weight = to_linear_activation_quantized(weight, input_quant_func) module.weight = torch.nn.Parameter(weight, requires_grad=False) @@ -821,18 +734,28 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): are the same. However, this layout is more general and supports other weight dtypes. args: - weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. - torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6 - weight_granularity: The granularity to use for weight quantization. Must be PerGroup or PerAxis(axis=0). - weight_mapping_type: The type of mapping to use for the weight quantization. + `weight_dtype`: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. + ` weight_granularity`: The granularity to use for weight quantization. Must be PerGroup or PerAxis(axis=0). + `weight_mapping_type`: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. MappingType.SYMMETRIC requires ZeroPointDomain.NONE - weight_scale_dtype: The dtype to use for the weight scale. - act_mapping_type: The type of mapping to use for the activation quantization. + `weight_scale_dtype`: The dtype to use for the weight scale. + `act_mapping_type`: The type of mapping to use for the activation quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. - layout: The layout to use for the packed weight tensor: + `layout`: The layout to use for the packed weight tensor: - PackedLinearInt8DynamicActivationIntxWeightLayout: this layout is optimized for CPU performance. - QDQLayout: this layout represents the quantization with Q/DQ quant primitives, and is intended for export applications like ExecuTorch. + `intx_packing_format`: The format to use for the packed weight tensor (version 2 only). + - unpacked_to_int8: this format is the default and is intended for export applications like ExecuTorch. + - opaque_torchao_auto: this format is optimized for CPU performance. + `version`: version of the config to use, only subset of above args are valid based on version, see note for more details. + + Note: + + Current state for Int8DynamicActivationIntxWeightConfig is that it supports both v1 (legacy) and v2. + + * `intx_packing_format` is used for version 2. + * `layout` is only used for version 1. """ weight_dtype: torch.dtype = torch.int8 @@ -842,10 +765,13 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig): weight_scale_dtype: Optional[torch.dtype] = None act_mapping_type: MappingType = MappingType.ASYMMETRIC layout: Layout = QDQLayout() + intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8 + + version: int = 2 def __post_init__(self): - assert TORCH_VERSION_AT_LEAST_2_6, ( - "Int8DynamicActivationIntxWeightConfig requires torch 2.6+" + torch._C._log_api_usage_once( + "torchao.quantization.Int8DynamicActivationIntxWeightConfig" ) assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" @@ -860,8 +786,9 @@ def __post_init__(self): assert self.weight_mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, + MappingType.SYMMETRIC_NO_CLIPPING_ERR, ], ( - f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.weight_mapping_type}" + f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}" ) assert self.act_mapping_type in [ MappingType.ASYMMETRIC, @@ -888,33 +815,92 @@ def __post_init__(self): ) -@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig) -def _int8_dynamic_activation_intx_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationIntxWeightConfig -) -> torch.nn.Module: - weight = module.weight - bias = module.bias +def _int8_dynamic_activation_intx_weight_quantize_tensor( + weight, + bias, + config, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, +): weight_dtype = config.weight_dtype weight_granularity = config.weight_granularity weight_mapping_type = config.weight_mapping_type weight_scale_dtype = config.weight_scale_dtype act_mapping_type = config.act_mapping_type layout = config.layout + intx_packing_format = config.intx_packing_format - assert weight.dim() == 2, f"weight must be 2D, but got {weight.dim()}D" + assert weight.dim() == 2, ( + f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}" + ) if isinstance(weight_granularity, PerGroup): group_size = weight_granularity.group_size elif isinstance(weight_granularity, PerAxis): - assert weight_granularity.axis == 0, "axis must be 0" + assert weight_granularity.axis == 0, ( + f"axis must be 0 with PerAxis, but got {weight_granularity.axis}" + ) group_size = weight.shape[-1] else: raise ValueError( f"weight_granularity must be PerGroup or PerAxis, got {weight_granularity}" ) + block_size = (1, group_size) + + if config.version == 2: + assert act_mapping_type == MappingType.ASYMMETRIC + opaque_formats = [ + IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI, + IntxPackingFormat.OPAQUE_TORCHAO_AUTO, + IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI, + IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT, + ] + assert ( + intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8 + or intx_packing_format in opaque_formats + ), f"Unsupported packing format: {intx_packing_format}" + if custom_zero_point is not None and custom_zero_point.dtype == torch.int32: + custom_zero_point = custom_zero_point.to(torch.int8) + new_weight = IntxUnpackedToInt8Tensor.from_hp( + weight, + block_size, + weight_dtype, + mapping_type=weight_mapping_type, + activation_quantization="int8_asym_per_token", + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, + ) + if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype: + _adjust_scale_dtype_in_intx_unpacked_tensor( + new_weight, weight, weight_scale_dtype + ) + + new_bias = bias + + # Create packed tensor + if intx_packing_format in opaque_formats: + new_weight = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor( + new_weight, bias=new_bias, intx_packing_format=intx_packing_format + ) + new_bias = None # bias is packed with weights + + return new_weight, new_bias + + # Version 1 + assert config.version == 1 + warnings.warn( + "Config Deprecation: version 1 of Int8DynamicActivationIntxWeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2967 for more details" + ) quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype] # We quantize with QDQLayout, and then construct the packed weight tensor later + # set preserve_zero based on weight mapping type + preserve_zero = weight_mapping_type in [ + MappingType.SYMMETRIC, + MappingType.SYMMETRIC_NO_CLIPPING_ERR, + ] + weight = to_affine_quantized_intx( input_float=weight, mapping_type=weight_mapping_type, @@ -924,7 +910,7 @@ def _int8_dynamic_activation_intx_weight_transform( quant_max=quant_max, scale_dtype=weight_scale_dtype, zero_point_dtype=torch.int8, - preserve_zero=(weight_mapping_type == MappingType.SYMMETRIC), + preserve_zero=preserve_zero, zero_point_domain=ZeroPointDomain.INT, _layout=QDQLayout(), ) @@ -965,9 +951,29 @@ def _int8_dynamic_activation_intx_weight_transform( # bias is packed with weights if present bias = None - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.bias = bias - module.extra_repr = types.MethodType(_linear_extra_repr, module) + return weight, bias + + +@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig) +def _int8_dynamic_activation_intx_weight_transform( + module: torch.nn.Module, + config: Int8DynamicActivationIntxWeightConfig, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, +) -> torch.nn.Module: + new_weight, new_bias = _int8_dynamic_activation_intx_weight_quantize_tensor( + module.weight, + module.bias, + config, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + if new_bias is None: + module.bias = None + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -987,9 +993,16 @@ class Int4DynamicActivationInt4WeightConfig(AOBaseConfig): act_mapping_type: MappingType = MappingType.SYMMETRIC set_inductor_config: bool = True + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Int4DynamicActivationInt4WeightConfig" + ) + # for bc -int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig +int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( + "int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig +) @register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) @@ -1043,9 +1056,16 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig): mode: Optional[str] = "weight_only" set_inductor_config: bool = True + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.GemliteUIntXWeightOnlyConfig" + ) + # for BC -gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig +gemlite_uintx_weight_only = _ConfigDeprecationWrapper( + "gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig +) @register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) @@ -1083,26 +1103,29 @@ def _gemlite_uintx_weight_only_transform( @dataclass class Int4WeightOnlyConfig(AOBaseConfig): """ - Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using - "tensor_core_tiled" layout for speedup with tinygemm kernel - - Note: - This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm` - and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference - of quantization algorithm compared to the more traditional type of integer quantization is the following: - 1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`) - 2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`) - please follow the relevant code in `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` - to learn about how the quantization parameters are chosen and how the Tensor is quantized/dequantized for tinygemm + Configuration for int4 weight only quantization, only groupwise quantization is supported + right now, and we support version 1 and version 2, that are implemented differently although with + same support. In version 2, different target are mainly distinguished by `packing_format` arg, and in version 1, mainly by `layout`. Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained, choices are [256, 128, 64, 32] - `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` - `use_hqq`: whether to use hqq or default quantization mode, default is False - `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] - `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. - `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT + size is more fine grained, choices are [256, 128, 64, 32], used in both version 1 and 2 + `int4_packing_format`: the packing format for int4 tensor, used in version 2 only + `int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4, + currently support TINYGEMM ("tinygemm") and HQQ ("hqq"), used in version 2 only + `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`, used in version 1 only + `use_hqq`: whether to use hqq or default quantization mode, default is False, used in version 1 only + `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE], used in version 1 only + `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. used in both version 1 and 2 + `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT, used in version 1 only + `version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 2, see note for more details + + Note: + Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2 + + For v2 (version = 2), only `group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored + For v1 (version = 1), only `group_size`, `layout`, `use_hqq`, `zero_point_domain`, `preserve_zero` and `set_inductor_config` are valid, we plan to deprecate v1 in torchao 0.15 to make this config + less confusing """ group_size: int = 128 @@ -1111,11 +1134,20 @@ class Int4WeightOnlyConfig(AOBaseConfig): zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE set_inductor_config: bool = True preserve_zero: Optional[bool] = None + # only used in version >= 2 + int4_packing_format: Int4PackingFormat = Int4PackingFormat.PLAIN + int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = ( + Int4ChooseQParamsAlgorithm.TINYGEMM + ) + version: int = 2 + + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") # for BC # TODO maybe change other callsites -int4_weight_only = Int4WeightOnlyConfig +int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig) def _int4_weight_only_quantize_tensor(weight, config): @@ -1127,7 +1159,9 @@ def _int4_weight_only_quantize_tensor(weight, config): group_size = config.group_size layout = config.layout use_hqq = config.use_hqq + int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm zero_point_domain = config.zero_point_domain + int4_packing_format = config.int4_packing_format if weight.shape[-1] % group_size != 0: logger.info( @@ -1135,8 +1169,68 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return weight + block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) + + if config.version == 2: + block_size = list(block_size) + + if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: + assert int4_packing_format in [ + Int4PackingFormat.TILE_PACKED_TO_4D, + Int4PackingFormat.OPAQUE, + ], ( + f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, " + f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D and Int4PackingFormat.OPAQUE currently" + ) + + if int4_packing_format == Int4PackingFormat.PRESHUFFLED: + new_weight = Int4PreshuffledTensor.from_hp( + weight, + block_size, + activation_dtype=torch.bfloat16, + ) + return new_weight + elif int4_packing_format == Int4PackingFormat.PLAIN: + new_weight = Int4Tensor.from_hp( + weight, + block_size, + ) + return new_weight + elif int4_packing_format == Int4PackingFormat.PLAIN_INT32: + new_weight = Int4PlainInt32Tensor.from_hp( + weight, + block_size, + ) + return new_weight + elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE: + new_weight = Int4MarlinSparseTensor.from_hp( + weight, + block_size, + ) + return new_weight + elif int4_packing_format == Int4PackingFormat.OPAQUE: + new_weight = Int4OpaqueTensor.from_hp( + weight, + block_size, + int4_choose_qparams_algorithm=int4_choose_qparams_algorithm, + ) + return new_weight + elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D: + new_weight = Int4TilePackedTo4dTensor.from_hp( + weight, + block_size, + int4_choose_qparams_algorithm=int4_choose_qparams_algorithm, + ) + return new_weight + else: + raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}") + + assert config.version == 1 + + warnings.warn( + "Config Deprecation: version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2948 for more details" + ) mapping_type = MappingType.ASYMMETRIC - block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size]) target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -1208,18 +1302,67 @@ def _int4_weight_only_transform( return module +@dataclass +class Float8DynamicActivationInt4WeightConfig(AOBaseConfig): + """Configuration for apply float8 dynamic per row quantization and int4 + per group weight quantization to linear + (only group_size 128 is supported right now since underlying kernel used only supports 128 + and above and no benefits of making it bigger) + + Args: + `int4_packing_format`: how the weight is packed, only preshuffled is supported + """ + + int4_packing_format: Int4PackingFormat = "preshuffled" + + +@register_quantize_module_handler(Float8DynamicActivationInt4WeightConfig) +def _float8_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationInt4WeightConfig +) -> torch.nn.Module: + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + int4_packing_format = config.int4_packing_format + + assert int4_packing_format == "preshuffled", ( + f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}" + ) + weight = module.weight + group_size = 128 + block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) + new_weight = Int4PreshuffledTensor.from_hp( + module.weight, + block_size, + activation_dtype=torch.float8_e4m3fn, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + @dataclass class Int8WeightOnlyConfig(AOBaseConfig): """ Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers. + + Args: + group_size: Optional[int] = None - Controls the granularity of quantization. If None, applies per-channel quantization. + Otherwise, applies per-group quantization with the specified group size. + set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values + for better performance with this quantization scheme. """ group_size: Optional[int] = None set_inductor_config: bool = True + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") + # for BC -int8_weight_only = Int8WeightOnlyConfig +int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig) def _int8_weight_only_quantize_tensor(weight, config): @@ -1355,7 +1498,17 @@ def _float8_cutlass_quant_sparse( class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): """ Configuration for applying int8 dynamic symmetric per-token activation and int8 per-channel weight - quantization to linear layers + quantization to linear layers. + + Args: + layout: Optional[Layout] = PlainLayout() - Tensor layout for the quantized weights. Controls how the + quantized data is stored and accessed. + act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC - Mapping type for activation quantization. + SYMMETRIC uses symmetric quantization around zero. + weight_only_decode: bool = False - If True, only quantizes weights during forward pass and keeps activations + in original precision during decode operations. + set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values + for better performance with this quantization scheme. """ layout: Optional[Layout] = PlainLayout() @@ -1363,9 +1516,16 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): weight_only_decode: bool = False set_inductor_config: bool = True + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Int8DynamicActivationInt8WeightConfig" + ) + # for BC -int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig +int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper( + "int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig +) def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): @@ -1377,7 +1537,7 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): # int8 dynamic quantization only has benefit when in_feature > 16 if in_features <= 16: logger.info( - f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" + f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}" f" because `in_feature` is <= 16: {in_features}" ) return weight @@ -1441,12 +1601,14 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. + warnings.warn( + """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in Int8DynamicActivationInt8WeightConfig instead. from torchao.dtypes import SemiSparseLayout - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""") + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()""" + ) - return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) + return Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()) @dataclass @@ -1457,6 +1619,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default) Note: The actual matmul will be computed in original precision of the weight tensor. @@ -1464,23 +1627,39 @@ class Float8WeightOnlyConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype set_inductor_config: bool = True + version: int = 2 + + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") # for BC -float8_weight_only = Float8WeightOnlyConfig +float8_weight_only = _ConfigDeprecationWrapper( + "float8_weight_only", Float8WeightOnlyConfig +) def _float8_weight_only_quant_tensor(weight, config): - from torchao.dtypes import to_affine_quantized_floatx + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Float8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" + ) + from torchao.dtypes import to_affine_quantized_floatx - block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) - new_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=config.weight_dtype, - scale_dtype=None, - _layout=Float8Layout(mm_config=None), - ) + block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) + new_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=config.weight_dtype, + scale_dtype=None, + _layout=Float8Layout(mm_config=None), + ) + else: + assert config.version == 2, f"Unexpected version: {config.version}" + weight_dtype = config.weight_dtype + new_weight = Float8Tensor.from_hp( + weight, float8_dtype=weight_dtype, granularity=PerRow() + ) return new_weight @@ -1578,13 +1757,17 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): Args: activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. - granularity: + granularity (Optional[Union[FP8Granularity, List[FP8Granularity]]]): The granularity for quantization. Can be either a single granularity (applied to both activations and weights) or a tuple of two granularities (one for activations, one for weights). If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And only PerTensor and PerRow are supported. mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + activation_value_lb (Optional[float]): the lower bound for activation value for calculating scale + activation_value_ub (Optional[float]): the upper bound for activation value for calculating scale + kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. by defalut (KernelPreference.AUTO) it will be chosen for user based on hardware or other information, this only needs to be set in weight set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default) """ @@ -1592,12 +1775,18 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None + activation_value_lb: Optional[float] = None + activation_value_ub: Optional[float] = None + kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True + version: int = 2 def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Float8DynamicActivationFloat8WeightConfig" + ) if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=True) - activation_granularity, weight_granularity = _normalize_granularity( self.granularity ) @@ -1605,7 +1794,9 @@ def __post_init__(self): # for bc -float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig +float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper( + "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig +) def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): @@ -1613,6 +1804,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): weight_dtype = config.weight_dtype granularity = config.granularity mm_config = config.mm_config + activation_value_lb = config.activation_value_lb + activation_value_ub = config.activation_value_ub + kernel_preference = config.kernel_preference # Ensure works on device _check_hardware_support(granularity) @@ -1622,31 +1816,56 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked return weight + if isinstance(weight_granularity, PerRow): assert weight.dtype == torch.bfloat16, ( "PerRow quantization only works for bfloat16 precision input weight" ) - block_size = get_block_size(weight.shape[-2:], weight_granularity) - if weight.dim() == 3: - block_size = tuple([1] + list(block_size)) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" + ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } + block_size = get_block_size(weight.shape[-2:], weight_granularity) + if weight.dim() == 3: + block_size = tuple([1] + list(block_size)) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) + + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } + + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + else: + assert config.version == 2, f"Unexpected version: {config.version}" + act_quant_kwargs = QuantizeTensorToFloat8Kwargs( + activation_dtype, + activation_granularity, + hp_value_lb=activation_value_lb, + hp_value_ub=activation_value_ub, + kernel_preference=kernel_preference, + ) + + quantized_weight = Float8Tensor.from_hp( + weight, + float8_dtype=weight_dtype, + granularity=weight_granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + act_quant_kwargs=act_quant_kwargs, + ) - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs - ) return quantized_weight @@ -1687,6 +1906,11 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig): activation_dtype: torch.dtype = e5m2_dtype weight_dtype: torch.dtype = e4m3_dtype + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Float8DynamicActivationFloat8SemiSparseWeightConfig" + ) + @register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig) def _float8_dynamic_activation_float8_semi_sparse_weight_transform( @@ -1735,16 +1959,19 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): granularity: Optional[ Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]] ] = None - mm_config: Optional[Float8MMConfig] = None + mm_config: Optional[Float8MMConfig] = Float8MMConfig(use_fast_accum=True) set_inductor_config: bool = True def __post_init__(self): - if self.mm_config is None: - self.mm_config = Float8MMConfig(use_fast_accum=True) + torch._C._log_api_usage_once( + "torchao.quantization.Float8StaticActivationFloat8WeightConfig" + ) # for bc -float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig +float8_static_activation_float8_weight = _ConfigDeprecationWrapper( + "float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig +) @register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) @@ -1822,9 +2049,14 @@ class UIntXWeightOnlyConfig(AOBaseConfig): use_hqq: bool = False set_inductor_config: bool = True + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig") + # for BC -uintx_weight_only = UIntXWeightOnlyConfig +uintx_weight_only = _ConfigDeprecationWrapper( + "uintx_weight_only", UIntXWeightOnlyConfig +) @register_quantize_module_handler(UIntXWeightOnlyConfig) @@ -1860,7 +2092,7 @@ def _uintx_weight_only_transform( if use_hqq: if dtype == torch.uint4: logger.warning( - "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" + "Recommended to use `Int4WeightOnlyConfig(group_size, use_hqq=True, version=1)` for the best performance" ) quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] dtype = torch.uint8 @@ -1896,6 +2128,31 @@ def _uintx_weight_only_transform( return module +def _adjust_scale_dtype_in_intx_unpacked_tensor( + intx_unpacked_tensor: IntxUnpackedToInt8Tensor, + hp_tensor: torch.Tensor, + scale_dtype: torch.dtype, +) -> None: + """ + Adjusts the scale_dtype on IntxUnpackedToInt8Tensor. + Updating the scale dtype requires updating the qdata because qdata is calculated after the scale. + This is used in IntxWeightOnlyConfig and Int8DynamicActivationIntxWeightConfig to make + version=2 and version=1 numerically equivalent when the scale_dtype differs from the input dtype + """ + assert isinstance(intx_unpacked_tensor, IntxUnpackedToInt8Tensor) + intx_unpacked_tensor.scale = intx_unpacked_tensor.scale.to(scale_dtype) + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[intx_unpacked_tensor.target_dtype] + intx_unpacked_tensor.qdata = quantize_affine( + hp_tensor, + intx_unpacked_tensor.block_size, + intx_unpacked_tensor.scale, + intx_unpacked_tensor.zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + + @dataclass class IntxWeightOnlyConfig(AOBaseConfig): """ @@ -1903,15 +2160,23 @@ class IntxWeightOnlyConfig(AOBaseConfig): Weights are quantized with scales/zeros in a groupwise or channelwise manner using the number of bits specified by weight_dtype. args: - weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. - torch.intx with x < 8 requires TORCH_VERSION_AT_LEAST_2_6 - granularity: The granularity to use for weight quantization. Must be PerGroup or PerAxis(0). - mapping_type: The type of mapping to use for the weight quantization. + `weight_dtype`: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. + `granularity`: The granularity to use for weight quantization. Must be PerGroup or PerAxis(0). + `mapping_type`: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. - scale_dtype: The dtype to use for the weight scale. - layout: The layout to use for the packed weight tensor: + `scale_dtype`: The dtype to use for the weight scale. + `layout`: The layout to use for the packed weight tensor: - QDQLayout: this layout is designed for export to ExecuTorch.this layout represents the quantization with Q/DQ quant primitives, and is intended for export applications like ExecuTorch. + `intx_packing_format`: The format to use for the packed weight tensor (version 2 only). + `version`: version of the config to use, only subset of above args are valid based on version, see note for more details. + + Note: + + Current state for IntxWeightOnlyConfig is that it supports both v1 (legacy) and v2. + + * `intx_packing_format` is used for version 2. + * `layout` is only used for version 1. """ weight_dtype: torch.dtype = torch.int8 @@ -1919,9 +2184,11 @@ class IntxWeightOnlyConfig(AOBaseConfig): mapping_type: MappingType = MappingType.SYMMETRIC scale_dtype: Optional[torch.dtype] = None layout: Layout = QDQLayout() + intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8 + version: int = 2 def __post_init__(self): - assert TORCH_VERSION_AT_LEAST_2_6, "IntxWeightOnlyConfig requires torch 2.6+" + torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig") assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" ) @@ -1932,21 +2199,27 @@ def __post_init__(self): assert self.granularity.axis == 0, ( f"axis must be 0 with PerAxis, but got {self.granularity.axis}" ) - assert self.mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], ( + assert self.mapping_type in [ + MappingType.ASYMMETRIC, + MappingType.SYMMETRIC, + ], ( f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" ) -@register_quantize_module_handler(IntxWeightOnlyConfig) -def _intx_weight_only_transform( - module: torch.nn.Module, config: IntxWeightOnlyConfig -) -> torch.nn.Module: - weight = module.weight +def _intx_weight_only_quantize_tensor( + weight, + config, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, +): weight_dtype = config.weight_dtype granularity = config.granularity mapping_type = config.mapping_type scale_dtype = config.scale_dtype layout = config.layout + intx_packing_format = config.intx_packing_format assert weight.dim() == 2, ( f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" @@ -1961,11 +2234,39 @@ def _intx_weight_only_transform( else: raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") + block_size = (1, group_size) + + if config.version == 2: + if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8: + if custom_zero_point is not None and custom_zero_point.dtype == torch.int32: + custom_zero_point = custom_zero_point.to(torch.int8) + new_weight = IntxUnpackedToInt8Tensor.from_hp( + weight, + block_size, + weight_dtype, + mapping_type=mapping_type, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, + ) + if scale_dtype is not None and scale_dtype != weight.dtype: + _adjust_scale_dtype_in_intx_unpacked_tensor( + new_weight, weight, scale_dtype + ) + + return new_weight + else: + raise ValueError(f"Unsupported packing format: {intx_packing_format}") + + # Version 1 + assert config.version == 1 + warnings.warn( + "Config Deprecation: version 1 of IntxWeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2967 for more details" + ) quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype] weight = to_affine_quantized_intx( input_float=weight, mapping_type=mapping_type, - block_size=(1, group_size), + block_size=block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, @@ -1975,7 +2276,34 @@ def _intx_weight_only_transform( zero_point_domain=ZeroPointDomain.INT, _layout=layout, ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) + return weight + + +@register_quantize_module_handler(IntxWeightOnlyConfig) +def _intx_weight_only_transform( + module: torch.nn.Module, + config: IntxWeightOnlyConfig, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, +) -> torch.nn.Module: + assert hasattr(module, "weight"), ( + "applying intx weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _intx_weight_only_quantize_tensor( + module.weight, + config, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + elif isinstance(module, nn.Embedding): + module.extra_repr = types.MethodType(_embedding_extra_repr, module) + return module @@ -1995,9 +2323,12 @@ class FPXWeightOnlyConfig(AOBaseConfig): mbits: int set_inductor_config: bool = True + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig") + # for BC -fpx_weight_only = FPXWeightOnlyConfig +fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig) @register_quantize_module_handler(FPXWeightOnlyConfig) @@ -2030,73 +2361,6 @@ def _fpx_weight_only_transform( return module -@dataclass -class FbgemmConfig(AOBaseConfig): - """Quantization Config for fbgemm-genai kernels - Args: - input_dtype (torch.dtype): input dtype of the kernel - weight_dtype (torch.dtype): weight dtype of the kernel - output_dtype (torch.dtype): output dtype of the kernel - group_size (int): The group size for weight - """ - - input_dtype: torch.dtype - weight_dtype: torch.dtype - output_dtype: torch.dtype - block_size: Optional[List[int]] = None - activation_scale_ub: Optional[float] = None - transpose_input: bool = False - - -@register_quantize_module_handler(FbgemmConfig) -def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: - # TODO: use is_package_at_least("fbgemm_gpu", "1.2.0") when - # https://github.com/pytorch/FBGEMM/issues/4198 is fixed - if importlib.util.find_spec("fbgemm_gpu") is None: - raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - - import fbgemm_gpu.experimental.gen_ai # noqa: F401 - - if not is_fbcode() and fbgemm_gpu.__version__ < "1.2.0": - raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - - _SUPPORTED_DTYPES = { - (torch.bfloat16, torch.int4, torch.bfloat16), - (e4m3_dtype, e4m3_dtype, torch.bfloat16), - } - - if ( - (config.input_dtype == torch.bfloat16) - and (config.weight_dtype == torch.int4) - and (config.output_dtype == torch.bfloat16) - ): - weight = to_fbgemm_int4( - module.weight, - config.block_size, - config.transpose_input, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - elif ( - (config.input_dtype == e4m3_dtype) - and (config.weight_dtype == e4m3_dtype) - and (config.output_dtype == torch.bfloat16) - ): - weight = to_fbgemm_fp8( - module.weight, - config.activation_scale_ub, - config.transpose_input, - ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) - return module - else: - raise NotImplementedError( - f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}" - ) - - @dataclass class ModuleFqnToConfig(AOBaseConfig): """Per module configurations for torchao quantize_ API @@ -2113,6 +2377,9 @@ class ModuleFqnToConfig(AOBaseConfig): default_factory=dict ) + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig") + def _module_fqn_to_config_handler( module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig @@ -2132,16 +2399,15 @@ def _module_fqn_to_config_handler( return module -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals( - [ - _int8_asymm_per_token_quant, - _int8_symm_per_token_reduced_range_quant, - _input_activation_quant_func_fp8, - _int4_symm_cutlass_quant, - _int8_symm_cutlass_quant, - _float8_cutlass_quant, - _float8_cutlass_quant_sparse, - Target, - ] - ) +torch.serialization.add_safe_globals( + [ + _int8_asymm_per_token_quant, + _int8_symm_per_token_reduced_range_quant, + _input_activation_quant_func_fp8, + _int4_symm_cutlass_quant, + _int8_symm_cutlass_quant, + _float8_cutlass_quant, + _float8_cutlass_quant_sparse, + Target, + ] +) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 56e8422197..cdfbc00c3a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -16,10 +16,8 @@ _n_ones, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, _register_custom_op, + _register_meta_op, ) __all__ = [ @@ -35,7 +33,7 @@ "_choose_qparams_affine_floatx", "_choose_qparams_and_quantize_affine_hqq", "_choose_qparams_and_quantize_affine_qqq", - "_choose_qparams_affine_float8", + "_choose_scale_float8", "_choose_qparams_gguf", "_quantize_affine_no_zero_point", "_quantize_affine_tinygemm", @@ -106,8 +104,7 @@ class TorchAODType(Enum): INT7 = auto() -if TORCH_VERSION_AT_LEAST_2_5: - torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) +torch.serialization.add_safe_globals([MappingType, ZeroPointDomain]) FP8_TYPES = { torch.float8_e4m3fn, @@ -151,53 +148,49 @@ class TorchAODType(Enum): TorchAODType.INT7: (-(2**6), 2**6 - 1), } -# torch.uintX available only in PyTorch 2.3+ -if TORCH_VERSION_AT_LEAST_2_3: - _SUB_BYTE_UINT_BOUNDS = { - torch.uint1: (0, 2**1 - 1), - torch.uint2: (0, 2**2 - 1), - torch.uint3: (0, 2**3 - 1), - torch.uint4: (0, 2**4 - 1), - torch.uint5: (0, 2**5 - 1), - torch.uint6: (0, 2**6 - 1), - torch.uint7: (0, 2**7 - 1), +_SUB_BYTE_UINT_BOUNDS = { + torch.uint1: (0, 2**1 - 1), + torch.uint2: (0, 2**2 - 1), + torch.uint3: (0, 2**3 - 1), + torch.uint4: (0, 2**4 - 1), + torch.uint5: (0, 2**5 - 1), + torch.uint6: (0, 2**6 - 1), + torch.uint7: (0, 2**7 - 1), +} +_DTYPE_TO_BIT_WIDTH.update( + { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, } - _DTYPE_TO_BIT_WIDTH.update( - { - torch.uint1: 1, - torch.uint2: 2, - torch.uint3: 3, - torch.uint4: 4, - torch.uint5: 5, - torch.uint6: 6, - torch.uint7: 7, - } - ) - -# torch.intX available only in PyTorch 2.6+ -if TORCH_VERSION_AT_LEAST_2_6: - _SUB_BYTE_INT_BOUNDS.update( - { - torch.int1: (-(2**0), 2**0 - 1), - torch.int2: (-(2**1), 2**1 - 1), - torch.int3: (-(2**2), 2**2 - 1), - torch.int4: (-(2**3), 2**3 - 1), - torch.int5: (-(2**4), 2**4 - 1), - torch.int6: (-(2**5), 2**5 - 1), - torch.int7: (-(2**6), 2**6 - 1), - } - ) - _DTYPE_TO_BIT_WIDTH.update( - { - torch.int1: 1, - torch.int2: 2, - torch.int3: 3, - torch.int4: 4, - torch.int5: 5, - torch.int6: 6, - torch.int7: 7, - } - ) +) + +_SUB_BYTE_INT_BOUNDS.update( + { + torch.int1: (-(2**0), 2**0 - 1), + torch.int2: (-(2**1), 2**1 - 1), + torch.int3: (-(2**2), 2**2 - 1), + torch.int4: (-(2**3), 2**3 - 1), + torch.int5: (-(2**4), 2**4 - 1), + torch.int6: (-(2**5), 2**5 - 1), + torch.int7: (-(2**6), 2**6 - 1), + } +) +_DTYPE_TO_BIT_WIDTH.update( + { + torch.int1: 1, + torch.int2: 2, + torch.int3: 3, + torch.int4: 4, + torch.int5: 5, + torch.int6: 6, + torch.int7: 7, + } +) _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) _DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS) @@ -226,6 +219,20 @@ def backward(ctx, gy: torch.Tensor) -> torch.Tensor: return gy +class _RoundToFloat8(torch.autograd.Function): + """ + Implementation of `tensor.to(float8_dtype)` with backward STE. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor, float8_dtype: torch.dtype) -> torch.Tensor: + return x.to(float8_dtype) + + @staticmethod + def backward(ctx, gy: torch.Tensor) -> torch.Tensor: + return gy, None + + # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): """Get quant_min and quant_max args based on dtype and also verify bounds. @@ -1172,7 +1179,7 @@ def _do_fake_quantize_affine( elif zero_point_domain == ZeroPointDomain.FLOAT: _quantize_affine = _quantize_affine_tinygemm_no_dtype_cast _dequantize_affine = _dequantize_affine_tinygemm_no_dtype_check - elif ZeroPointDomain == ZeroPointDomain.NONE: + elif zero_point_domain == ZeroPointDomain.NONE: _quantize_affine = _quantize_affine_no_zero_point_no_dtype_cast _dequantize_affine = _dequantize_affine_no_zero_point_no_dtype_check else: @@ -2178,25 +2185,32 @@ def _dequantize_affine_floatx( return tensor -def _choose_qparams_affine_float8( +@register_custom_op +def _choose_scale_float8( tensor: torch.Tensor, + block_size: List[int], float8_dtype: torch.dtype = torch.float8_e4m3fn, scale_dtype: torch.dtype = torch.float32, - block_size: Optional[Tuple[int, ...]] = None, + hp_value_lb: Optional[float] = None, + hp_value_ub: Optional[float] = None, ) -> torch.Tensor: """ - Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. + Calculates float8 scaling factor for the given high precision tensor. Args: tensor (torch.Tensor): Input tensor to be quantized. float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). scale_dtype (torch.dtype): Data type of the scaling factor (e.g., torch.float32). block_size (Optional[Tuple[int, ...]]): Block size for block-wise quantization. If None, tensorwise quantization is used. + hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale + hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale """ quant_max = torch.finfo(float8_dtype).max - # only tensorwise scaling is supported for now: - if block_size is None: + if len(block_size) == 0: + # tensorwise max_abs = tensor.abs().max() + if hp_value_lb is not None or hp_value_ub is not None: + max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub) scale = max_abs / quant_max else: shape_for_reduction, reduction_dims = _get_reduction_params( @@ -2204,7 +2218,8 @@ def _choose_qparams_affine_float8( ) tensor_reshaped = tensor.view(shape_for_reduction) max_abs = tensor_reshaped.abs().amax(dim=reduction_dims, keepdim=True) - + if hp_value_lb is not None or hp_value_ub is not None: + max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub) scale = max_abs / quant_max # Reshape scale back to match the expected output shape # The scale tensor should have the same shape as the input divided by block_size @@ -2220,11 +2235,12 @@ def _choose_qparams_affine_float8( return scale.to(dtype=torch.float32) -def _expand_scale_to_tensor_shape( +def _maybe_expand_scale_to_tensor_shape( scale: torch.Tensor, target_shape: torch.Size ) -> torch.Tensor: """ Expand a scale tensor to match the target tensor shape for block-wise quantization. + If this is rowwise quantization, however, just return the scale as is. Args: scale (torch.Tensor): Scale tensor with shape corresponding to block structure @@ -2241,6 +2257,11 @@ def _expand_scale_to_tensor_shape( # Scalar scale - can broadcast naturally return scale + # If the scale can be broadcast as is, then we don't need to expand it + # E.g. for rowwise quantization, scale = [256, 1] and target_shape = [256, 512] + if all(a == b or a == 1 for a, b in zip(scale.shape, target_shape)): + return scale + # Calculate block sizes from shape difference if len(scale.shape) != len(target_shape): raise ValueError( @@ -2270,7 +2291,6 @@ def _expand_scale_to_tensor_shape( return expanded_scale -@_register_custom_op(quant_lib, False) def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, @@ -2282,16 +2302,48 @@ def _quantize_affine_float8( tensor_fp32 = tensor.to(torch.float32) # Expand scale to match tensor dimensions for block-wise quantization - scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape) + scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape) tensor_scaled = tensor_fp32 / scale_expanded max_value = torch.finfo(float8_dtype).max tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) - fp8_tensor = tensor_clamped.to(float8_dtype) - return fp8_tensor + return _RoundToFloat8.apply(tensor_clamped, float8_dtype) -@torch.library.impl(quant_lib, "quantize_affine_float8", "Meta") +def _dequantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantizes the float8 tensor to high precision tensor. + """ + fp8_tensor = tensor.to(torch.float32) + + # Expand scale to match tensor dimensions for block-wise quantization + scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape) + + hp_tensor = fp8_tensor * scale_expanded + return hp_tensor.to(output_dtype) + + +@_register_custom_op(quant_lib, False) +def _quantize_affine_float8_non_decomposed( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, +) -> torch.Tensor: + """ + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. + """ + return _quantize_affine_float8( + tensor=tensor, + scale=scale, + float8_dtype=float8_dtype, + ) + + +@_register_meta_op(quant_lib, "quantize_affine_float8_non_decomposed") def _quantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor, @@ -2301,7 +2353,7 @@ def _quantize_affine_float8_meta( @_register_custom_op(quant_lib, False) -def _dequantize_affine_float8( +def _dequantize_affine_float8_non_decomposed( tensor: torch.Tensor, scale: torch.Tensor, output_dtype: torch.dtype = torch.float32, @@ -2309,16 +2361,14 @@ def _dequantize_affine_float8( """ Dequantizes the float8 tensor to high precision tensor. """ - fp8_tensor = tensor.to(torch.float32) - - # Expand scale to match tensor dimensions for block-wise quantization - scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape) - - hp_tensor = fp8_tensor * scale_expanded - return hp_tensor.to(output_dtype) + return _dequantize_affine_float8( + tensor=tensor, + scale=scale, + output_dtype=output_dtype, + ) -@torch.library.impl(quant_lib, "dequantize_affine_float8", "Meta") +@_register_meta_op(quant_lib, "dequantize_affine_float8_non_decomposed") def _dequantize_affine_float8_meta( tensor: torch.Tensor, scale: torch.Tensor, diff --git a/torchao/quantization/quantize_/__init__.py b/torchao/quantization/quantize_/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/quantize_/common/__init__.py b/torchao/quantization/quantize_/common/__init__.py new file mode 100644 index 0000000000..19f6e26807 --- /dev/null +++ b/torchao/quantization/quantize_/common/__init__.py @@ -0,0 +1,15 @@ +from .kernel_preference import KernelPreference +from .packing_format import PackingFormat +from .protocol import SupportsActivationPreScaling +from .quantize_tensor_kwargs import ( + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) + +__all__ = [ + "QuantizeTensorKwargs", + "KernelPreference", + "PackingFormat", + "SupportsActivationPreScaling", + "_choose_quant_func_and_quantize_tensor", +] diff --git a/torchao/quantization/quantize_/common/kernel_preference.py b/torchao/quantization/quantize_/common/kernel_preference.py new file mode 100644 index 0000000000..8f53f55c6a --- /dev/null +++ b/torchao/quantization/quantize_/common/kernel_preference.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + +import torch + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class KernelPreference(str, Enum): + """Enum for specifying the groups of kernels that's used for quantization, matrix multiplication + or other compute ops for quantized tensor + + Examples of how options affects the selected kernels can be found in tensor subclass implementations under torchao/quantization/quantize_/workflows + """ + + """Use the most efficient quantize and mm kernels chosen for user based on hardware and library availabilities and versions etc. + """ + AUTO = "auto" + + """Use torch native quantize and quantized mm kernels + """ + TORCH = "torch" + + """Use quantize and quantized mm kernels from fbgemm_gpu_genai library, requires fbgemm_gpu_genai library + """ + FBGEMM = "fbgemm" + + +torch.serialization.add_safe_globals([KernelPreference]) diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py new file mode 100644 index 0000000000..c6546c55f9 --- /dev/null +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class PackingFormat(str, Enum): + """Packing format for quantized data in Tensor subclasses in torchao, represents how + the values are packed and laid out in the quantized data. + """ + + """ + plain means the format that quantized Tensor data lays out elements in Tensor sequentially, + for example: for a Tensor of shape (4, 6): + a_0_0, a_0_1, ..., a_0_5, + ... + a_3_0, a_3_1, ..., a_3_5 + + Note that it's different for different dtypes, for example for int4, we will + pack two adjacent int4 elements into one uint8/int8 value for plain packing format + """ + PLAIN = "plain" + + """ + Opaque packing format that's used for tensors that does not have a predefined packing format + (that may be decided on hardware, tensor shape, library availability etc.) and it's not + needed for the rest of the system to understand the specific format that's adopted. + """ + OPAQUE = "opaque" diff --git a/torchao/quantization/quantize_/common/protocol.py b/torchao/quantization/quantize_/common/protocol.py new file mode 100644 index 0000000000..2266dc7e25 --- /dev/null +++ b/torchao/quantization/quantize_/common/protocol.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +"""Protocols for some functionalities in tensor subclasses""" + +from typing import Optional, Protocol, runtime_checkable + +import torch + + +@runtime_checkable +class SupportsActivationPreScaling(Protocol): + """Protocol for activation scale that should be multiplied with activation before quantization, + or before we use activation in matrix multiplications, used for algorithms like AWQ + + A class that have `act_pre_scale: Optional[torch.Tensor]` attribute implements the Protocol + """ + + act_pre_scale: Optional[torch.Tensor] diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py new file mode 100644 index 0000000000..0adc8c786d --- /dev/null +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from typing import ClassVar + +import torch + +__all__ = [ + "QuantizeTensorKwargs", + "_choose_quant_func_and_quantize_tensor", +] + + +class QuantizeTensorKwargs(abc.ABC): + """Base class for keyword argument container for quantized tensor creation. This is needed to support storing activation construction arguments on the weight tensor while supporting multiple types of activation quantization. + + e.g. + + class Float8Tensor(...) + @classmethod + def from_hp(cls, tensor, quant_kwargs: QuantizeTensorKwargs) + ... + """ + + # Base Version of a config + VERSION: ClassVar[int] = 1 + + +def _choose_quant_func_and_quantize_tensor( + tensor: torch.Tensor, quant_kwargs: QuantizeTensorKwargs +) -> torch.Tensor: + """Given a tensor and a kwargs container, chooses a derived dtype (float8, int8, etc) to quantize tensor to, based on the type of quant_kwargs + quantizes tensor to the derived dtype chosen in (1) + This is needed to support flexible quantization of activation to various derived dtypes. + """ + from torchao.quantization.quantize_.workflows import ( + Float8Tensor, + QuantizeTensorToFloat8Kwargs, + ) + + if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): + return Float8Tensor.from_hp( + tensor, + quant_kwargs.float8_dtype, + quant_kwargs.granularity, + quant_kwargs.mm_config, + quant_kwargs.hp_value_lb, + quant_kwargs.hp_value_ub, + quant_kwargs.kernel_preference, + ) + + raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py new file mode 100644 index 0000000000..229c94c73a --- /dev/null +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -0,0 +1,47 @@ +from .float8.float8_tensor import ( + Float8Tensor, + QuantizeTensorToFloat8Kwargs, +) +from .int4.int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm +from .int4.int4_marlin_sparse_tensor import ( + Int4MarlinSparseTensor, +) +from .int4.int4_opaque_tensor import ( + Int4OpaqueTensor, +) +from .int4.int4_packing_format import Int4PackingFormat +from .int4.int4_plain_int32_tensor import ( + Int4PlainInt32Tensor, +) +from .int4.int4_preshuffled_tensor import ( + Int4PreshuffledTensor, +) +from .int4.int4_tensor import ( + Int4Tensor, +) +from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor +from .intx.intx_opaque_tensor import ( + IntxOpaqueTensor, +) +from .intx.intx_packing_format import ( + IntxPackingFormat, +) +from .intx.intx_unpacked_to_int8_tensor import ( + IntxUnpackedToInt8Tensor, +) + +__all__ = [ + "Int4Tensor", + "Int4PreshuffledTensor", + "Int4MarlinSparseTensor", + "Int4PlainInt32Tensor", + "Int4TilePackedTo4dTensor", + "Float8Tensor", + "QuantizeTensorToFloat8Kwargs", + "Int4OpaqueTensor", + "Int4ChooseQParamsAlgorithm", + "Int4PackingFormat", + "IntxPackingFormat", + "IntxUnpackedToInt8Tensor", + "IntxOpaqueTensor", +] diff --git a/torchao/quantization/quantize_/workflows/float8/__init__.py b/torchao/quantization/quantize_/workflows/float8/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py new file mode 100644 index 0000000000..bcae1fc756 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -0,0 +1,623 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import List, Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.utils import get_out_shape +from torchao.float8.inference import ( + Float8MMConfig, + FP8Granularity, + _is_rowwise_scaled, + _is_tensorwise_scaled, + _slice_scale_for_dimension, + addmm_float8_unwrapped_inference, + preprocess_data, + preprocess_scale, +) +from torchao.quantization.granularity import PerRow, PerTensor +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _dequantize_affine_float8, + _quantize_affine_float8, +) +from torchao.quantization.quantize_.common import ( + KernelPreference, + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) +from torchao.quantization.utils import get_block_size +from torchao.utils import ( + TorchAOBaseTensor, + _is_fbgemm_genai_gpu_available, + fill_defaults, + is_sm_at_least_90, +) + +__all__ = [ + "Float8Tensor", + "QuantizeTensorToFloat8Kwargs", +] + +aten = torch.ops.aten + + +@dataclass +class QuantizeTensorToFloat8Kwargs(QuantizeTensorKwargs): + """Tensor kwargs for creating float8 tensor (either activation or weight) + + Args: + dtype (torch.dtype): the dtype for float8 Tensor + granularity (FP8Granularity): the granularity for the Tensor, currently either PerRow() or PerTensor() + mm_config (Float8MMConfig): Configuration for the scaled_mm in the forward and backward pass. + hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale + hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale + kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. by defalut (None) it will be chosen for user based on hardware or other information + """ + + float8_dtype: torch.dtype = torch.float8_e4m3fn + granularity: FP8Granularity = PerRow() + mm_config: Optional[Float8MMConfig] = None + hp_value_lb: Optional[float] = None + hp_value_ub: Optional[float] = None + kernel_preference: KernelPreference = KernelPreference.AUTO + + +class Float8Tensor(TorchAOBaseTensor): + """ + Float8 Quantized (weight) Tensor, with float8 dynamic quantization for activation or bfloat16 activation. + + TODO: needs padding for cutlass kernels + + Tensor Attributes: + qdata: float8 raw data + scale: the scale for float8 Tensor + + Non-Tensor Attributes: + block_size (List[int]): the block size for float8 quantization, meaning the shape of the elements + sharing the same set of quantization parameters (scale), have the same rank as qdata or + is an empty list (representing per tensor quantization) + mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + act_quant_kwargs (QuantizeTensorToFloat8Kwargs): the kwargs for Float8Tensor.from_hp + kernel_preference (KernelPreference): the preference for quantize, mm etc. kernel to use, + by default, this will be chosen for user based on hardware, library availabilities etc. + dtype: Original Tensor dtype + """ + + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = [] + optional_tensor_attribute_names = [ + "block_size", + "mm_config", + "act_quant_kwargs", + "kernel_preference", + "dtype", + ] + + def __new__( + cls, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: Optional[List[int]] = None, + mm_config: Optional[Float8MMConfig] = None, + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + kernel_preference: KernelPreference = KernelPreference.AUTO, + dtype: Optional[torch.dtype] = None, + ): + shape = qdata.shape + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: Optional[List[int]] = None, + mm_config: Optional[Float8MMConfig] = None, + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + kernel_preference: KernelPreference = KernelPreference.AUTO, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.qdata = qdata + self.scale = scale + self.block_size = block_size + self.mm_config = mm_config + self.act_quant_kwargs = act_quant_kwargs + self.kernel_preference = kernel_preference + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " + f"{self.block_size=}, {self.mm_config=}, {self.kernel_preference=} " + f"{self.shape=}, {self.device=}, {self.dtype=})" + ) + + def _quantization_type(self): + return f"{self.act_quant_kwargs=}, {self.block_size=}, {self.mm_config=}, {self.scale.shape=}, {self.kernel_preference=}" + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + qdata, scale = self.qdata, self.scale + return _dequantize_affine_float8(qdata, scale, output_dtype) + + @classmethod + def from_hp( + cls, + hp_tensor: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + granularity: FP8Granularity = PerRow(), + mm_config: Optional[Float8MMConfig] = None, + hp_value_lb: Optional[float] = None, + hp_value_ub: Optional[float] = None, + kernel_preference: KernelPreference = KernelPreference.AUTO, + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + ): + block_size = get_block_size(hp_tensor.shape, granularity) + block_size = list(block_size) + + kernel_choice = None + if ( + kernel_preference == KernelPreference.AUTO + and _is_fbgemm_genai_gpu_available() + and is_sm_at_least_90() + and isinstance(granularity, PerRow) + and float8_dtype == torch.float8_e4m3fn + and hp_value_lb is None + ): + # if kernel_preference is AUTO and per row quantization + # we'll use fbgemm quantize kernel for best performance + kernel_choice = "fbgemm" + elif kernel_preference == KernelPreference.FBGEMM: + # if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel + assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), ( + "Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (>= H100)" + ) + assert hp_value_lb is None, ( + "hp_value_lb should not be specified if with KerenelPreference.FBGEMM" + ) + kernel_choice = "fbgemm" + else: + # fallback quantize kernel for everything else will be torch + kernel_choice = "torch" + + if kernel_choice == "fbgemm": + assert hp_value_lb is None, f"{hp_value_lb=} is not supported" + if hp_value_ub is not None: + maybe_hp_value_ub_tensor = torch.tensor( + hp_value_ub, dtype=torch.float, device=hp_tensor.device + ) + else: + maybe_hp_value_ub_tensor = None + if isinstance(granularity, PerRow): + data, scale = torch.ops.triton.quantize_fp8_row( + hp_tensor, scale_ub=maybe_hp_value_ub_tensor + ) + scale_shape = [] + for i in range(hp_tensor.ndim): + scale_shape.append(hp_tensor.shape[i] // block_size[i]) + scale = scale.reshape(*scale_shape) + else: + assert isinstance(granularity, PerTensor), ( + f"Expected per tensor, got {granularity}" + ) + # current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered + # TODO: enable after this is working + # data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor( + # hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor + # ) + raise NotImplementedError( + "Currently KernelPreference.FBGEMM does not work for per tensor float8 quant" + ) + else: + assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}" + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=float8_dtype, + block_size=block_size, + hp_value_lb=hp_value_lb, + hp_value_ub=hp_value_ub, + ) + data = _quantize_affine_float8(hp_tensor, scale, float8_dtype) + + hp_dtype = hp_tensor.dtype + return Float8Tensor( + data, + scale, + block_size=block_size, + mm_config=mm_config, + act_quant_kwargs=act_quant_kwargs, + kernel_preference=kernel_preference, + dtype=hp_dtype, + ) + + +implements = Float8Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert isinstance(weight_tensor, Float8Tensor), ( + f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" + ) + + act_quant_kwargs = weight_tensor.act_quant_kwargs + # quantizing activation, if `act_quant_kwargs` is specified + if act_quant_kwargs is not None: + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, act_quant_kwargs + ) + + if isinstance(input_tensor, Float8Tensor): + kernel_choice = None + + if weight_tensor.kernel_preference == KernelPreference.AUTO: + kernel_choice = "torch" + if _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(): + kernel_choice = "fbgemm" + elif weight_tensor.kernel_preference == KernelPreference.FBGEMM: + kernel_choice = "fbgemm" + else: + assert weight_tensor.kernel_preference == KernelPreference.TORCH, ( + f"{weight_tensor.kernel_preference=} not handled" + ) + kernel_choice = "torch" + + if kernel_choice == "fbgemm": + assert _is_fbgemm_genai_gpu_available(), ( + "Expected fbgemm_gpu_genai package to be installed" + ) + assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai" + mm_config = weight_tensor.mm_config + assert mm_config is not None + + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1]) + wq = weight_tensor.qdata + x_scale = input_tensor.scale + w_scale = weight_tensor.scale + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled(input_tensor), ( + "Input tensor must be rowwise block size" + ) + res = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, + wq, + x_scale, + w_scale, + bias=bias, + use_fast_accum=mm_config.use_fast_accum, + ).reshape(out_shape) + else: + assert _is_tensorwise_scaled(weight_tensor) + assert _is_tensorwise_scaled(input_tensor) + res = torch.ops.fbgemm.f8f8bf16( + xq, + wq, + x_scale * w_scale, + use_fast_accum=mm_config.use_fast_accum, + ).reshape(out_shape) + if bias is not None: + res = res + bias + return res + else: + assert kernel_choice == "torch" + scaled_mm_config = weight_tensor.mm_config + assert scaled_mm_config is not None + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + + # Extract tensor data and scales + inpt_data = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1]) + w_data = weight_tensor.qdata + input_scale = input_tensor.scale + w_scale = weight_tensor.scale + + # Handle rowwise scaling + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled(input_tensor), ( + "Input tensor must be rowwise block size" + ) + w_scale = w_scale.transpose(-1, -2) + + input_scale = preprocess_scale(input_scale, input_tensor.shape) + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + else: + assert not isinstance(input_tensor, TorchAOBaseTensor), ( + "Expecting input_tensor to be unquantized" + ) + # when input is not `Float8Tensor`, we expect that it is not quantized + # so this is float8 weight only quantization + return torch.nn.functional.linear( + input_tensor, weight_tensor.dequantize(), bias + ) + + +@implements(torch.bmm) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + assert isinstance(weight_tensor, Float8Tensor), ( + f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" + ) + + kernel_preference = weight_tensor.kernel_preference + assert kernel_preference != KernelPreference.TORCH, "bmm is not supported for TORCH" + assert _is_fbgemm_genai_gpu_available(), ( + "bmm is not supported when fbgemm_gpu_genai is not installed" + ) + + orig_act_size = input_tensor.size() + act_quant_kwargs = weight_tensor.act_quant_kwargs + if act_quant_kwargs is not None: + input_tensor = _choose_quant_func_and_quantize_tensor( + input_tensor, act_quant_kwargs + ) + + if isinstance(input_tensor, Float8Tensor): + a_data = input_tensor.qdata + a_scale = input_tensor.scale + + b_data = weight_tensor.qdata + b_scale = weight_tensor.scale.squeeze(-1) + assert b_data.is_contiguous(), "weight for bmm must be contiguous" + + assert ( + all(x == 1 for x in weight_tensor.block_size[:-1]) + and weight_tensor.block_size[-1] == weight_tensor.shape[-1] + ), "bmm only works for per row weight quantization" + assert ( + all(x == 1 for x in input_tensor.block_size[:-1]) + and input_tensor.block_size[-1] == input_tensor.shape[-1] + ), "bmm only works for per row activation quantization" + + orig_out_features = b_data.shape[-2] + + res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( + a_data, + b_data, + a_scale, + b_scale, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + else: + raise NotImplementedError( + "bmm only support float8 dynamic activation + float8 weight" + ) + + return res + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Only supports slicing for dim == 1 and dim == 2 + original tensor shape has dimension (N, K) + qdata has dimension (N, K) + scale (per row quantization) has dimension: (N,) + + since qdata has the same dimension as original tensor, we can directly slice that + for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1 + + Note that we need to call slice on the qdata and scale directly because slice + is an operation that need to preserve aliasing + """ + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + + assert self.qdata.ndim == 2, ( + f"Expected packed weight to have dim 2, got {self.qdata.dim}" + ) + + # Always slice the qdata + sliced_data = aten.slice.Tensor(self.qdata, dim, start, end, step) + + if self.scale.numel() == 1: + # Per-tensor quantization - scale doesn't change + sliced_scale = self.scale + else: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, start, end, step + ) + + # adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i] + block_size = self.block_size.copy() + for i in range(len(self.block_size)): + block_size[i] = min(block_size[i], sliced_data.shape[i]) + + return return_and_correct_aliasing( + func, + args, + kwargs, + Float8Tensor( + sliced_data, + sliced_scale, + block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + dtype=self.dtype, + ), + ) + + +@implements(aten.cat.default) +def _(func, types, args, kwargs): + """Concatenate multiple float8 quantized tensors + (scale and qdata has the same rank) + If the concatenation dimension is not the same as block_size, then we can just concatenate the + qdata and scale directly + If the concatention dimension is the same as block_size, theoretically we should either + (1) check that scales from all tensors are equal and use the first scale + (2) dequantize and requantize + but for now we just use the first scale directly, which might have slight implication on accuaracy + we can improve upon this a bit later + """ + + tensors, dim = fill_defaults(args, 2, [[], 0]) + tensor_0 = tensors[0] + dim = dim % tensor_0.ndim + + for i in range(1, len(tensors)): + assert tensor_0.qdata.ndim == tensors[i].qdata.ndim + assert tensor_0.scale.ndim == tensors[i].scale.ndim + assert tensor_0.block_size == tensors[i].block_size + assert tensor_0.mm_config == tensors[i].mm_config + assert tensor_0.act_quant_kwargs == tensors[i].act_quant_kwargs + assert tensor_0.kernel_preference == tensors[i].kernel_preference + + qdatas = [t.qdata for t in tensors] + scales = [t.scale for t in tensors] + + cat_qdata = aten.cat.default(qdatas, dim=dim) + if tensor_0.block_size[dim] == 1: + cat_scale = aten.cat.default(scales, dim=dim) + else: + for i in range(1, len(tensors)): + assert torch.equal(tensor_0.scale, tensors[i].scale) + cat_scale = scales[0] + + block_size = [] + for i in range(cat_qdata.ndim): + block_size.append(cat_qdata.shape[i] // cat_scale.shape[i]) + + new = tensor_0.__class__( + cat_qdata, + cat_scale, + block_size, + tensor_0.mm_config, + tensor_0.act_quant_kwargs, + tensor_0.kernel_preference, + tensor_0.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + self, dim0, dim1 = args + qdata = self.qdata.transpose(dim0, dim1) + scale = self.scale.transpose(dim0, dim1) + block_size = self.block_size.copy() + + block_size[dim0], block_size[dim1] = block_size[dim1], block_size[dim0] + + new = self.__class__( + qdata, + scale, + block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, size = args + original_shape = self.shape + if len(original_shape) == 3 and len(size) == 2: + assert original_shape[-1] == size[-1], ( + f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + ) + qdata = self.qdata.reshape(*size) + scale = self.scale.reshape(*size) + block_size = self.block_size.copy() + block_size = [block_size[0] * block_size[1], block_size[2]] + elif len(original_shape) == 2 and len(size) == 3: + assert original_shape[-1] == size[-1], ( + f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + ) + qdata = self.qdata.reshape(*size) + block_size = self.block_size.copy() + block_size = [1, block_size[0], block_size[1]] + scale_shape = [] + for i in range(3): + scale_shape.append(qdata.shape[i] // block_size[i]) + scale = self.scale.reshape(*scale_shape) + elif len(original_shape) == len(size): + assert all(x == y or y == -1 for x, y in zip(original_shape, size)), ( + f"Only support viewing with match dimensions or -1, got: {original_shape}, {size}" + ) + qdata = self.qdata.reshape(*size) + scale_shape = [] + for i in range(3): + scale_shape.append(qdata.shape[i] // self.block_size[i]) + scale = self.scale.reshape(*scale_shape) + block_size = self.block_size + else: + assert len(original_shape) == 2 and len(size) == 3, ( + f"Only support reshaping from 2D to 3D or from 3D to 2D, requested: reshaping from {original_shape} to {size}" + ) + + new = self.__class__( + qdata, + scale, + block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.squeeze.dim) +def _(func, types, args, kwargs): + self, dim = args + assert dim == 0, f"Only dim == 0 is supported, got: {dim}" + qdata = self.qdata.squeeze(dim=dim) + scale = self.scale.squeeze(dim=dim) + block_size = [] + for i in range(len(qdata.shape)): + block_size.append(qdata.shape[i] // scale.shape[i]) + + new = self.__class__( + qdata, + scale, + block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +Float8Tensor.__module__ = "torchao.quantization" + +# Allow a model with Float8Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs]) diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/quantize_/workflows/int4/int4_choose_qparams_algorithm.py b/torchao/quantization/quantize_/workflows/int4/int4_choose_qparams_algorithm.py new file mode 100644 index 0000000000..2258b3f3e2 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_choose_qparams_algorithm.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class Int4ChooseQParamsAlgorithm(str, Enum): + """Variant of quantization algorithm to calculate scale and zero_point""" + + """ + The choose qparams algorithm native for tinygemm kernel: + scale = (max_val - min_val) / float(quant_max - quant_min), where + max_val and min_val are the max/min for the slice of input Tensor based on block_size + quant_max and quant_min and max/min for the quantized value, e.g. 0, 15 for uint4 + zero_point = min_val + scale * mid_point, where + mid_point = (quant_max + quant_min + 1) / 2 + + implemented in `torchao.quantization.quant_primitives._choose_qparams_affine_tinygemm + """ + TINYGEMM = "tinygemm" + + """ + The choose qparams based on half-quadratic quantization: https://mobiusml.github.io/hqq_blog/ + + implemented in `torchao.quantization.quant_primitives._choose_qparams_and_quantize_affine_hqq` + """ + HQQ = "hqq" diff --git a/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py new file mode 100644 index 0000000000..f71d73de1c --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + choose_qparams_affine, + quantize_affine, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = [ + "Int4MarlinSparseTensor", +] + +aten = torch.ops.aten + + +class Int4MarlinSparseTensor(TorchAOBaseTensor): + tensor_data_names = ["qdata", "scale", "zero_point", "meta"] + tensor_attribute_names = ["block_size", "num_bits", "shape"] + + def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape): + super().__init__() + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.meta = meta + self.block_size = block_size + self.num_bits = num_bits + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + from torchao.sparsity.marlin import ( + const, + inject_24, # avoid circular import + pack_to_marlin_24, + ) + + """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. + - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format + - 2º: tensor is injected with 2:4 sparsity + - 3º: transposes it again because the quantization process will compute the scales for dim=-1 + """ + + w_t = w.t() + w_24, _ = inject_24(w_t, *w_t.shape) + preprocessed_w = w_24.t() + + assert block_size[-1] == 128 or block_size[-1] == preprocessed_w.shape[-1], ( + f"MarlinSparse only supports 128 group size or per channel quantization, got {block_size}" + ) + + quant_min = 0 + quant_max = 15 + target_dtype = torch.int32 + + scale, zero_point = choose_qparams_affine( + input=preprocessed_w, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=1e-6, + ) + + wq = quantize_affine( + input=preprocessed_w, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + + scale = scale.to(w.dtype) + zero_point = zero_point.to(w.dtype) + + # Linear layers are (in_features, out_features) but the qdata that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w_24 = wq.t() + # addressing the case when scale has dimension 1, happens when + # weight_shape[-1] == group_size == 128 + if scale.ndim == 1: + scale = scale.reshape(scale.shape[0], -1) + + scale_t = scale.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." + ) + + if q_w_24.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w_24.shape + if in_features % 128 != 0 or out_features != 256 == 0: + raise ValueError( + "`in_features` must be divisible by 64 and `out_features` by 256." + ) + + # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 + # will require a bit more work to get our current quantization flow to work with it. + # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main + num_bits = 4 if torch.max(q_w_24) < 16 else -1 + if num_bits not in [4]: + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") + + group_size = in_features // scale_t.shape[0] + if group_size == 0: + group_size = in_features + assert group_size <= in_features, ( + "Group size must be less than or equal to in_features." + ) + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin 2:4 format + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( + q_w_24, scale_t, num_bits, group_size + ) + + return cls( + qdata=marlin_24_q_w_comp, + scale=marlin_24_s, + zero_point=zero_point, + meta=meta, + block_size=group_size, + shape=q_w_24.shape, + num_bits=num_bits, + ) + + +implements = Int4MarlinSparseTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + from torchao.ops import marlin_24_gemm + from torchao.sparsity.marlin import marlin_24_workspace + + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" + assert weight_tensor.zero_point.is_contiguous(), ( + "Expected zero_point to be contiguous" + ) + + sparse_w_int4 = weight_tensor.qdata + scale = weight_tensor.scale + meta = weight_tensor.meta + original_shape = weight_tensor.shape + num_bits = weight_tensor.num_bits + + # Folds batch dimension into the first dimension + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) + + size_m = input_2d.shape[0] + size_n = scale.shape[1] + size_k = input_2d.shape[1] + workspace_24 = marlin_24_workspace(original_shape[1]) + + out = marlin_24_gemm( + input_2d, + sparse_w_int4, + meta, + scale, + workspace_24, + num_bits, + size_m, + size_n, + size_k, + ) + + # Unfold the batch dimension + out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) + + if bias is not None: + out += bias.to(out.dtype) + return out + + +Int4MarlinSparseTensor.__module__ = "torchao.quantization" + +# Allow a model with Int4MarlinSparseTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4MarlinSparseTensor]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py new file mode 100644 index 0000000000..57245f55a7 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from typing import List, Optional + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_affine_hqq, + _quantize_affine_tinygemm, +) +from torchao.quantization.utils import pack_tinygemm_scales_and_zeros +from torchao.utils import ( + TorchAOBaseTensor, +) + +from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm + +__all__ = [ + "Int4OpaqueTensor", +] + +aten = torch.ops.aten + + +class Int4OpaqueTensor(TorchAOBaseTensor): + """ + int4 weight-only quantization on CPU with tinygemm (groupwise quantization only). The packing format is determined on ISA and shape. + This is an opaque tensor subclass, the packing format is not exposed to the rest of the system. See the note below for more details. + + Tensor Attributes: + qdata: preshuffled and packed int4 weight for CPU tinygemm kernel, always viewed as a 2D (N, K/2) tensor, last dimension is packed + preshuffling is specific to CPU kernels based on ISA and shape, see Note below. + scale_and_zero: (K/group_size, N, 2), dtype is the same as the original Tensor dtype + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size). + we only support group_size = 32/64/128. + shape: shape of the original Tensor + + Optional Tensor Data Attributes: + act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present, + we'll multiply activation Tensor with act_pre_scale before applying dynamic + quantization to activation or running quantized mm op + + Note on Details for data layout for CPU tinygemm kernel: + + We use AVX512 to compute TINYGEMM on CPU. We can also leverage AVX512_VNNI and AMX instructions with torch.compile and max-autotune. + For data locality, we preshuffle the data in plain layout (N, K/2) to (N/block_n, K, block_n/2), where block_n = 64/32/16. + See https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 for more details. + """ + + tensor_data_names = ["qdata", "scale_and_zero"] + tensor_attribute_names = ["block_size", "shape"] + optional_tensor_data_names = ["act_pre_scale"] + + def __new__( + cls, + qdata, + scale_and_zero, + block_size, + shape, + act_pre_scale: Optional[torch.Tensor] = None, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale_and_zero.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale_and_zero: torch.Tensor, + block_size: List[int], + shape: torch.Size, + act_pre_scale: Optional[torch.Tensor] = None, + ): + super().__init__() + self.qdata = qdata + self.scale_and_zero = scale_and_zero + self.block_size = block_size + self.act_pre_scale = act_pre_scale + + def _quantization_type(self): + s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + if self.act_pre_scale is not None: + s += f", act_pre_scale.shape={self.act_pre_scale.shape}" + return s + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = Int4ChooseQParamsAlgorithm.TINYGEMM, + ): + assert w.ndim == 2 and w.device.type == "cpu", ( + f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + assert block_size[0] == 1 and block_size[1] in (32, 64, 128), ( + f"Expecting groupwise quantization with group size = 32/64/128, but got block_size: {block_size}" + ) + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + scale_dtype = None + zero_point_dtype = w.dtype + + # we support two paths for constructing a Int4OpaqueTensor + # 1. use [hqq](https://mobiusml.github.io/hqq_blog/) algorithm to compute + # scale and zero_point, then convert to the format that's compatible with tinygemm kernels + # 2. don't use hqq, use default tinygemm algorithm to compute scale and zero_point + # + # both approach should have the same performance since both are using CPU tinygemm kernel for gemm + # 1. typically will have higher accuracy compared to 2. + if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: + nbits = int(math.log2(quant_max + 1)) + axis = 1 + group_size = block_size[-1] + int_data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq( + w, + nbits=nbits, + group_size=group_size, + axis=axis, + compute_dtype=zero_point_dtype, + device=w.device, + ) + int_data = int_data.to(target_dtype) + else: + assert ( + int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.TINYGEMM + ), ( + f"Unsupported Int4ChooseQParamsAlgorithm: {int4_choose_qparams_algorithm}" + ) + + scale, zero_point = _choose_qparams_affine_tinygemm( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + int_data = _quantize_affine_tinygemm( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + ) + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # innerKTiles is not needed for CPU + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) + return Int4OpaqueTensor( + qdata=packed_weight, + scale_and_zero=scale_and_zero, + block_size=block_size, + shape=original_shape, + act_pre_scale=None, + ) + + +implements = Int4OpaqueTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert input_tensor.device.type == "cpu", ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert isinstance(weight_tensor, Int4OpaqueTensor), ( + f"Expected weight_tensor to be Int4OpaqueTensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + if weight_tensor.act_pre_scale is not None: + input_tensor = input_tensor * weight_tensor.act_pre_scale + + act_mat = input_tensor + packed_weight = weight_tensor.qdata + scale_and_zero = weight_tensor.scale_and_zero + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + + # remove out_feature padding + assert weight_tensor.ndim == 2 + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) + + +Int4OpaqueTensor.__module__ = "torchao.quantization" + +# Allow a model with Int4OpaqueTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4OpaqueTensor]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py b/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py new file mode 100644 index 0000000000..b5d988ef4a --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_packing_format.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class Int4PackingFormat(str, Enum): + """Packing format for quantized data in Int4 Tensor subclasses in torchao, represents how + the values in quantized data are packed and laid out in memory. + """ + + """ + plain means the format that quantized Tensor data lays out elements in Tensor sequentially, + for example: for a Tensor of shape (4, 6): + a_0_0, a_0_1, ..., a_0_5, + ... + a_3_0, a_3_1, ..., a_3_5 + + For example for int4, we will + pack two adjacent int4 elements into one uint8/int8 value for plain packing format + """ + PLAIN = "plain" + + """ + preshuffled is referring to the preshuffled format used by fbgemm kernels + """ + PRESHUFFLED = "preshuffled" + + """ + marlin_sparse is referring to the format used by marlin kernels, requires symmetric quantization + """ + MARLIN_SPARSE = "marlin_sparse" + + """ + plain_int32 is a format that 2 adjacent int4 values are packed in a byte and 4 such packed bytes are stored in a int32 value. + """ + PLAIN_INT32 = "plain_int32" + + """ + tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization + for a Tensor of shape (n, k), the packed weight will have dimension: + [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2], where inner_k_tiles is 8 currently + for simplication of Int4TilePackedTo4dTensor API + """ + TILE_PACKED_TO_4D = "tile_packed_to_4d" + + """ + Opaque packing format that's used for tensors that does not have a predefined packing format + (that may be decided on hardware, tensor shape, library availability etc.) and it's not + needed for the rest of the system to understand the specific format that's adopted. + """ + OPAQUE = "opaque" diff --git a/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py new file mode 100644 index 0000000000..0446eed42c --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + choose_qparams_affine, + quantize_affine, +) +from torchao.utils import ( + TorchAOBaseTensor, +) + +__all__ = [ + "Int4PlainInt32Tensor", +] + +aten = torch.ops.aten + + +class Int4PlainInt32Tensor(TorchAOBaseTensor): + """ + int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) + + Tensor Attributes: + qdata: (N, K/8), packed int4 weight, the data type is int32 here with 4*(int4*2), the original data type can be half and bfloat16 + scale: (K/group_size, N), dtype is the same as the original Tensor dtype + zero_point: (K/group_size, N), dtype is int8 + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity. + shape: shape of the original Tensor + + Optional Tensor Data Attributes: + act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present, + we'll multiply activation Tensor with act_pre_scale before applying dynamic + quantization to activation or running quantized mm op + + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape"] + optional_tensor_data_names = ["act_pre_scale"] + + def __new__( + cls, + qdata, + scale, + zero_point, + block_size, + shape, + act_pre_scale: Optional[torch.Tensor] = None, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata, + scale, + zero_point, + block_size, + shape, + act_pre_scale: Optional[torch.Tensor] = None, + ): + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + self.act_pre_scale = act_pre_scale + + def _quantization_type(self): + s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + if self.act_pre_scale is not None: + s += f", act_pre_scale.shape={self.act_pre_scale.shape}" + return s + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + assert w.ndim == 2 and w.device.type == "xpu", ( + f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" + ) + assert len(block_size) == w.ndim + assert w.dtype in [torch.float16, torch.bfloat16], ( + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" + ) + original_shape = w.shape + mapping_type = MappingType.ASYMMETRIC + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + scale_dtype = None + zero_point_dtype = torch.int32 + scale, zero_point = choose_qparams_affine( + w, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + ) + int_data = quantize_affine( + w, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + ) + assert int_data.dtype == torch.int32, ( + "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" + ) + packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + packed_weight.contiguous(), 8 + ) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + return Int4PlainInt32Tensor( + packed_weight, + scale.transpose(0, 1).contiguous(), + zero_point.transpose(0, 1).contiguous().to(torch.int8), + block_size, + original_shape, + act_pre_scale=None, + ) + + +implements = Int4PlainInt32Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert input_tensor.device.type == "xpu", ( + f"For XPU device only but got: {input_tensor.device}" + ) + assert isinstance(weight_tensor, Int4PlainInt32Tensor), ( + f"Expected weight_tensor to be Int4PlainInt32Tensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + if weight_tensor.act_pre_scale is not None: + input_tensor = input_tensor * weight_tensor.act_pre_scale + + act_mat = input_tensor + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + zero_point = weight_tensor.zero_point + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + act_mat, packed_weight, groupsize, scale, zero_point + ) + + # remove out_feature padding + assert weight_tensor.ndim == 2 + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) + + +Int4PlainInt32Tensor.__module__ = "torchao.quantization" + +# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4PlainInt32Tensor]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py new file mode 100644 index 0000000000..3f5a4e2b10 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -0,0 +1,298 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib.util +from typing import List, Optional + +import torch + +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor +from torchao.utils import ( + TorchAOBaseTensor, +) + +__all__ = [ + "Int4PreshuffledTensor", +] + +aten = torch.ops.aten + + +if ( + importlib.util.find_spec("fbgemm_gpu") is None + or importlib.util.find_spec("fbgemm_gpu.experimental") is None +): + quantize_int4_preshuffle = None + quantize_fp8_row = None + pack_int4 = None +else: + from fbgemm_gpu.experimental.gen_ai.quantize import ( + quantize_fp8_row, + quantize_int4_preshuffle, + ) + + +class Int4PreshuffledTensor(TorchAOBaseTensor): + """ + int4 quantization with preshuffled packing format (for all granularities) + + Tensor Attributes: + qdata: preshuffled and packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed + preshuffling is specific to fbgemm kernels, see Note for motivation, detailed layout doc is WIP + for bf16 activation: + group_scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, + dtype is the same as the original Tensor dtype + group_zero: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, + dtype is the same as the original Tensor dtype + for float8 activation: + group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor + dtype is float8 + row_scale: (N,) for 2D Tensor, (B, N) for 3D Tensor + dtype is the same as the original Tensor dtype + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) + shape: shape of the original Tensor + + Note on Details for preshuffle for fbgemm kernel: + + We use WGMMA instruction for efficient matrix multiplication in H100 Tensor Core. + To address a major inefficiency in how WGMMA tiles are loaded into shared memory before + dispatching to tensor cores, Each thread of an FP8 WGMMA reads 4 groups for 4 elements + (or 4 groups of 2 elements for BF16) into local registers. Each of those groups thus + contains a total 32 bits, which can be efficiently loaded using a single 32-bit load instruction. + However, weights are loaded using the same format. As the INT4 weights are only 4-bits each, + one group has a total of 16 bits. Unfortunately, 16 bit loads are not any faster than 32 bit + loads so having to load all four groups is wasteful. We can optimize weight loading by shuffling + the order of elements such that all 4 groups are sequential in memory. This allows us to + perform a single 64 bit load to move all needed weights for the thread into register memory. + + Note for float8 activation int4 weight kernel: + float8 activation int4 weight kernel doesn't work with zero_point, since it use table lookup approach which + requires symmetric quantization + """ + + tensor_data_names = ["qdata", "group_scale"] + tensor_attribute_names = ["block_size", "shape"] + optional_tensor_data_names = ["group_zero", "row_scale"] + + def __new__( + cls, + qdata: torch.Tensor, + group_scale: torch.Tensor, + block_size: List[int], + shape: List[int], + group_zero: Optional[torch.Tensor] = None, + row_scale: Optional[torch.Tensor] = None, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = group_scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + group_scale: torch.Tensor, + block_size: List[int], + shape: List[int], + group_zero: Optional[torch.Tensor] = None, + row_scale: Optional[torch.Tensor] = None, + ): + super().__init__() + # one and only one of group_scale and group_zero should be None + assert group_zero is None or row_scale is None + assert not (group_zero is not None and row_scale is not None) + self.qdata = qdata + self.row_scale = row_scale + self.block_size = block_size + self.group_scale = group_scale + self.group_zero = group_zero + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + activation_dtype: torch.dtype = torch.bfloat16, + ): + assert len(block_size) == w.ndim, ( + f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" + ) + + assert all(x == 1 for x in block_size[:-1]), ( + f"Only per group quantization is supported, got block_size: {block_size}" + ) + + _SUPPORTED_DTYPE_TO_STR = { + torch.bfloat16: "bf16", + torch.float8_e4m3fn: "fp8", + } + assert activation_dtype in _SUPPORTED_DTYPE_TO_STR, ( + f"activation dtype {activation_dtype} is not supported, supported ones are: {_SUPPORTED_DTYPE_TO_STR.keys()}" + ) + + if quantize_int4_preshuffle is None: + raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + + assert all(x == 1 for x in block_size[:-1]) and block_size[-1] != 1, ( + "Only groupwise quant is supported right now" + ) + original_shape = w.shape + group_size = block_size[-1] + + activation_dtype_str = _SUPPORTED_DTYPE_TO_STR[activation_dtype] + + if w.ndim >= 3: + wq, scales = zip( + *[ + quantize_int4_preshuffle( + i.cuda(), group_size=group_size, dtype=activation_dtype_str + ) + for i in w + ] + ) + wq = torch.stack(wq, dim=0) + group_scale, group_zero_or_row_scale = zip(*scales) + group_zero_or_row_scale = torch.stack( + group_zero_or_row_scale, dim=0 + ).contiguous() + group_scale = torch.stack(group_scale, dim=0).contiguous() + else: + wq, (group_scale, group_zero_or_row_scale) = quantize_int4_preshuffle( + w.cuda(), group_size=group_size, dtype=activation_dtype_str + ) + + if activation_dtype == torch.bfloat16: + group_zero = group_zero_or_row_scale + row_scale = None + else: + group_zero = None + row_scale = group_zero_or_row_scale + + return Int4PreshuffledTensor( + qdata=wq, + group_scale=group_scale, + block_size=block_size, + shape=original_shape, + group_zero=group_zero, + row_scale=row_scale, + ) + + @classmethod + def from_int4_tensor( + cls, + tensor: Int4Tensor, + ): + assert isinstance(tensor, Int4Tensor), ( + f"Only conversion from Int4Tensor is supportd, got: {tensor}" + ) + # currently Int4Tensor only supports weight only, we can extend it to fp8-int4 a bit later + qdata = tensor.qdata + group_scale = tensor.scale + group_zero = tensor.zero_point + block_size = tensor.block_size + original_shape = tensor.shape + row_scale = None + + # Set scales to activation type. + group_scale = group_scale.to(torch.bfloat16) + group_zero = group_zero.to(torch.bfloat16) + # pack weights and scales into efficient preshuffled format + preshuffled_qdata, group_scale = torch.ops.fbgemm.preshuffle_i4( + qdata, group_scale + ) + return Int4PreshuffledTensor( + qdata=preshuffled_qdata, + group_scale=group_scale, + block_size=block_size, + shape=original_shape, + group_zero=group_zero, + row_scale=row_scale, + ) + + +implements = Int4PreshuffledTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + orig_input_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + + wq = weight_tensor.qdata.contiguous() + group_scale = weight_tensor.group_scale.contiguous() + if weight_tensor.group_zero is not None: + # bf16 activation + group_zero = weight_tensor.group_zero.contiguous() + res = torch.ops.fbgemm.bf16i4bf16_shuffled( + input_tensor, wq, group_scale, group_zero + ) + else: + # dynamically quantizes activation to fp8 + assert weight_tensor.row_scale is not None + row_scale = weight_tensor.row_scale.contiguous() + xq, x_scale = quantize_fp8_row(input_tensor) + res = torch.ops.fbgemm.f8i4bf16_shuffled( + xq, wq, x_scale, row_scale, group_scale + ) + + res = res.reshape(*orig_input_size[:-1], orig_out_features) + if bias is not None: + res = res + bias + return res + + +@implements(torch.bmm) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + orig_input_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + + wq = weight_tensor.qdata.contiguous() + group_scale = weight_tensor.group_scale.contiguous() + if weight_tensor.group_zero is not None: + # bfloat16 activation + group_zero = weight_tensor.group_zero.contiguous() + res = torch.ops.fbgemm.bf16i4bf16_shuffled_batched( + input_tensor, wq, group_scale, group_zero + ) + else: + # dynamically quantizes activation to fp8 + assert weight_tensor.row_scale is not None + row_scale = weight_tensor.row_scale.contiguous() + xq, x_scale = quantize_fp8_row(input_tensor) + # From: https://github.com/pytorch/FBGEMM/blob/ba8f2b7adb90e096cff8818716f7cc3587030f70/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1654 + assert xq.dim() == 3 + B, M, _ = xq.shape + _, N, _ = wq.shape + res = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16) + for i in range(B): + res[i] = torch.ops.fbgemm.f8i4bf16_shuffled( + xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i] + ) + + res = res.reshape(*orig_input_size[:-1], orig_out_features) + return res + + +Int4PreshuffledTensor.__module__ = "torchao.quantization" + +# Allow a model with Int4PreshuffledTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4PreshuffledTensor]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py new file mode 100644 index 0000000000..cb4c520a33 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py @@ -0,0 +1,533 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.utils import TorchAOBaseTensor, fill_defaults + +__all__ = [ + "Int4Tensor", +] + +aten = torch.ops.aten + + +try: + from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4 +except: + int4_row_quantize_zp = None + pack_int4 = None + + +class Int4Tensor(TorchAOBaseTensor): + """ + int4 quantization with plain (default) packing format (for all granularities) + + Tensor Data Attributes: + qdata: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed + scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, + dtype is the same as the original Tensor dtype + zero_point: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, + dtype is the same as the original Tensor dtype + + Non-Tensor Data Attributes: + block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) + shape: the shape of the original Tensor + + Optional Tensor Data Attributes: + act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present, + we'll multiply activation Tensor with act_pre_scale before applying dynamic + quantization to activation or running quantized mm op + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape"] + optional_tensor_data_names = ["act_pre_scale"] + + def __new__( + cls, + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: List[int], + shape: torch.Size, + act_pre_scale: Optional[torch.Tensor] = None, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: List[int], + shape: torch.Size, + act_pre_scale: Optional[torch.Tensor] = None, + ): + super().__init__() + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + self.act_pre_scale = act_pre_scale + + def _quantization_type(self): + s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + if self.act_pre_scale is not None: + s += f", act_pre_scale.shape={self.act_pre_scale.shape}" + return s + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + assert len(block_size) == w.ndim, ( + f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" + ) + if int4_row_quantize_zp is None: + raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + + assert all(x == 1 for x in block_size[:-1]) and block_size[-1] != 1, ( + "Only groupwise quant is supported right now" + ) + + group_size = block_size[-1] + original_shape = w.shape + + if w.ndim >= 3: + wq, scale, zero_point = zip( + *[int4_row_quantize_zp(i, group_size) for i in w], strict=False + ) + wq = torch.stack([pack_int4(i) for i in wq], dim=0) + scale = torch.stack(scale, dim=0) + zero_point = torch.stack(zero_point, dim=0) + else: + wq, scale, zero_point = int4_row_quantize_zp(w, group_size) + wq = pack_int4(wq) + + scale = scale.to(w.dtype) + zero_point = zero_point.to(w.dtype) + + return Int4Tensor( + qdata=wq, + scale=scale, + zero_point=zero_point, + block_size=block_size, + shape=original_shape, + act_pre_scale=None, + ) + + +implements = Int4Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert isinstance(weight_tensor, Int4Tensor) + + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" + assert weight_tensor.zero_point.is_contiguous(), ( + "Expected zero_point to be contiguous" + ) + + if weight_tensor.act_pre_scale is not None: + input_tensor = input_tensor * weight_tensor.act_pre_scale + + orig_act_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + + input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]) + res = torch.ops.fbgemm.bf16i4bf16_rowwise( + input_tensor, + weight_tensor.qdata, + weight_tensor.scale, + weight_tensor.zero_point, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + if bias is not None: + res = res + bias + return res + + +@implements(torch.bmm) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" + assert weight_tensor.zero_point.is_contiguous(), ( + "Expected zero_point to be contiguous" + ) + + orig_act_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( + input_tensor, + weight_tensor.qdata, + weight_tensor.scale, + weight_tensor.zero_point, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + return res + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Only supports slicing for dim == 1 and dim == 2 + qdata has dimension: (N, K/2) + scale and zero_point has dimension: (K/groups, N) + + dim, start, end, step are args that's referring to the original tensor shape + which is (N, K), and we need to map that to the transformed weight shape of qdata, + scale and zero_point + + when dim == 0: we do a slice on qdata dim 0, and on dim 1 of scale and zero_point, + also adjust the start and end indexes based on the ratio between original shape and the shape + of qdata and scale/zero_point + + when dim == 1: we do a slice on qdata dim 1 and dim 0 of scale and zero_point and do the + same adjustment based on ratio + + Note that we need to call slice on the qdata, scale and zero_point directly because slice + is an operation that need to preserve aliasing, see `test_slice_preserves_aliasing` and + `test_slice_and_copy_similar_to_vllm` in `test_int4_tensor` for more details + """ + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + + assert self.qdata.ndim == 2, ( + f"Expected packed weight to have dim 2, got {self.qdata.dim}" + ) + N, K_by_2 = self.qdata.shape + sz_dim0, sz_dim1 = self.scale.shape + + data_len = self.shape[dim] + + if dim == 0: + pw_len = N + sz_len = sz_dim1 + else: + pw_len = K_by_2 + sz_len = sz_dim0 + + sz_dim = 1 - dim + if pw_len == 0 or sz_len == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + Int4Tensor( + self.qdata, + self.scale, + self.zero_point, + block_size=self.block_size, + shape=self.shape, + act_pre_scale=self.act_pre_scale, + ), + ) + + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + qdata = aten.slice.Tensor(self.qdata, dim, start_pw, end_pw, step) + scale = aten.slice.Tensor(self.scale, sz_dim, start_sz, end_sz, step) + zero_point = aten.slice.Tensor(self.zero_point, sz_dim, start_sz, end_sz, step) + packed_shape0, packed_shape1 = qdata.shape + new_shape = (packed_shape0, packed_shape1 * 2) + new = Int4Tensor( + qdata, + scale, + zero_point, + self.block_size, + new_shape, + act_pre_scale=self.act_pre_scale, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.cat.default) +def _(func, types, args, kwargs): + """Concatenate multiple Int4 quantized tensors + + For Int4Tensor, we need to concatenate qdata, scale, and zero_point tensors. + The concatenation behavior depends on the dimension and block_size configuration. + + If the concatenation dimension is not the same as the packed dimension, then we can just concatenate the + qdata, scale and zero_point directly, note that scale and zero_point has reversed dimension order in 2D + If the concatention dimension is the same as block_size, we'll check that scales from all + tensors are equal and use the first scale + """ + tensors, dim = fill_defaults(args, 2, [[], 0]) + if not tensors: + raise ValueError("Cannot concatenate empty list of tensors") + + tensor_0 = tensors[0] + dim = dim % tensor_0.ndim + + # Validate that all tensors have compatible properties + for i in range(1, len(tensors)): + assert tensor_0.qdata.ndim == tensors[i].qdata.ndim + assert tensor_0.scale.ndim == tensors[i].scale.ndim + assert tensor_0.zero_point.ndim == tensors[i].zero_point.ndim + assert tensor_0.block_size == tensors[i].block_size + + qdatas = [t.qdata for t in tensors] + scales = [t.scale for t in tensors] + zero_points = [t.zero_point for t in tensors] + + # Concatenate the quantized data along the specified dimension + cat_qdata = aten.cat.default(qdatas, dim=dim) + + # if concatenation happens in the non-packed dimension, we need to concatenation + # scale and zero_point + if tensor_0.block_size[dim] == 1: + # For scale and zero_point, the concatenation dimension depends on the dimension + # Int4Tensor has scale and zero_point with shape (K/group_size, N) for 2D or (B, K/group_size, N) for 3D + if cat_qdata.ndim == 2: # 2D case + sz_dim = ( + 1 - dim + ) # If concatenating dim 0 (N), use dim 1 for scale; if dim 1 (K), use dim 0 + else: # 3D case + assert cat_qdata.ndim == 3 + if dim in [1, 2]: + sz_dim = 3 - dim + else: + sz_dim = dim + + cat_scale = aten.cat.default(scales, dim=sz_dim) + cat_zero_point = aten.cat.default(zero_points, dim=sz_dim) + + else: + # if concatenation happens in the packed dimension, we just need to verify + # that all scale and zero_points match + for i in range(1, len(tensors)): + assert torch.equal(tensor_0.scale, tensors[i].scale) + assert torch.equal(tensor_0.zero_point, tensors[i].zero_point) + cat_scale = scales[0] + cat_zero_point = zero_points[0] + + # Calculate new shape based on the concatenated qdata shape + new_shape = list(cat_qdata.shape) + new_shape[-1] *= 2 + new_shape = list(new_shape) + + new = Int4Tensor( + cat_qdata, + cat_scale, + cat_zero_point, + tensor_0.block_size, + new_shape, + act_pre_scale=tensor_0.act_pre_scale, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + self, dim0, dim1 = args + + # Transpose the quantized data + qdata = self.qdata.transpose(dim0, dim1).contiguous() + if self.scale.ndim == 3: + # since scale/zero_point dimension order is different + # (B, K/group_size, N), we'll need to remap the dim + remapped_dim0 = dim0 + if dim0 in [1, 2]: + remapped_dim0 = 3 - dim0 + + remapped_dim1 = dim1 + if dim1 in [1, 2]: + remapped_dim1 = 3 - dim1 + + scale = self.scale.transpose(remapped_dim0, remapped_dim1) + zero_point = self.zero_point.transpose(remapped_dim0, remapped_dim1) + else: + assert scale.ndim == 2, f"Only support ndim == 2 or 3, got: {scale.ndim}" + remapped_dim0 = 1 - dim0 + remapped_dim1 = 1 - dim1 + scale = self.scale.transpose(remapped_dim0, remapped_dim1) + zero_point = self.zero_point.transpose(remapped_dim0, remapped_dim1) + + # Update block_size by swapping the dimensions + block_size = self.block_size.copy() + block_size[dim0], block_size[dim1] = block_size[dim1], block_size[dim0] + + # Update shape by swapping the dimensions + new_shape = list(self.shape) + new_shape[dim0], new_shape[dim1] = new_shape[dim1], new_shape[dim0] + + new = Int4Tensor( + qdata, + scale, + zero_point, + block_size, + new_shape, + act_pre_scale=self.act_pre_scale, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, size = args + original_shape = self.shape + original_packing_dim = None + for i in range(len(original_shape)): + if original_shape[i] == (self.qdata.shape[i] * 2): + original_packing_dim = i + assert original_packing_dim is not None, "Didn't find a packing_dim" + + if len(original_shape) == 3 and len(size) == 2: + # only support combining the dim 0 and dim1 together + assert original_shape[-1] == size[-1], ( + f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + ) + # the dim that int4 packing happens + if original_packing_dim in [0, 1]: + packing_dim = 0 + else: + packing_dim = 1 + + block_size = self.block_size.copy() + block_size = [block_size[0] * block_size[1], block_size[2]] + + qdata_shape = size.copy() + qdata_shape[packing_dim] //= 2 + qdata = self.qdata.reshape(*qdata_shape) + sz_shape = [] + for i in range(len(size)): + sz_shape.append(size[i] // block_size[i]) + # scale and zero_point have reversed dimensions + sz_shape[0], sz_shape[1] = sz_shape[1], sz_shape[0] + + scale = self.scale.reshape(*sz_shape) + zero_point = self.zero_point.reshape(*sz_shape) + elif len(original_shape) == 2 and len(size) == 3: + # only support extending the dim 0 to 2, `t.unflatten(0, (num_experts, -1))` + assert original_shape[-1] == size[-1], ( + f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + ) + if original_packing_dim == 0: + packing_dim = 1 + else: + # original_packing_dim is 1 + packing_dim = 2 + + block_size = self.block_size.copy() + block_size = [1, block_size[0], block_size[1]] + + qdata_shape = size.copy() + qdata_shape[packing_dim] //= 2 + qdata = self.qdata.reshape(*qdata_shape) + + sz_shape = [] + for i in range(len(size)): + sz_shape.append(size[i] // block_size[i]) + + # scale and zero_point have reversed dimensions + sz_shape[1], sz_shape[2] = sz_shape[2], sz_shape[1] + + scale = self.scale.reshape(*sz_shape) + zero_point = self.zero_point.reshape(*sz_shape) + elif len(original_shape) == len(size): + assert all(x == y or y == -1 for x, y in zip(original_shape, size)), ( + f"Only support viewing with match dimensions or -1, got: {original_shape}, {size}" + ) + packing_dim = original_packing_dim + block_size = self.block_size + else: + assert len(original_shape) == 2 and len(size) == 3, ( + f"Only support reshaping from 2D to 3D or from 3D to 2D or between sam ranges, requested: reshaping from {original_shape} to {size}" + ) + + shape = list(qdata.shape) + for i in range(len(shape)): + if i == packing_dim: + shape[i] *= 2 + + new = Int4Tensor( + qdata, + scale, + zero_point, + block_size, + shape, + act_pre_scale=self.act_pre_scale, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.squeeze.dim) +def _(func, types, args, kwargs): + self, dim = args + + # Squeeze qdata + qdata = self.qdata.squeeze(dim=dim) + + # For scale and zero_point, we need to squeeze based on the tensor layout + # Int4Tensor has scale and zero_point with shape (K/group_size, N) for 2D or (B, N, K/group_size) for 3D + if self.qdata.ndim == 2: # 2D case + # qdata is (N, K/2), scale/zero_point is (K/group_size, N) + # When squeezing qdata dim, we need to squeeze scale/zero_point in reverse order + sz_dim = 1 - dim + else: # 3D case + # qdata is (B, N, K/2), scale/zero_point is (B, N, K/group_size) + sz_dim = dim + + scale = self.scale.squeeze(dim=sz_dim) + zero_point = self.zero_point.squeeze(dim=sz_dim) + + # Update block_size by removing the squeezed dimension + new_block_size = list(self.block_size) + if len(qdata.shape) < len(new_block_size): + new_block_size.pop(dim) + + # Update shape by removing the squeezed dimension + new_shape = list(self.shape) + if len(qdata.shape) < len(new_shape): + assert new_shape[dim] == 1 + new_shape.pop(dim) + + new = Int4Tensor( + qdata, + scale, + zero_point, + new_block_size, + new_shape, + act_pre_scale=self.act_pre_scale, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +Int4Tensor.__module__ = "torchao.quantization" + +# Allow a model with Int4Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4Tensor]) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py new file mode 100644 index 0000000000..6c80198b9f --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int4/int4_tile_packed_to_4d_tensor.py @@ -0,0 +1,347 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from typing import List + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + _choose_qparams_affine_tinygemm, + _choose_qparams_and_quantize_affine_hqq, + _quantize_affine_tinygemm, +) +from torchao.quantization.utils import pack_tinygemm_scales_and_zeros +from torchao.utils import TorchAOBaseTensor, fill_defaults, find_multiple + +from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm + +__all__ = [ + "Int4TilePackedTo4dTensor", +] + +aten = torch.ops.aten + + +class Int4TilePackedTo4dTensor(TorchAOBaseTensor): + """ + int4 quantization with tile packed to 4d packing format for groupwise quantization + + Tensor Attributes: + qdata: tile packed to 4d int4 weight, 4-d tensor of dimension: + [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] + (unpacked Tensor shape is n * k) + (inner_k_tiles is fixed to 8 for Int4TilePackedTo4dTensor) + scale_and_zero: combined scale and zero point tensor packed for tinygemm kernels + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity, + for example groupwise quantization will have block_size (1, group_size) + shape: shape of the original Tensor + + Note on Details for tile packed to 4d packing format: + + This is used by tinygemm kernels `_weight_int4pack_mm`. The weight is stored as + a 4-d packed tensor with specific packing format for efficient computation on tensor cores. + The packing format optimizes for tensor core matrix multiplication performance. + """ + + tensor_data_names = ["qdata", "scale_and_zero"] + tensor_attribute_names = ["block_size", "shape"] + + def __new__( + cls, + qdata: torch.Tensor, + scale_and_zero: torch.Tensor, + block_size: List[int], + shape: torch.Size, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = torch.bfloat16 # This tensor subclass only supports bfloat16 + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale_and_zero: torch.Tensor, + block_size: List[int], + shape: torch.Size, + ): + self.qdata = qdata + self.scale_and_zero = scale_and_zero + self.block_size = block_size + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" + + @classmethod + def from_hp( + cls, + hp_tensor: torch.Tensor, + block_size: List[int], + int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = Int4ChooseQParamsAlgorithm.TINYGEMM, + ): + assert len(block_size) == hp_tensor.ndim, ( + f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {hp_tensor.ndim=}" + ) + + assert all(x == 1 for x in block_size[:-1]), ( + f"Only per group quantization is supported, got block_size: {block_size}" + ) + + assert hp_tensor.dtype == torch.bfloat16, ( + f"Only bfloat16 is supported for Int4TilePackedTo4dTensor, got {hp_tensor.dtype}" + ) + + original_shape = hp_tensor.shape + # use a fixed inner_k_tiles value to simplify the argument list and config + # for Int4TilePackedTo4dTensor + inner_k_tiles = 8 + + # Validate kernel requirements + orig_out_features, orig_in_features = hp_tensor.shape[-2:] + # TODO: relax checks to enable quantizing in other platoforms and run in A100 + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Cannot use tinygemm int4 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for tensor core kernels." + ) + + # Pre-process: pad to required dimensions + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + hp_tensor_padded = torch.nn.functional.pad( + hp_tensor, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + + # Quantize + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + + # we support two paths for constructing a Int4TilePackedTo4dTensor + # 1. use [hqq](https://mobiusml.github.io/hqq_blog/) algorithm to compute + # scale and zero_point, then convert to the format that's compatible with tinygemm kernels + # 2. don't use hqq, use default tinygemm algorithm to compute scale and zero_point + # + # both approach should have the same speed since both are using tinygemm kernel for gemm + # 1. typically will have higher accuracy compared to 2. + if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ: + nbits = int(math.log2(quant_max + 1)) + axis = 1 + group_size = block_size[-1] + compute_dtype = hp_tensor_padded.dtype + device = hp_tensor_padded.device + int_data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq( + hp_tensor_padded, + nbits=nbits, + group_size=group_size, + axis=axis, + compute_dtype=compute_dtype, + device=device, + verbose=False, + raw_output=False, + # raw_output=False is basically the 'convert to tinygemm zero_point version' option (add scale*midpoint) that's used in TilePackedTo4d + # note _choose_qparams_affine_tinygemm does this same thing + ) + int_data = int_data.to(target_dtype) + else: + assert ( + int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.TINYGEMM + ), ( + f"Unsupported Int4ChooseQParamsAlgorithm: {int4_choose_qparams_algorithm}" + ) + # Calculate scale and zero_point for tinygemm + scale, zero_point = _choose_qparams_affine_tinygemm( + hp_tensor_padded, + mapping_type=MappingType.ASYMMETRIC, + block_size=tuple(block_size), + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + scale_dtype=hp_tensor.dtype, + zero_point_dtype=hp_tensor.dtype, + ) + + # Quantize for tinygemm + int_data = _quantize_affine_tinygemm( + hp_tensor_padded, + block_size, + scale, + zero_point, + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + ) + + # Convert to packed format + def quant_2d(int_data_2d): + int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( + torch.uint8 + ) + return torch.ops.aten._convert_weight_to_int4pack( + int_data_2d.contiguous(), inner_k_tiles + ) + + if int_data.dim() == 3: # for moe quant + num_experts = int_data.shape[0] + packed_weight_list = [] + for expert in range(num_experts): + packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0)) + packed_weight = torch.cat(packed_weight_list, dim=0) + scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1) + zero_point = ( + zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1) + if zero_point is not None + else None + ) + else: + assert int_data.dim() == 2 + packed_weight = quant_2d(int_data) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = ( + zero_point.reshape(int_data.shape[0], -1) + if zero_point is not None + else None + ) + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) + + return cls( + qdata=packed_weight, + scale_and_zero=scale_and_zero, + block_size=block_size, + shape=original_shape, + ) + + +implements = Int4TilePackedTo4dTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale_and_zero.is_contiguous(), ( + "Expected scale_and_zero to be contiguous" + ) + + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + # weight is packed from padded (out_features, in_features) weight tensor + # (same dimension requirement as F.linear weight) + packed_weight = weight_tensor.qdata + scale_and_zero = weight_tensor.scale_and_zero + original_shape = weight_tensor.shape + + orig_act_size = input_tensor.size() + orig_dtype = input_tensor.dtype + + # Folds batch dimension into the first dimension + act_mat = input_tensor.reshape(-1, input_tensor.shape[-1]).to(torch.bfloat16) + pad_size = find_multiple(act_mat.shape[-1], 1024) + act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[-1] + if act_mat.numel() == 0: # handling for empty input + y = act_mat + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat, packed_weight, groupsize, scale_and_zero + ) + # remove out_feature padding + orig_out_features = original_shape[-2] + y = y[:, :orig_out_features] + + # Unfold the batch dimension + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias.to(y.dtype) + return y.to(orig_dtype) + + +@implements(aten.slice.Tensor) +def _(func, _types, args, _kwargs): + """Slice operation for tensor core tiled packed tensor""" + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + cur_shape = self.shape + + assert len(cur_shape) == 2 + assert self.qdata.dim() == 4 + # qdata has shape [n/8, k/(inner_k_tiles*16), 32, inner_k_tiles/2] + n_by_8, k_by_inner_tiles, _, _ = self.qdata.shape + sz_dim1, sz_dim0, _ = self.scale_and_zero.shape + + data_len = cur_shape[dim] + assert dim in [ + 0, + 1, + ], ( + f"Int4TilePackedTo4dTensor slice: attempting to run {func}, with dim={dim}, that is not supported" + ) + + if dim == 0: + pw_len = n_by_8 + sz_len = sz_dim0 + else: + pw_len = k_by_inner_tiles + sz_len = sz_dim1 + + if pw_len == 0 or sz_len == 0: + return Int4TilePackedTo4dTensor( + self.qdata, + self.scale_and_zero, + self.block_size, + self.shape, + ) + + pw_ratio = data_len / pw_len + start_pw = int(start / pw_ratio) + end_pw = int(end / pw_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + qdata = aten.slice(self.qdata, dim, start_pw, end_pw, step) + scale_and_zero = aten.slice(self.scale_and_zero, 1 - dim, start_sz, end_sz, step) + + # Calculate new shape after slicing + new_shape = list(self.shape) + new_shape[dim] = end - start + + block_size = list(self.block_size) + block_size[dim] = min(block_size[dim], new_shape[dim]) + + return Int4TilePackedTo4dTensor( + qdata, + scale_and_zero, + block_size, + new_shape, + ) + + +Int4TilePackedTo4dTensor.__module__ = "torchao.quantization" + +# Allow a model with Int4TilePackedTo4dTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Int4TilePackedTo4dTensor]) diff --git a/torchao/quantization/quantize_/workflows/intx/__init__.py b/torchao/quantization/quantize_/workflows/intx/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py new file mode 100644 index 0000000000..2c32732b74 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py @@ -0,0 +1,369 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from typing import Optional + +import torch + +from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH +from torchao.quantization.quantize_.workflows.intx.intx_packing_format import ( + IntxPackingFormat, +) +from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import ( + IntxUnpackedToInt8Tensor, + IntxUnpackedToInt8TensorActivationQuantization, +) +from torchao.utils import ( + TorchAOBaseTensor, + torch_version_at_least, +) + +__all__ = [ + "IntxOpaqueTensor", +] + +aten = torch.ops.aten + + +def _is_kernel_library_loaded(): + loaded = False + try: + torch.ops.torchao._pack_8bit_act_4bit_weight + loaded = True + except AttributeError: + pass + return loaded + + +class IntxOpaqueTensor(TorchAOBaseTensor): + """ + intx quantization with tile packed format for CPUs + + Tensor Attributes: + packed_weights: packed bytes. Only interpretable by kernel + + Non-Tensor Attributes: + bit_width: the bit width for quantization (can be 1 - 8) + block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) + shape: the shape of the original Tensor + dtype: dtype for activations/outputs + packed_weights_has_zeros: whether zeros are present in packed_weights + packed_weights_has_bias: whether bias is present in packed_weights + intx_packing_format: the packing format for the packed data. See :class:`~torchao.quantization.quantize_.workflows.intx.intx_packing_format.IntxPackingFormat` enum for details. + """ + + tensor_data_names = ["packed_weights"] + tensor_attribute_names = [ + "bit_width", + "block_size", + "shape", + "dtype", + "packed_weights_has_zeros", + "packed_weights_has_bias", + "intx_packing_format", + ] + + def __new__( + cls, + packed_weights, + bit_width, + block_size, + shape, + dtype, + packed_weights_has_zeros, + packed_weights_has_bias, + intx_packing_format, + ): + kwargs = {} + kwargs["device"] = packed_weights.device + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weights, + bit_width, + block_size, + shape, + dtype, + packed_weights_has_zeros, + packed_weights_has_bias, + intx_packing_format, + ): + super().__init__() + assert packed_weights.device == torch.device("cpu") + self.packed_weights = packed_weights + self.bit_width = bit_width + self.block_size = block_size + self.packed_weights_has_zeros = packed_weights_has_zeros + self.packed_weights_has_bias = packed_weights_has_bias + self.intx_packing_format = intx_packing_format + + def _quantization_type(self): + return f"bit_width={self.bit_width}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device} intx_packing_format={self.intx_packing_format}" + + def to(self, *args, **kwargs): + raise NotImplementedError("to() is not implemented for IntxOpaqueTensor") + + @classmethod + def from_intx_unpacked_to_int8_tensor( + cls, + tensor: IntxUnpackedToInt8Tensor, + *, + bias: Optional[torch.Tensor] = None, + intx_packing_format: IntxPackingFormat = IntxPackingFormat.OPAQUE_TORCHAO_AUTO, + ): + """ + Constructs a IntxOpaqueTensor from an IntxUnpackedToInt8Tensor. + If bias is passed, bias is packed into the tensor. + The intx_packing_format indicates how the data is packed. + """ + if isinstance(intx_packing_format, str): + intx_packing_format = IntxPackingFormat[intx_packing_format.upper()] + + assert intx_packing_format in [ + IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI, + IntxPackingFormat.OPAQUE_TORCHAO_AUTO, + IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI, + IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT, + ] + + # Extract data from IntxUnpackedToInt8Tensor + assert ( + tensor.activation_quantization + == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN + ) + qdata, scale, zero_point = tensor.qdata, tensor.scale, tensor.zero_point + bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype] + dtype = tensor.dtype + shape = tensor.shape + + block_size = tensor.block_size + assert len(block_size) == 2, "only 2D block_size is supported" + assert block_size[0] == 1, ( + "only per group or per channel quantization is supported" + ) + group_size = block_size[1] + is_per_channel = group_size == shape[1] + + packed_weights_has_bias = bias is not None + packed_weights_has_zeros = not torch.all(zero_point == 0.0).item() + + assert scale.dtype in [torch.bfloat16, torch.float32] + scale_is_bfloat16_or_is_rounded_to_bf16 = ( + scale.dtype == torch.bfloat16 + ) or torch.allclose(scale, scale.to(torch.bfloat16).to(torch.float32)) + + # Handle ATEN + if intx_packing_format == IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI: + assert torch_version_at_least("2.6.0"), ( + "ATEN target requires torch version > 2.6.0" + ) + assert torch.backends.kleidiai.is_available(), ( + "ATEN target requires torch.backends.kleidiai.is_available()" + ) + assert bit_width == 4, "ATEN target only supports 4-bit" + assert not packed_weights_has_zeros, "ATEN target does not support zeros" + qdata = qdata.add(8) + qdata = (qdata[::, 1::2] << 4 | qdata[::, ::2]).to(torch.uint8) + + # If per-group, convert scales to bfloat16 to call optimized kernel + if not is_per_channel: + if not scale_is_bfloat16_or_is_rounded_to_bf16: + logging.info( + f"scale has dtype {scale.dtype}, converting to torch.bfloat16" + ) + scale = scale.to(torch.bfloat16) + + packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight( + qdata, scale, bias, group_size, shape[1], shape[0] + ) + return cls( + packed_weight, + bit_width, + block_size, + shape, + dtype, + packed_weights_has_zeros, + packed_weights_has_bias, + intx_packing_format, + ) + + # Handle TORCHAO + assert _is_kernel_library_loaded(), "TorchAO kernel library is not loaded" + packing_format_map = { + IntxPackingFormat.OPAQUE_TORCHAO_AUTO: None, + IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI: "kleidiai", + IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT: "universal", + } + assert intx_packing_format in packing_format_map, ( + f"intx_packing_format {intx_packing_format} not supported" + ) + + if not scale_is_bfloat16_or_is_rounded_to_bf16 and intx_packing_format in [ + IntxPackingFormat.OPAQUE_TORCHAO_AUTO, + IntxPackingFormat.OPAQUE_TORCHAO_KLEIDIAI, + ]: + logging.info("scale may be rounded to bf16 in the kernel") + if scale.dtype != torch.float32: + logging.info(f"scale has dtype {scale.dtype}, converting to torch.float32") + scale = scale.to(torch.float32) + if bias is not None and bias.dtype != torch.float32: + logging.info(f"bias has dtype {bias.dtype}, converting to torch.float32") + bias = bias.to(torch.float32) + if packed_weights_has_zeros and not tensor._has_float_zero_point(): + zero_point = zero_point.to(torch.int8) + + packed_weights = getattr( + torch.ops.torchao, + f"_pack_8bit_act_{bit_width}bit_weight", + )( + qdata, + scale.reshape(-1), + zero_point.reshape(-1) if packed_weights_has_zeros else None, + group_size, + bias, + packing_format_map[intx_packing_format], + ) + return cls( + packed_weights, + bit_width, + block_size, + shape, + dtype, + packed_weights_has_zeros, + packed_weights_has_bias, + intx_packing_format, + ) + + +implements = IntxOpaqueTensor.implements + + +def _linear_impl_2d_aten(input_tensor, weight_tensor): + assert isinstance(weight_tensor, IntxOpaqueTensor) + assert weight_tensor.intx_packing_format == IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + assert weight_tensor.block_size[0] == 1 + assert weight_tensor.bit_width == 4 + group_size = weight_tensor.block_size[1] + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + + packed_weights = weight_tensor.packed_weights + + return torch.ops.aten._dyn_quant_matmul_4bit( + input_tensor, packed_weights, group_size, k, n + ) + + +def _linear_impl_2d_torchao(input_tensor, weight_tensor): + assert weight_tensor.intx_packing_format != IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + + packed_weights = weight_tensor.packed_weights + bit_width = weight_tensor.bit_width + + if weight_tensor.dtype != torch.float32: + input_tensor = input_tensor.to(torch.float32) + res = getattr(torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit_weight")( + input_tensor, + packed_weights, + group_size, + n, + k, + ) + if weight_tensor.dtype != torch.float32: + res = res.to(weight_tensor.dtype) + + return res + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + if weight_tensor.intx_packing_format == IntxPackingFormat.OPAQUE_ATEN_KLEIDIAI: + _impl_2d = _linear_impl_2d_aten + else: + _impl_2d = _linear_impl_2d_torchao + + # TODO: why was this added https://github.com/pytorch/ao/pull/2043 + if input_tensor.numel() == 0: + return input_tensor + + if input_tensor.dim() == 1: + k = input_tensor.shape[0] + input_tensor = input_tensor.reshape(1, k) + res = _impl_2d(input_tensor, weight_tensor) + res = res.reshape(-1) + elif input_tensor.dim() == 2: + res = _impl_2d(input_tensor, weight_tensor) + else: + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = res.reshape(*lead_shape, m, n) + + if bias is not None: + assert not weight_tensor.packed_weights_has_bias + res = res + bias + + return res + + +@implements([torch.nn.functional.embedding, aten.embedding.default]) +def _(func, types, args, kwargs): + assert len(args) == 2 + indices, weight_tensor = ( + args[0], + args[1], + ) + assert isinstance(weight_tensor, IntxOpaqueTensor) + assert weight_tensor.intx_packing_format == IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT + packed_weights = weight_tensor.packed_weights + + assert len(weight_tensor.block_size) == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] + + n, k = weight_tensor.shape + bit_width = weight_tensor.bit_width + + shape = indices.shape + out = getattr(torch.ops.torchao, f"_shared_embedding_{bit_width}bit")( + packed_weights, + group_size, + n, + k, + indices.reshape(-1), + ).reshape(*shape, -1) + return out + + +IntxOpaqueTensor.__module__ = "torchao.quantization" + +torch.serialization.add_safe_globals([IntxOpaqueTensor]) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_packing_format.py b/torchao/quantization/quantize_/workflows/intx/intx_packing_format.py new file mode 100644 index 0000000000..bb16663c54 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/intx/intx_packing_format.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + +import torch + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class IntxPackingFormat(str, Enum): + """Packing format for quantized data in Tensor subclasses in torchao, represents how + the values are packed and laid out in the quantized data. + """ + + """ + Unpacked to int8 means the subbyte quantized data is stored as int8 + """ + UNPACKED_TO_INT8 = "unpacked_to_int8" + + """ + Opaque packing formats are used for tensors that does not have a predefined packing format + (that may be decided on hardware, tensor shape, library availability etc.) and it's not + needed for the rest of the system to understand the specific format that's adopted. + """ + + """ + This packs the tensor for PyTorch CPU kernels in ATen. + It does not require installing torchao C++ kernels. + """ + OPAQUE_ATEN_KLEIDIAI = "opaque_aten_kleidiai" + + """ + This packs the tensor for TorchAO CPU kernels by selecting the best available kernel + based on the quantization scheme, either using KlediAI kernels or lowbit kernels. + It requires TorchAO C++ kernels to be installed. + """ + OPAQUE_TORCHAO_AUTO = "opaque_torchao_auto" + + """ + This packs the tensor for TorchAO CPU kernels using KlediAI kernels. + It requires TorchAO C++ kernels to be installed. + """ + OPAQUE_TORCHAO_KLEIDIAI = "opaque_torchao_kleidiai" + + """ + This packs the tensor for TorchAO CPU kernels using lowbit kernels. + It requires TorchAO C++ kernels to be installed. + """ + OPAQUE_TORCHAO_LOWBIT = "opaque_torchao_lowbit" + + +torch.serialization.add_safe_globals([IntxPackingFormat]) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py new file mode 100644 index 0000000000..87402241dd --- /dev/null +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -0,0 +1,381 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import enum +from typing import List, Optional, Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_QVALUE_BOUNDS, + MappingType, + choose_qparams_affine, + dequantize_affine, + quantize_affine, +) +from torchao.quantization.utils import _get_per_token_block_size +from torchao.utils import ( + TorchAOBaseTensor, + fill_defaults, +) + +__all__ = [ + "IntxUnpackedToInt8Tensor", +] + +aten = torch.ops.aten + +_FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] + + +class IntxUnpackedToInt8TensorActivationQuantization(str, enum.Enum): + """ + This applies int8 asymmetric activation quantization per token. + """ + + INT8_ASYM_PER_TOKEN = "int8_asym_per_token" + + +class IntxUnpackedToInt8Tensor(TorchAOBaseTensor): + """ + intx quantization with unpacked format. Subbyte quantized data is represented as int8. + The range of the quantized values are restricted to the quant_min and quant_max of the target_dtype, e.g., + if target_dtype=torch.int4, qdata will be an int8 tensor with values in [-8, 7]. + Quantization is represented in a decomposed way. + This format is inteded for torch.export use cases. + + Tensor Attributes: + qdata: int data for quantization. + dtype is int8, but the range of the qdata is determined by target_dtype + Shape is the same as original Tensor: (n, k) for 2D tensor + scale: block scales for quantization + dtype is the same as the original Tensor dtype. + Shape is (n // block_size[0], k // block_size[1]) for 2D tensor + zero_point: block zero points for quantization + dtype is the same as the original Tensor dtype or int8 + Shape is (n // block_size[0], k // block_size[1]) for 2D tensor + + Non-Tensor Attributes: + target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8) + block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) + dtype: the dtype of the dequantized Tensor + activation_quantization: Optional[IntxUnpackedToInt8TensorActivationQuantization] = None, kind of activation quantization to apply. Default is None, which means weight-only quantization + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = [ + "target_dtype", + "block_size", + "dtype", + "activation_quantization", + ] + + def __new__( + cls, + qdata, + scale, + zero_point, + target_dtype, + block_size, + dtype, + activation_quantization, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + shape = qdata.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata, + scale, + zero_point, + target_dtype, + block_size, + dtype, + activation_quantization, + ): + super().__init__() + assert qdata.dtype == torch.int8, ( + f"qdata dtype must be int8, but got {qdata.dtype}" + ) + assert scale.dtype in _FLOAT_TYPES, ( + f"scale dtype must be one of {_FLOAT_TYPES}, but got {scale.dtype}" + ) + assert zero_point.dtype in _FLOAT_TYPES or zero_point.dtype == torch.int8, ( + f"zero_point dtype must be {torch.int8} or one of {_FLOAT_TYPES}, but got {zero_point.dtype}" + ) + + assert target_dtype in [ + getattr(torch, f"int{bit_width}") for bit_width in range(1, 9) + ] + + assert len(block_size) == qdata.ndim + n_blocks = [] + for i in range(len(block_size)): + assert qdata.shape[i] % block_size[i] == 0 + n_blocks.append(qdata.shape[i] // block_size[i]) + + # Assert shapes + assert scale.shape == tuple(n_blocks), ( + f"Expected scale to have shape {n_blocks} (inferred from block_size={block_size}), but got {scale.shape}" + ) + assert zero_point.shape == tuple(n_blocks), ( + f"Expected zero_point to have shape {n_blocks} (inferred from block_size={block_size}), but got {zero_point.shape}" + ) + + assert dtype in _FLOAT_TYPES, ( + f"dtype must be one of {_FLOAT_TYPES}, but got {dtype}" + ) + + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + + self.target_dtype = target_dtype + self.block_size = block_size + self.activation_quantization = activation_quantization + + def _quantization_type(self): + return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}, activation_quantization={self.activation_quantization}" + + def _has_float_zero_point(self) -> bool: + return self.zero_point.dtype in _FLOAT_TYPES + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + dtype = kwargs.pop("dtype") + assert dtype in _FLOAT_TYPES + return IntxUnpackedToInt8Tensor( + self.qdata.to(device), + self.scale.to(device=device, dtype=dtype), + self.zero_point.to(device=device, dtype=dtype) + if self._has_float_zero_point() + else self.zero_point.to(device), + self.target_dtype, + self.block_size, + dtype, + self.activation_quantization, + ) + + @classmethod + def from_hp( + cls, + hp_tensor: torch.Tensor, + block_size: Tuple[int], + target_dtype: torch.dtype, + *, + mapping_type: MappingType = MappingType.SYMMETRIC, + activation_quantization: Optional[ + IntxUnpackedToInt8TensorActivationQuantization + ] = None, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, + ): + """ + Create an IntxUnpackedToInt8Tensor from a high-precision tensor + """ + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] + if custom_scale is not None and custom_zero_point is not None: + scale, zero_point = custom_scale, custom_zero_point + elif custom_scale is None and custom_zero_point is None: + scale, zero_point = choose_qparams_affine( + hp_tensor, + mapping_type, + block_size, + target_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + zero_point_dtype=torch.int8, + ) + else: + raise ValueError( + "`custom_scale` and `custom_zero_point` must be both defined or both None" + ) + qdata = quantize_affine( + hp_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + + # Reshape scale and zero_point to be compatible with block_size + # This is asserted in IntxUnpackedToInt8Tensor's __init__ + n_blocks = [] + for i in range(len(block_size)): + assert qdata.shape[i] % block_size[i] == 0 + n_blocks.append(qdata.shape[i] // block_size[i]) + scale = scale.reshape(*n_blocks) + zero_point = zero_point.reshape(*n_blocks) + + return IntxUnpackedToInt8Tensor( + qdata=qdata, + scale=scale, + zero_point=zero_point, + target_dtype=target_dtype, + block_size=block_size, + dtype=hp_tensor.dtype, + activation_quantization=activation_quantization, + ) + + def dequantize(self): + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.target_dtype] + return dequantize_affine( + self.qdata, + self.block_size, + self.scale, + self.zero_point, + torch.int8, + qmin, + qmax, + output_dtype=self.dtype, + ) + + +def _apply_int8_act_asym_per_token_quant_dequant(hp_tensor): + target_dtype = torch.int8 + mapping_type = MappingType.ASYMMETRIC + block_size = _get_per_token_block_size(hp_tensor) + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] + scale, zero_point = choose_qparams_affine( + hp_tensor, + mapping_type, + block_size, + target_dtype=target_dtype, + quant_min=qmin, + quant_max=qmax, + zero_point_dtype=torch.int8, + ) + qdata = quantize_affine( + hp_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + dequantized_affine = dequantize_affine( + qdata, + block_size, + scale, + zero_point, + torch.int8, + qmin, + qmax, + output_dtype=hp_tensor.dtype, + ) + return dequantized_affine + + +implements = IntxUnpackedToInt8Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert isinstance(weight_tensor, IntxUnpackedToInt8Tensor) + + # Apply dynamic activation quant + if weight_tensor.activation_quantization is not None: + if ( + weight_tensor.activation_quantization + == IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN + ): + input_tensor = _apply_int8_act_asym_per_token_quant_dequant(input_tensor) + else: + raise NotImplementedError( + f"Unsupported activation quantization: {weight_tensor.activation_quantization}" + ) + + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements([torch.nn.functional.embedding, aten.embedding.default]) +def _(func, types, args, kwargs): + assert len(args) == 2 + indices, weight_tensor = ( + args[0], + args[1], + ) + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.embedding(indices, weight_tensor, **kwargs) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + + # Slicing must be compatible with the block size to make sense on the quantized tensor + # In particular both start and end must be a multiple of block_size[dim] + # Otherwise the sliced tensor cannot be represented as a IntxUnpackedToInt8Tensor + # For example, if block_size = 4, we might have: + # + # qdata: i i i i | i i i i + # scale: s s + # + # If we set start = 2 and end = 8, then the qdata slice is: + # + # qdata_slice: i i (i i | i i i i) + # + # But then the block_size for the first two qdata in the slice is 2 + # and remaining blocks have size 4. This cannot be represented + # with the metadata we store in an IntxUnpackedToInt8Tensor, which requires uniform blocking + + assert start % self.block_size[dim] == 0, ( + f"slice args are incompatible with blocking: start={start} must be divisible by block_size[dim]={self.block_size[dim]}" + ) + start_scale = start // self.block_size[dim] + + assert end % self.block_size[dim] == 0, ( + f"slice args are incompatible with blocking: end={end} must be divisible by block_size[dim]={self.block_size[dim]}" + ) + end_scale = end // self.block_size[dim] + + qdata = aten.slice.Tensor(self.qdata, dim, start, end, step) + scale = aten.slice.Tensor(self.scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor(self.zero_point, dim, start_scale, end_scale, step) + + new_block_size = [] + for i in range(qdata.ndim): + assert scale.shape[i] == zero_point.shape[i] + n_blocks = scale.shape[i] + assert qdata.shape[i] % n_blocks == 0 + new_block_size.append(qdata.shape[i] // n_blocks) + new_block_size = tuple(new_block_size) + + new = IntxUnpackedToInt8Tensor( + qdata, + scale, + zero_point, + self.target_dtype, + new_block_size, + self.dtype, + self.activation_quantization, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +IntxUnpackedToInt8Tensor.__module__ = "torchao.quantization" + +# Allow a model with IntxUnpackedToInt8Tensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals( + [IntxUnpackedToInt8Tensor, IntxUnpackedToInt8TensorActivationQuantization] +) diff --git a/torchao/quantization/smoothquant.py b/torchao/quantization/smoothquant.py index 972f0cc6ec..3420f3c8b2 100644 --- a/torchao/quantization/smoothquant.py +++ b/torchao/quantization/smoothquant.py @@ -16,8 +16,8 @@ import torch.nn.functional as F from .utils import ( + _quant_int8_dynamic_per_token_linear, dynamically_quantize_per_channel, - quant_int8_dynamic_per_token_linear, ) __all__ = [ @@ -152,7 +152,7 @@ def forward(self, X, *args, **kwargs): W_int_repr_t = ( self.W_int_repr if self.store_w_int_repr_t else self.W_int_repr.t() ) - Y = quant_int8_dynamic_per_token_linear( + Y = _quant_int8_dynamic_per_token_linear( X, W_int_repr_t, self.W_scales, self.bias, X.dtype ) return Y diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index be0533510f..caffef7b58 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -9,10 +9,10 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.quantization.utils import ( + _quant_int8_dynamic_per_token_linear, dequantize_per_channel, dynamically_quantize_per_channel, groupwise_affine_quantize_tensor, - quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) from torchao.utils import ( @@ -244,7 +244,7 @@ def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs): @staticmethod def _quantized_op(act_mat, w_qtensor, bias): - return quant_int8_dynamic_per_token_linear( + return _quant_int8_dynamic_per_token_linear( act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype ) diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index 339d46be35..52bc721f1f 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import functools -from typing import Callable, Dict +from typing import Callable, Dict, Type import torch from torchao.core.config import AOBaseConfig _QUANTIZE_CONFIG_HANDLER: Dict[ - AOBaseConfig, + Type[AOBaseConfig], Callable[[torch.nn.Module, AOBaseConfig], torch.nn.Module], ] = {} diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index c7dd92d55c..c54b539036 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -3,8 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import importlib.util -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch from torch.utils._python_dispatch import TorchDispatchMode @@ -26,17 +25,23 @@ quantize_affine, ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, check_cpu_version, check_xpu_version, ) +from .granularity import ( + Granularity, + PerAxis, + PerGroup, + PerRow, + PerTensor, + PerToken, +) + __all__ = [ "compute_error", - "_apply_logging_hook", - "quantize_activation_per_token_absmax", - "quant_int8_dynamic_per_token_linear", - "quant_int8_per_token_matmul", + "_quantize_activation_per_token_absmax", + "_quant_int8_dynamic_per_token_linear", "dynamically_quantize_per_channel", "dequantize_per_tensor", "dequantize_per_channel", @@ -52,8 +57,6 @@ "recommended_inductor_config_setter", ] -_lm_eval_available = importlib.util.find_spec("lm_eval") is not None - # basic SQNR def compute_error(x, y): @@ -133,7 +136,7 @@ def xpu(self): ] -def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): +def _guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): if dtype is not None and tensor_arg.dtype != dtype: raise ValueError( f"Expected Tensor argument {arg_name} to have dtype {dtype}, but got {tensor_arg.dtype} instead." @@ -155,7 +158,7 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]: # taken from # https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 # and slightly modified -def quantize_activation_per_token_absmax(t): +def _quantize_activation_per_token_absmax(t): # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] mapping_type = MappingType.SYMMETRIC block_size = list(t.shape) @@ -188,7 +191,7 @@ def quantize_activation_per_token_absmax(t): return quantized, scale -def quant_int8_dynamic_per_token_linear( +def _quant_int8_dynamic_per_token_linear( x, w_vals_int8_t, w_scales, @@ -199,8 +202,8 @@ def quant_int8_dynamic_per_token_linear( like F.linear, but with int8 dynamic quantization of activation, and a quantized weight """ - x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - mm_out = quant_int8_per_token_matmul( + x_vals_int8, x_scales = _quantize_activation_per_token_absmax(x) + mm_out = _quant_int8_per_token_matmul( x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype ) if bias is not None: @@ -208,7 +211,7 @@ def quant_int8_dynamic_per_token_linear( return mm_out -def quant_int8_per_token_matmul( +def _quant_int8_per_token_matmul( x_vals_int8, x_scales, w_vals_int8_t, @@ -399,8 +402,8 @@ def get_groupwise_affine_qparams( def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16): - guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) - guard_dtype_size(zeros, "zeros", dtype=dtype) + _guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) + _guard_dtype_size(zeros, "zeros", dtype=dtype) dim = scales.dim() return ( torch.cat( @@ -454,7 +457,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_min, quant_max, ) - if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: + if w.shape[-1] > 1: if (not (check_cpu_version(int_data.device))) and ( not (check_xpu_version(int_data.device)) ): @@ -475,10 +478,8 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert groupsize > 1 assert w_int4x8.dim() == 2 # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path - if ( - TORCH_VERSION_AT_LEAST_2_5 - and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) - and not (check_cpu_version(w_int4x8.device)) + if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not ( + check_cpu_version(w_int4x8.device) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 @@ -686,3 +687,28 @@ def recommended_inductor_config_setter(): torch._inductor.config.fx_graph_cache = True torch._inductor.config.triton.unique_kernel_names = True torch.set_float32_matmul_precision("high") + + +def get_block_size( + input_shape: Tuple[int, ...], granularity: Granularity +) -> Tuple[int, ...]: + """Get the block size based on the input shape and granularity type. + + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + granularity: The granularity type of the quantization + """ + if isinstance(granularity, PerTensor): + return input_shape + elif isinstance(granularity, PerAxis): + block_size = list(input_shape) + block_size[granularity.axis] = 1 + return tuple(block_size) + elif isinstance(granularity, (PerRow, PerToken)): + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + assert input_shape[-1] % granularity.group_size == 0, ( + f"Group size {granularity.group_size} does not divide input shape {input_shape}" + ) + return (1,) * (len(input_shape) - 1) + (granularity.group_size,) + raise ValueError(f"Unsupported Granularity: {granularity}") diff --git a/torchao/quantization/weight_tensor_linear_activation_quantization.py b/torchao/quantization/weight_tensor_linear_activation_quantization.py index 6612213bc1..c0b0a893e4 100644 --- a/torchao/quantization/weight_tensor_linear_activation_quantization.py +++ b/torchao/quantization/weight_tensor_linear_activation_quantization.py @@ -8,10 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TorchAOBaseTensor __all__ = [ "WeightTensorWithLinearActivationQuantizationMetadata", @@ -201,8 +198,7 @@ def _(func, types, args, kwargs): WeightTensorWithLinearActivationQuantizationMetadata.from_float ) -if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals( - [WeightTensorWithLinearActivationQuantizationMetadata] - ) +# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals( + [WeightTensorWithLinearActivationQuantizationMetadata] +) diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index 6971bcc84b..2c62c2738a 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -53,11 +53,10 @@ Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regress ```py from torchao.quantization.quant_api import quantize_, Int4WeightOnlyConfig -from torchao.dtypes import MarlinSparseLayout # Your FP16 model model = model.cuda().half() -quantize_(model, Int4WeightOnlyConfig(layout=MarlinSparseLayout())) +quantize_(model, Int4WeightOnlyConfig(int4_packing_format="marlin_sparse")) ``` Note the existing API results in an extremely high accuracy degredation and is intended to be used in concert with an already sparsified+finetuned checkpoint where possible until we develop diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index b263b5e098..9214f8b1ef 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -50,6 +50,9 @@ def apply_fake_sparsity(model, **kwargs): class BlockSparseWeightConfig(AOBaseConfig): blocksize: int = 64 + def __post_init__(self): + torch._C._log_api_usage_once("torchao.sparsity.BlockSparseWeightConfig") + # for bc block_sparse_weight = BlockSparseWeightConfig @@ -72,7 +75,8 @@ class SemiSparseWeightConfig(AOBaseConfig): Configuration for converting the weight of linear modules to semi-structured (2:4) sparsity """ - pass + def __post_init__(self): + torch._C._log_api_usage_once("torchao.sparsity.SemiSparseWeightConfig") # for bc @@ -125,8 +129,9 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: # for int8 dynamic quantization + 2:4 sparsity from torchao.dtypes import SemiSparseLayout - m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) + m = quantize_(m, Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout), filter_fn) """ + torch._C._log_api_usage_once("torchao.sparsity.sparsify_") handler = _QUANTIZE_CONFIG_HANDLER[type(config)] _replace_with_custom_fn_if_matches_filter( model, diff --git a/torchao/sparsity/training/__init__.py b/torchao/sparsity/training/__init__.py index 3c4212101b..87ce3add4f 100644 --- a/torchao/sparsity/training/__init__.py +++ b/torchao/sparsity/training/__init__.py @@ -4,17 +4,15 @@ # LICENSE file in the root directory of this source tree. import torch +# load pointwise op support, which exists only for CUTLASS +from torch.sparse import SparseSemiStructuredTensorCUTLASS + from torchao.sparsity.training.autograd import semi_structured_sparsify from torchao.sparsity.training.pointwise_ops import CUTLASS_POINTWISE_OP_DISPATCH_TABLE -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - -# load pointwise op support, which exists only for CUTLASS -if TORCH_VERSION_AT_LEAST_2_3: - from torch.sparse import SparseSemiStructuredTensorCUTLASS - SparseSemiStructuredTensorCUTLASS._load_dispatch_table( - CUTLASS_POINTWISE_OP_DISPATCH_TABLE - ) +SparseSemiStructuredTensorCUTLASS._load_dispatch_table( + CUTLASS_POINTWISE_OP_DISPATCH_TABLE +) __all__ = [ "SemiSparseLinear", diff --git a/torchao/sparsity/training/autograd.py b/torchao/sparsity/training/autograd.py index fafbd7c3c3..40c6c98083 100644 --- a/torchao/sparsity/training/autograd.py +++ b/torchao/sparsity/training/autograd.py @@ -6,18 +6,14 @@ from enum import Enum import torch -from torch.sparse import SparseSemiStructuredTensor - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - -if TORCH_VERSION_AT_LEAST_2_3: - from torch.sparse import ( - SparseSemiStructuredTensorCUSPARSELT, - SparseSemiStructuredTensorCUTLASS, - ) - - torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUSPARSELT) - torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUTLASS) +from torch.sparse import ( + SparseSemiStructuredTensor, + SparseSemiStructuredTensorCUSPARSELT, + SparseSemiStructuredTensorCUTLASS, +) + +torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUSPARSELT) +torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUTLASS) GRADIENT_TYPE = Enum("GRADIENT_TYPE", ["DENSE", "SPARSE", "STE"]) diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index 24c0808a02..916fff6cd4 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -80,7 +80,7 @@ def forward(self, x_orig): new_axis_list[0], new_axis_list[-1] = new_axis_list[-1], new_axis_list[0] y = x.permute(new_axis_list) y = torch.flatten(y, start_dim=1) - norm = torch.norm(y, dim=1) ** 2 + norm = torch.linalg.vector_norm(y, dim=1) ** 2 if self.norm.numel() == 0: self.norm.resize_(norm.shape) diff --git a/torchao/sparsity/wanda.py b/torchao/sparsity/wanda.py index 7ad12a2d55..1f430c2ba8 100644 --- a/torchao/sparsity/wanda.py +++ b/torchao/sparsity/wanda.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import warnings -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch from torch import nn @@ -48,8 +48,7 @@ def __init__( ) super().__init__(defaults=defaults) - # `typing.Dict[, ]` to avoid runtime subscripting errors. - def prepare(self, model: nn.Module, config: List[Dict]) -> None: + def prepare(self, model: nn.Module, config: list[dict]) -> None: # activation: use PerChannelNormObserver # use no-op placeholder weight observer if config is None: @@ -88,35 +87,38 @@ def update_mask( # type: ignore[override] by comparing this metric across the whole current layer. """ - # Step 1: get the tensor and the mask from the parametrizations + # Step 1: get the attributes (tensor and mask) from the parametrizations mask = getattr(module.parametrizations, tensor_name)[0].mask tensor = getattr(module.parametrizations, tensor_name).original activation_norm_per_channel = module.activation_post_process.norm - # Step 2: Calculate Wx + # Step 2: Calculate pruning criteria : '|weight| * ||activation||' pruning_metric = torch.abs(tensor) * activation_norm_per_channel - # defaults for unstructured sparsity + # Step 3 : Calculate the number of elements (weight params) block_size = pruning_metric.numel() + + # Step 4 : Define pruning boundary : N(elements) * (pruning ratio) num_specified = int(block_size * sparsity_level) - # if set to use semi-structured, ignore sparsity_level + # if set to use semi-structured, ignore sparsity_level and apply 2:4 sparsity if kwargs.get("semi_structured_block_size", None) is not None: block_size = kwargs["semi_structured_block_size"] num_specified = block_size // 2 - # get indicies to prune + # Step 5 : Flatten it for sorting and prune weights pruning_inds = pruning_metric.view(-1, block_size).argsort(dim=1)[ :, :num_specified ] - # update mask + + # Step 6 : Reshape and prune weights mask.data.view(-1, block_size).scatter_( 1, pruning_inds, torch.zeros_like(pruning_inds, dtype=mask.dtype) ) def squash_mask( self, - params_to_keep: Optional[Tuple[str, ...]] = None, - params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, + params_to_keep: Optional[tuple[str, ...]] = None, + params_to_keep_per_layer: Optional[dict[str, tuple[str, ...]]] = None, *args, **kwargs, ): diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py index f59a1271b1..8f41a8464c 100644 --- a/torchao/testing/model_architectures.py +++ b/torchao/testing/model_architectures.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F # TODO: Refactor torchao and tests to use these models @@ -21,6 +22,27 @@ def forward(self, x): return x +class ConvWithSharedWeightInExportedModel(nn.Module): + def __init__( + self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) -> None: + super().__init__() + self.n_chunks = n_chunks + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x) -> torch.Tensor: + chunks = torch.chunk(x, self.n_chunks, dim=1) + outputs = [] + for chunk in chunks: + out = self.conv(chunk) + out = self.bn(out) + out = self.relu(out) + outputs.append(out) + return torch.cat(outputs, dim=1) + + class LNLinearActivationModel(nn.Module): def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16, activation="sigmoid"): super().__init__() @@ -177,3 +199,64 @@ def create_model_and_input_data( else: raise ValueError(f"Unknown model type: {model_type}") return model, input_data + + +# from https://github.com/meta-llama/llama-models/blob/a9c89c471f793423afd4cc3ca8671d6e56fe64cb/models/llama4/moe.py#L22 +class LlamaModelsLlama4Experts(nn.Module): + def __init__( + self, + num_local_experts: int, + dim: int, + hidden_dim: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + + self.num_local_experts = num_local_experts + self.dim = dim + + self.w1: nn.Parameter = nn.Parameter( + torch.randn( + num_local_experts, + dim, + hidden_dim, + dtype=dtype, + device=device, + ) + ) + + self.w2: nn.Parameter = nn.Parameter( + torch.randn( + num_local_experts, + hidden_dim, + dim, + dtype=dtype, + device=device, + ) + ) + + self.w3: nn.Parameter = nn.Parameter( + torch.randn( + num_local_experts, + dim, + hidden_dim, + dtype=dtype, + device=device, + ) + ) + + def forward( + self, + routed_in_egD: torch.Tensor, # noqa: N803 + ) -> torch.Tensor: + e = self.num_local_experts + D = self.dim + + x_egD = routed_in_egD.view(e, -1, D) + + middle_out_egF = F.silu(torch.bmm(x_egD, self.w1)) * torch.bmm(x_egD, self.w3) + out_egD = torch.bmm(middle_out_egF, self.w2) + out_egD = out_egD.view(-1, D) + + return out_egD diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index 5d903a4a15..f031386012 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -6,7 +6,6 @@ import copy import unittest -from typing import Dict import torch from torch.ao.quantization.backend_config import ( @@ -23,23 +22,16 @@ from torch.testing._internal.common_utils import TestCase from torchao.quantization.pt2e import FROM_NODE_KEY -from torchao.quantization.pt2e._numeric_debugger import _generate_debug_handle_from_node +from torchao.quantization.pt2e._numeric_debugger import _extract_node_source_debug_info from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 +from torchao.utils import torch_version_at_least -if TORCH_VERSION_AT_LEAST_2_5: - from torch.export import export_for_training - -@unittest.skipIf( - not TORCH_VERSION_AT_LEAST_2_5, - "only works for torch 2.5+ since export_for_training is only supported after 2.5", -) class PT2EQuantizationTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2 with some helper methods. @@ -79,7 +71,7 @@ def _test_quantizer( {0: torch.export.Dim("dim")} if i == 0 else None for i in range(len(example_inputs)) ) - m = export_for_training( + m = torch.export.export( m, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, @@ -120,7 +112,7 @@ def _test_quantizer( m_fx = _convert_to_reference_decomposed_fx( m_fx, backend_config=backend_config ) - m_fx = export_for_training( + m_fx = torch.export.export( m_fx, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, @@ -140,55 +132,68 @@ def _test_quantizer( return m -@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+") class PT2ENumericDebuggerTestCase(TestCase): """ Base test case class for PT2E numeric debugger tests containing common utility functions for numeric debugging functionality. """ - def _assert_each_node_has_debug_handle(self, model) -> None: - """Assert that each node in the model has a debug handle.""" - - def _assert_node_has_debug_handle(node): + def _assert_each_node_has_from_node_source(self, model) -> None: + def _assert_node_has_from_node_source(node): + if node.op == "placeholder" or node.op == "output": + return self.assertIn( FROM_NODE_KEY, node.meta, f"Node {node} doesn't have from_node info", ) - bfs_trace_with_node_process(model, _assert_node_has_debug_handle) + bfs_trace_with_node_process(model, _assert_node_has_from_node_source) - def _extract_debug_handles(self, model) -> Dict[str, int]: - """Extract debug handles from all nodes in the model.""" - debug_handle_map: Dict[str, int] = {} + def _extract_from_node_source(self, model) -> dict[str, any]: + from_node_source_map: dict[str, any] = {} - def _extract_debug_handles_from_node(node): - nonlocal debug_handle_map - if (dh := _generate_debug_handle_from_node(node)) is not None: - debug_handle_map[str(node)] = dh + def _extract_from_node_source_from_node(node): + nonlocal from_node_source_map + if (root_node_source := _extract_node_source_debug_info(node)) is not None: + from_node_source_map[str(node)] = ( + root_node_source.name, + root_node_source.graph_id, + ) - bfs_trace_with_node_process(model, _extract_debug_handles_from_node) - return debug_handle_map + bfs_trace_with_node_process(model, _extract_from_node_source_from_node) - def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]: - prev_decomp_op_to_debug_handle_map: dict[str, int] = {} + return from_node_source_map - def _extract_debug_handles_with_prev_decomp_op_from_node(node): - nonlocal prev_decomp_op_to_debug_handle_map - if FROM_NODE_KEY in node.meta: + def _extract_from_node_source_with_prev_decomp_op(self, model) -> dict[str, any]: + prev_decomp_op_to_from_node_source_map: dict[str, any] = {} + + def _extract_from_node_source_with_prev_decomp_op_from_node(node): + nonlocal prev_decomp_op_to_from_node_source_map + if FROM_NODE_KEY in node.meta and node.meta[FROM_NODE_KEY] is not None: prev_decomp_op = str(node.meta.get("nn_module_stack")) - debug_handle = _generate_debug_handle_from_node(node) - if prev_decomp_op not in prev_decomp_op_to_debug_handle_map: - prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle + from_node_source = _extract_node_source_debug_info(node) + if prev_decomp_op not in prev_decomp_op_to_from_node_source_map: + prev_decomp_op_to_from_node_source_map[prev_decomp_op] = ( + from_node_source + ) else: assert ( - prev_decomp_op_to_debug_handle_map[prev_decomp_op] - == debug_handle - ), f"Node {node} has different debug handle {debug_handle}" - "than previous node sharing the same decomp op {prev_decomp_op}" + prev_decomp_op_to_from_node_source_map[prev_decomp_op] + == from_node_source + ), ( + f"Node {node} has different from_node info {from_node_source}" + f"than previous node sharing the same decomp op {prev_decomp_op}" + ) bfs_trace_with_node_process( - model, _extract_debug_handles_with_prev_decomp_op_from_node + model, _extract_from_node_source_with_prev_decomp_op_from_node + ) + return prev_decomp_op_to_from_node_source_map + + def assertNodeSourcesEqual(self, node_source_1, node_source_2): + self.assertTrue( + node_source_1.name == node_source_2.name + and node_source_1.graph_id == node_source_2.graph_id ) - return prev_decomp_op_to_debug_handle_map diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 7ebf67d53c..acbfbb6a3e 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -32,11 +32,11 @@ class FeedForward(nn.Module): """MLP based model""" - def __init__(self): + def __init__(self, size): super(FeedForward, self).__init__() - self.w1 = nn.Linear(16, 32, bias=False) - self.w2 = nn.Linear(16, 32, bias=False) - self.out_proj = nn.Linear(32, 16, bias=False) + self.w1 = nn.Linear(size, size * 2, bias=False) + self.w2 = nn.Linear(size, size * 2, bias=False) + self.out_proj = nn.Linear(size * 2, size, bias=False) def forward(self, x): x = F.silu(self.w1(x)) * self.w2(x) @@ -45,9 +45,9 @@ def forward(self, x): class ToyModel(nn.Module): - def __init__(self): + def __init__(self, size): super(ToyModel, self).__init__() - self.ffn = FeedForward() + self.ffn = FeedForward(size) def forward(self, x): return self.ffn(x) @@ -56,7 +56,7 @@ def forward(self, x): def _test_lowp_mlp_tensor_parallelism_base( mesh: DeviceMesh, config: Union[Float8LinearConfig, MXLinearConfig], - size=16, + size=32, compile: bool = False, allgather_in_lowp: bool = False, ): @@ -67,7 +67,7 @@ def _test_lowp_mlp_tensor_parallelism_base( if isinstance(config, MXLinearConfig): convert_model_func = quantize_ - toy_model = ToyModel().to(device) + toy_model = ToyModel(size).to(device) toy_model_fp8 = copy.deepcopy(toy_model) convert_model_func(toy_model_fp8, config=config) @@ -151,8 +151,8 @@ def _test_lowp_mlp_tensor_parallelism_base( sp_model = torch.compile(sp_model) sp_model2 = torch.compile(sp_model2) - x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) - go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False) + go_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False) x_fp32_tp_input = x_fp32.clone() go_fp32_tp = go_fp32.clone() x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index 286803dbf2..6c51cef0b0 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -65,8 +65,9 @@ } -def get_specs(): - gpu_name = torch.cuda.get_device_name(0) +def get_specs(gpu_name: Optional[str] = None): + if gpu_name is None: + gpu_name = torch.cuda.get_device_name(0) return gpu_name_to_specs[gpu_name] @@ -188,6 +189,7 @@ def get_tensor_memory_traffic_ovhd_s( assert mx_recipe_name in ( "mxfp8_emulated", "mxfp8_cublas", + "mxfp8_cublas_rceil", ), "unsupported" # For now, assume that we can't profitably fuse kernel 1 and kernel 2 # x_bf16 = ... @@ -213,10 +215,15 @@ def get_tensor_memory_traffic_ovhd_s( def get_individual_gemm_time_sympy( - M: sympy.Symbol, K: sympy.Symbol, N: sympy.Symbol, dtype, mx_recipe_name + M: sympy.Symbol, + K: sympy.Symbol, + N: sympy.Symbol, + dtype, + mx_recipe_name, + gpu_name: Optional[str] = None, ) -> sympy.Symbol: # compute bound - specs = get_specs() + specs = get_specs(gpu_name) gemm_ops = 2 * M * K * N if dtype is torch.bfloat16: peak_tops = specs["bf16_peak_tops"] @@ -234,6 +241,7 @@ def get_individual_gemm_time_sympy( assert mx_recipe_name in ( "mxfp8_emulated", "mxfp8_cublas", + "mxfp8_cublas_rceil", ), "unsupported" assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), "unsupported" # adjust reads for MX scaling @@ -263,6 +271,7 @@ def get_gemm_time_sympy( dtype, float8_recipe_name: Optional[str], mx_recipe_name: Optional[str], + gpu_name: Optional[str], ): # next: add rowwise_with_gw_hp here # note: this function is currently not super accurate for small shapes: @@ -277,13 +286,13 @@ def get_gemm_time_sympy( gemm_dtype_grad_weight = torch.bfloat16 gemm_output_time_s = get_individual_gemm_time_sympy( - M, K, N, gemm_dtype_input, mx_recipe_name + M, K, N, gemm_dtype_input, mx_recipe_name, gpu_name ) gemm_grad_input_time_s = get_individual_gemm_time_sympy( - M, N, K, gemm_dtype_grad_input, mx_recipe_name + M, N, K, gemm_dtype_grad_input, mx_recipe_name, gpu_name ) gemm_grad_weight_time_s = get_individual_gemm_time_sympy( - K, M, N, gemm_dtype_grad_weight, mx_recipe_name + K, M, N, gemm_dtype_grad_weight, mx_recipe_name, gpu_name ) total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s return total @@ -296,8 +305,9 @@ def get_float8_mem_sympy( float8_recipe_name: Optional[str], mx_recipe_name: Optional[str], enable_fusion_modeling: bool, + gpu_name: Optional[str] = None, ): - specs = get_specs() + specs = get_specs(gpu_name) # there are three gemms in the fwd/bwd of a linear: # @@ -340,3 +350,80 @@ def get_float8_mem_sympy( res = sum([*fwd_fp8_input_mem, *fwd_fp8_weight_mem, *gi_fp8_grad_output_mem]) return res + + +def get_inference_tensor_memory_traffic_ovhd_s( + specs, + dim0, + dim1, + tensor_role: str, + float8_recipe_name: Optional[str], + fuse_with_prev=False, +) -> List[Union[sympy.Symbol, float]]: + """ + Inference version of `get_tensor_memory_traffic_ovhd_s`. + The only thing happening here is we quantize the activation. + """ + assert float8_recipe_name == "rowwise", "unsupported" + assert fuse_with_prev is False, "unsupported" + + # assumes input bf16, output f8 + numel = dim0 * dim1 + + res_bytes = None + + assert tensor_role == "input" + # x_bf16 = ... + # kernel 1: x_bf16 -> x_fp8 + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [ + kernel_1_rw, + ] + + # convert from bytes to seconds + res_s = [ + x / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"] + for x in res_bytes + ] + + # take max of kernel_overhead, r/w time + res_s = [sympy.Max(x, KERNEL_LAUNCH_OVERHEAD_SEC) for x in res_s] + + return res_s + + +def get_inference_float8_mem_sympy( + M, + K, + N, + float8_recipe_name: Optional[str], + gpu_name: Optional[str] = None, +): + specs = get_specs(gpu_name) + # input @ weight_t = output + # MxK @ KxN => MxN + fwd_fp8_input_mem = get_inference_tensor_memory_traffic_ovhd_s( + specs, + M, + K, + tensor_role="input", + float8_recipe_name=float8_recipe_name, + fuse_with_prev=False, + ) + res = sum([*fwd_fp8_input_mem]) + return res + + +def get_inference_gemm_time_sympy( + M: sympy.Symbol, + K: sympy.Symbol, + N: sympy.Symbol, + dtype, + float8_recipe_name: Optional[str], + gpu_name: Optional[str], +): + assert float8_recipe_name == "rowwise" or float8_recipe_name is None, "unsupported" + # note: this function is currently not super accurate for small shapes: + # when M,K,N <= 1k,1k,1k it undercounts by around 2x + gemm_output_time_s = get_individual_gemm_time_sympy(M, K, N, dtype, None, gpu_name) + return gemm_output_time_s diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 26a738c53b..bb9c2ca8dc 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -17,9 +17,16 @@ import torchao from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx -from torchao.quantization import int8_weight_only, quantize_ +from torchao.quantization import Int8WeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, get_compute_capability +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, +) +from torchao.testing.model_architectures import LlamaModelsLlama4Experts +from torchao.utils import ( + DummyModule, + get_compute_capability, +) """ How to use: @@ -324,7 +331,7 @@ class TorchAOTensorParallelTestCase(DTensorTestBase): COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor - QUANT_METHOD_FN = staticmethod(int8_weight_only) + QUANT_METHOD_FN = staticmethod(Int8WeightOnlyConfig) QUANT_METHOD_KWARGS = {} @staticmethod @@ -412,19 +419,199 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist(up_dist(input_dtensor)) - if not TORCH_VERSION_AT_LEAST_2_6: - # Need torch 2.6 to support compiled tensor parallelism - return - up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) dn_compiled(y_up) +class TorchAOIntegrationTestCase(common_utils.TestCase): + def _test_slice_and_copy_similar_to_vllm(self, config): + # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly + # the test is similar to the linked code, but with some hardcoded arguments + # and does not use tensor parallelism + + dtype = torch.bfloat16 + device = "cuda" + l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype) + quantize_(l, config) + + # high level, we do a narrow for both param.data and the loaded_weights + # and do inplace copy_ to copy from the loaded_weights into param.data + + # simulate loaded_weight + dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + # making the weight different + dummy_l.weight = torch.nn.Parameter( + dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), + requires_grad=False, + ) + quantize_(dummy_l, config) + + output_dim = 0 + shard_size = 512 + for tp_rank in [0, 1]: + start_idx = tp_rank * shard_size + param = l.weight + param_data = param.data + param_data = param_data.narrow(output_dim, start_idx, shard_size) + orig_value = param_data.qdata[0][0] + loaded_weight = dummy_l.weight + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0] + assert not torch.equal(orig_value, loaded_weight.qdata[0][0]) + param_data.copy_(loaded_weight) + # making sure param.data is updated to loaded_weight + assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0]) + if hasattr(param_data, "scale"): + assert torch.equal(param_data.scale, loaded_weight.scale) + if hasattr(param_data, "zero_point"): + assert torch.equal(param_data.zero_point, loaded_weight.zero_point) + if hasattr(param_data, "scale_and_zero"): + assert torch.equal( + param_data.scale_and_zero, loaded_weight.scale_and_zero + ) + + def _test_moe_weight_reshape_ops(self, config): + """This is testing the op call sequence in saving and loading quantization + checkpoints in llama-models for llama4 + (https://github.com/meta-llama/llama-models/tree/main/models/llama4) + """ + # only per row quantization is supported for bmm + dtype = torch.bfloat16 + device = "cuda" + + def _quantize_experts(model, config): + for _, module in model.named_modules(): + if not isinstance(module, LlamaModelsLlama4Experts): + continue + + expert_module = module + for weight_name in ["w1", "w2", "w3"]: + weight = getattr(expert_module, weight_name) + config_handler = _QUANTIZE_CONFIG_HANDLER[type(config)] + dummy_mod = DummyModule(weight) + quant_mod = config_handler(dummy_mod, config) + setattr(expert_module, weight_name, quant_mod.weight) + + batch_size = 4 + num_experts = 2 + input_dim = 64 + dim = 128 + hidden_dim = 256 + + moe1 = LlamaModelsLlama4Experts(num_experts, dim, hidden_dim, dtype, device) + moe2 = LlamaModelsLlama4Experts(num_experts, dim, hidden_dim, dtype, device) + moe_combined = LlamaModelsLlama4Experts( + num_experts, dim, 2 * hidden_dim, dtype, device + ) + input = torch.randn(batch_size, input_dim, dim, dtype=dtype, device=device) + + moes = [moe1, moe2] + + for moe in moes: + moe(input) + + # need to transpose before quantizing + moe.w1 = torch.nn.Parameter( + moe.w1.transpose(1, 2).contiguous(), requires_grad=False + ) + moe.w2 = torch.nn.Parameter( + moe.w2.transpose(1, 2).contiguous(), requires_grad=False + ) + moe.w3 = torch.nn.Parameter( + moe.w3.transpose(1, 2).contiguous(), requires_grad=False + ) + + _quantize_experts(moe, config) + + before = moe(input) + + # transposing for resharding support since only 2D resharding is supported + new_last_dim = moe.w1.shape[-2] + moe.w1 = torch.nn.Parameter( + moe.w1.transpose(1, 2).reshape(-1, new_last_dim).contiguous(), + requires_grad=False, + ) + new_last_dim = moe.w2.shape[-2] + moe.w2 = torch.nn.Parameter( + moe.w2.transpose(1, 2).reshape(-1, new_last_dim).contiguous(), + requires_grad=False, + ) + new_last_dim = moe.w3.shape[-2] + moe.w3 = torch.nn.Parameter( + moe.w3.transpose(1, 2).reshape(-1, new_last_dim).contiguous(), + requires_grad=False, + ) + + moe.w1 = torch.nn.Parameter( + moe.w1.unflatten(0, (num_experts, -1)).squeeze(dim=0), + requires_grad=False, + ) + moe.w2 = torch.nn.Parameter( + moe.w2.unflatten(0, (num_experts, -1)).squeeze(dim=0), + requires_grad=False, + ) + moe.w3 = torch.nn.Parameter( + moe.w3.unflatten(0, (num_experts, -1)).squeeze(dim=0), + requires_grad=False, + ) + + # transpose again to recover the original weights + moe.w1 = torch.nn.Parameter( + moe.w1.transpose(1, 2).contiguous(), requires_grad=False + ) + moe.w2 = torch.nn.Parameter( + moe.w2.transpose(1, 2).contiguous(), requires_grad=False + ) + moe.w3 = torch.nn.Parameter( + moe.w3.transpose(1, 2).contiguous(), requires_grad=False + ) + + after = moe(input) + self.assertEqual(before, after) + + state_dicts = [moe1.state_dict(), moe2.state_dict()] + # align the scale parameter so they can be concatenated + for key in ["w1", "w2", "w3"]: + weights = [st[key] for st in state_dicts] + for i in range(1, len(weights)): + weights[i].scale = weights[0].scale + if hasattr(weights[i], "zero_point"): + weights[i].zero_point = weights[0].zero_point + + def process_key(key: str) -> torch.Tensor: + tensors = [s[key] for s in state_dicts] + # Note: we have a hacky implementation for cat in user codebase + # since it is not implemented correctly before + if key == "w2": + return torch.cat(tensors, dim=-1) + else: + return torch.cat(tensors, dim=-2) + + new_state_dict = {} + for key in ["w1", "w2", "w3"]: + new_state_dict[key] = process_key(key) + + moe_combined.w1 = torch.nn.Parameter( + moe_combined.w1.transpose(1, 2), requires_grad=False + ) + moe_combined.w2 = torch.nn.Parameter( + moe_combined.w2.transpose(1, 2), requires_grad=False + ) + moe_combined.w3 = torch.nn.Parameter( + moe_combined.w3.transpose(1, 2), requires_grad=False + ) + moe_combined.load_state_dict(new_state_dict, assign=True) + # make sure it runs + moe_combined(input) + + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) common_utils.instantiate_parametrized_tests(TorchAOTensorParallelTestCase) + if __name__ == "__main__": unittest.main() diff --git a/torchao/utils.py b/torchao/utils.py index 1a12fb0668..ae72919c06 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -8,13 +8,15 @@ import itertools import re import time +import warnings from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable +from typing import Any, Callable, Optional, Type import torch import torch.nn.utils.parametrize as parametrize +from torch.utils._python_dispatch import return_and_correct_aliasing __all__ = [ "benchmark_model", @@ -27,25 +29,27 @@ "get_model_size_in_bytes", "unwrap_tensor_subclass", "TorchAOBaseTensor", + "is_MI300", + "is_sm_at_least_89", + "is_sm_at_least_90", + "is_package_at_least", + "DummyModule", + # Deprecated "TORCH_VERSION_AT_LEAST_2_2", "TORCH_VERSION_AT_LEAST_2_3", "TORCH_VERSION_AT_LEAST_2_4", "TORCH_VERSION_AT_LEAST_2_5", "TORCH_VERSION_AT_LEAST_2_6", "TORCH_VERSION_AT_LEAST_2_7", - # Needs to be deprecated in the future "TORCH_VERSION_AFTER_2_2", "TORCH_VERSION_AFTER_2_3", "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", - "is_MI300", - "is_sm_at_least_89", - "is_sm_at_least_90", - "is_package_at_least", ] # Referenced from: https://github.com/pytorch/pytorch/blob/9105d54c6b37099575c0059ef274c86c4dc80c57/torch/ao/quantization/utils.py#L711 +@functools.cache def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: """ Returns the unique device for a module, or None if no device is found. @@ -138,9 +142,8 @@ def get_available_devices(): devices.append("cuda") elif torch.xpu.is_available(): devices.append("xpu") - if TORCH_VERSION_AT_LEAST_2_5: - if torch.mps.is_available(): - devices.append("mps") + if torch.mps.is_available(): + devices.append("mps") return devices @@ -201,7 +204,7 @@ def _the_op_that_needs_to_be_preserved(...) # after this, `_the_op_that_needs_to_be_preserved` will be preserved as # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after - # torch.export.export / torch._export.export_for_training + # torch.export.export """ from torch._inductor.decomposition import register_decomposition @@ -213,29 +216,31 @@ def _the_op_that_needs_to_be_preserved(...) ) def decorator(fn): - if TORCH_VERSION_AT_LEAST_2_5: - from torch._library.infer_schema import infer_schema + from torch._library.infer_schema import infer_schema - # expecting fn.__name__ starts with `_` and we want to take the rest - # to be the name of the custom op - assert fn.__name__[0] == "_", ( - f"Expecting function name starts with `_`, got {fn.__name__}" - ) - assert not any(c in fn.__name__ for c in ".<>"), ( - f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" - ) - op_name = fn.__name__[1:] - schema = op_name + infer_schema(fn, mutates_args={}) - lib.define(schema) - lib.impl(op_name, fn, dispatch_key) - - lib_namespace = lib.ns - op = getattr(getattr(torch.ops, lib_namespace), op_name) - if inductor_decomposed: - register_decomposition([op])(fn) - return op - else: - return fn + assert not any(c in fn.__name__ for c in ".<>"), ( + f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + ) + op_name = fn.__name__ + if op_name[0] == "_": + op_name = op_name[1:] + schema = op_name + infer_schema(fn, mutates_args={}) + lib.define(schema) + lib.impl(op_name, fn, dispatch_key) + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + if inductor_decomposed: + register_decomposition([op])(fn) + return op + + return decorator + + +def _register_meta_op(lib, op_name): + def decorator(fn): + op = lib.impl(op_name, fn, "Meta") + return op return decorator @@ -344,36 +349,107 @@ def _is_float8_type(dtype: torch.dtype) -> bool: def parse_version(version_string): - # Extract just the X.Y.Z part from the version string - match = re.match(r"(\d+\.\d+\.\d+)", version_string) + """ + Parse version string representing pre-release with -1 + + Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0] + """ + # Check for pre-release indicators + is_prerelease = bool(re.search(r"(git|dev)", version_string)) + match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string) if match: - version = match.group(1) - return [int(x) for x in version.split(".")] + major, minor, patch = map(int, match.groups()) + if is_prerelease: + patch = -1 + return [major, minor, patch] else: raise ValueError(f"Invalid version string format: {version_string}") -def compare_versions(v1, v2): - v1_parts = parse_version(v1) - v2_parts = parse_version(v2) - return (v1_parts > v2_parts) - (v1_parts < v2_parts) - - def is_fbcode(): return not hasattr(torch.version, "git_version") def torch_version_at_least(min_version): - return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 + if is_fbcode(): + return True + # Parser for local identifiers + return parse_version(torch.__version__) >= parse_version(min_version) -TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0") -TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0") -TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") -TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") -TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") -TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0") -TORCH_VERSION_AT_LEAST_2_2 = torch_version_at_least("2.2.0") + +def _deprecated_torch_version_at_least(version_str: str) -> str: + """ + Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log + a deprecation warning if the variable is used. + """ + version_str_var_name = "_".join(version_str.split(".")[:2]) + deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" + return _BoolDeprecationWrapper( + torch_version_at_least(version_str), + deprecation_msg, + ) + + +def _deprecated_torch_version_after(version_str: str) -> str: + """ + Wrapper for existing TORCH_VERSION_AFTER* variables that will log + a deprecation warning if the variable is used. + """ + bool_value = is_fbcode() or version("torch") >= version_str + version_str_var_name = "_".join(version_str.split(".")[:2]) + deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" + return _BoolDeprecationWrapper(bool_value, deprecation_msg) + + +class _BoolDeprecationWrapper: + """ + A deprecation wrapper that logs a warning when the given bool value is accessed. + """ + + def __init__(self, bool_value: bool, msg: str): + self.bool_value = bool_value + self.msg = msg + + def __bool__(self): + warnings.warn(self.msg) + return self.bool_value + + def __eq__(self, other): + return bool(self) == bool(other) + + +# Deprecated, use `torch_version_at_least` directly instead +TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0") +TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0") +TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") +TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") +TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0") +TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0") +TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0") +TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev") +TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev") +TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev") +TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") + + +class _ConfigDeprecationWrapper: + """ + A deprecation wrapper that directs users from a deprecated "config function" + (e.g. `int4_weight_only`) to the replacement config class. + """ + + def __init__(self, deprecated_name: str, config_cls: Type): + self.deprecated_name = deprecated_name + self.config_cls = config_cls + + def __call__(self, *args, **kwargs): + warnings.warn( + f"`{self.deprecated_name}` is deprecated and will be removed in a future release. " + f"Please use `{self.config_cls.__name__}` instead. Example usage:\n" + f" quantize_(model, {self.config_cls.__name__}(...))" + ) + return self.config_cls(*args, **kwargs) """ @@ -401,6 +477,9 @@ def _(func, types, args, kwargs): if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"): cls._ATEN_OP_OR_TORCH_FN_TABLE = {} + if cls not in cls._ATEN_OP_OR_TORCH_FN_TABLE: + cls._ATEN_OP_OR_TORCH_FN_TABLE[cls] = {} + if not isinstance(aten_ops_or_torch_fns, (list, tuple)): aten_ops_or_torch_fns = [aten_ops_or_torch_fns] @@ -411,12 +490,154 @@ def decorator(func): def wrapper(f, types, args, kwargs): return func(f, types, args, kwargs) - cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper + cls._ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper return func return decorator +def _implements_common_tensor_ops(cls): + implements = cls.implements + aten = torch.ops.aten + + @implements( + [ + torch.Tensor.contiguous, + ] + ) + def _(func, types, args, kwargs): + return args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)) + + @implements( + [ + aten.detach.default, + aten.clone.default, + aten.alias.default, + aten.contiguous.default, + ] + ) + def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + + def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: + _tensor_shape_match = all( + getattr(self, t_name).shape == getattr(src, t_name).shape + for t_name in self.tensor_data_names + ) + _optional_tensor_shape_match = True + if hasattr(self, "optional_tensor_data_names"): + # either both are None or both are not Tensors and the shape match + _optional_tensor_shape_match = all( + getattr(self, t_name).shape == getattr(src, t_name).shape + if getattr(self, t_name) is not None + else getattr(src, t_name) is None + for t_name in self.optional_tensor_data_names + ) + + _attr_match = all( + getattr(self, a_name) == getattr(src, a_name) + for a_name in self.tensor_attribute_names + ) + + _optional_attr_match = True + if hasattr(self, "optional_tensor_attribute_names"): + _optional_attr_match = all( + getattr(self, a_name) == getattr(src, a_name) + for a_name in self.optional_tensor_attribute_names + ) + + return ( + type(self) == type(src) + and self.shape == src.shape + and _tensor_shape_match + and _optional_tensor_shape_match + and _attr_match + and _optional_attr_match + ) + + @implements(aten.copy_.default) + def _(func, types, args, kwargs): + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + @implements(aten._to_copy.default) + def _(func, types, args, kwargs): + self = args[0] + if hasattr(self, "tensor_data_names") and hasattr( + self, "tensor_attribute_names" + ): + kwargs = self._get_to_kwargs(*args[1:], **kwargs) + device = kwargs.pop("device") + tensors = [ + getattr(self, name).to(device) for name in self.tensor_data_names + ] + optional_tensors = [] + if hasattr(self, "optional_tensor_data_names"): + for tensor_data_name in self.optional_tensor_data_names: + maybe_tensor = getattr(self, tensor_data_name) + if maybe_tensor is not None: + optional_tensors.append(maybe_tensor.to(device)) + else: + optional_tensors.append(None) + + # change device + tensor_attributes = [ + getattr(self, attr_name) if attr_name != "device" else device + for attr_name in self.tensor_attribute_names + ] + optional_tensor_attributes = [] + if hasattr(self, "optional_tensor_attribute_names"): + optional_tensor_attributes = [ + getattr(self, attr_name) if attr_name != "device" else device + for attr_name in self.optional_tensor_attribute_names + ] + + t = self.__class__( + *tensors, + *tensor_attributes, + *optional_tensors, + *optional_tensor_attributes, + ) + return return_and_correct_aliasing(func, args, kwargs, t) + + raise NotImplementedError( + "Subclasses must implement `aten._to_copy.default` or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it" + ) + + +def _torchao_base_tensor__setstate__(self, state): + assert hasattr(self, "tensor_data_names") and hasattr( + self, "tensor_attribute_names" + ) + torch._utils._set_obj_state(self, state) + for optional_tensor_data_name in getattr(self, "optional_tensor_data_names", []): + if optional_tensor_data_name not in self.__dict__ and not hasattr( + self, optional_tensor_data_name + ): + setattr(self, optional_tensor_data_name, None) + + for optional_tensor_attribute_name in getattr( + self, "optional_tensor_attribute_names", [] + ): + if optional_tensor_attribute_name not in self.__dict__ and not hasattr( + self, optional_tensor_attribute_name + ): + setattr(self, optional_tensor_attribute_name, None) + + def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None): """Use this util function for a common `__torch_function__` implementation that dispatches to ops/functions registered with `_implements` @@ -428,9 +649,10 @@ class MyTensor(torch.Tensor): kwargs = {} if kwargs is None else kwargs if ( hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") - and func in cls._ATEN_OP_OR_TORCH_FN_TABLE + and cls in cls._ATEN_OP_OR_TORCH_FN_TABLE + and func in cls._ATEN_OP_OR_TORCH_FN_TABLE[cls] ): - return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) + return cls._ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, types, args, kwargs) with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @@ -446,9 +668,10 @@ class MyTensor(torch.Tensor): """ if ( hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") - and func in cls._ATEN_OP_OR_TORCH_FN_TABLE + and cls in cls._ATEN_OP_OR_TORCH_FN_TABLE + and func in cls._ATEN_OP_OR_TORCH_FN_TABLE[cls] ): - return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) + return cls._ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, types, args, kwargs) arg_types = tuple(type(arg) for arg in args) kwarg_types = {k: type(arg) for k, arg in kwargs.items()} @@ -479,9 +702,8 @@ def decorator(tensor_impl_class): tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] = ( tensor_impl_class.from_plain ) - if TORCH_VERSION_AT_LEAST_2_5: - # Allow serialization to work for models uses this tensor impl subclass - torch.serialization.add_safe_globals([layout_class, tensor_impl_class]) + # Allow serialization to work for models uses this tensor impl subclass + torch.serialization.add_safe_globals([layout_class, tensor_impl_class]) return tensor_impl_class return decorator @@ -566,26 +788,233 @@ class PlainAQTTensorImpl(...): tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) + class variables to define to simplify implmentation of tensor subclasses: + `tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match + the `__init__` list of tensor subclass + `tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, + order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments + `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor data attributes, when defined, this will be a list of names of Tensors that can be optional + `optional_tensor_attribute_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional non-Tensor attributes, when defined, this will be a list of names of attributes that can be optional + Note: Argument order in __init__ and __new__ should match exaclty with tensor_data_names + tensor_attribute_names + optional_tensor_data_names (if present) + optional_tensor_attribute_names (if present) + + + If `tensor_data_names` (torch.Tensor data attribute names) and `tensor_attribute_names` (non-torch.Tensor attribute names) are defined, there are some additional + functions that will be added, this includes: + `__tensor_flatten__`: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data, + second element is a dict from attribute_name to non-Tensor attributes + `__tensor_unflatten__`: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor + `_apply_fn_to_data`: takes a function (Tensor -> Tensor), applies function to all tensor data and + recreate a new subclassed Tensor with the transformed tensor data + `__repr__`: the string representation of the subclassed tensor instance + `_same_metadata`: returns whether the metadata is the same between two instances of cls + `__setstate__`: when loading a serialized tensor subclass checkpoints, it sets the new + optional tensor and tensor attribute that is saved in the old checkpoint to None, + to maintain BC of old checkpoints when we add new optional tensor data or attributes to + the tensor subclass + torch ops: torch.Tensor.contiguous + aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to) + + Example: + class MyTensor(torch.Tensor): + tensor_data_names = ["a", "b"] + tensor_attribute_names = ["c", "d"] + optional_tensor_data_names = ["e", "f"] + optional_tensor_attribute_names = ["g", "h"] + + + def __new__( + cls, + a: Tensor, + b: Tensor, + c: int, + d: str, + e: Optional[Tensor] = None, + f: Optional[Tensor] = None, + g: Optional[int] = None, + h: Optional[int] = None, + ): + pass + + def __init__( + self, + a: Tensor, + b: Tensor, + c: int, + d: str + e: Optional[Tensor] = None, + f: Optional[Tensor] = None, + g: Optional[int] = None, + h: Optional[int] = None, + ): + pass + """ + @classmethod + def __init_subclass__(cls, **kwargs): + if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"): + cls._ATEN_OP_OR_TORCH_FN_TABLE = {} + + if cls not in cls._ATEN_OP_OR_TORCH_FN_TABLE: + cls._ATEN_OP_OR_TORCH_FN_TABLE[cls] = {} + + # define the common ops and __set_state__ for BC + # if the tensor_data_names and tensor_attribute_names are defined + if hasattr(cls, "tensor_data_names") and hasattr(cls, "tensor_attribute_names"): + cls._implements_common_tensor_ops() + cls.__setstate__ = _torchao_base_tensor__setstate__ + + # inherit the torch function and dispatch implementations from direct parent classes + # e.g. for `class C(B, A)`, C.__bases__ == (B, A) + for parent in cls.__bases__: + if parent in cls._ATEN_OP_OR_TORCH_FN_TABLE: + cls._ATEN_OP_OR_TORCH_FN_TABLE[cls].update( + cls._ATEN_OP_OR_TORCH_FN_TABLE[parent] + ) + implements = classmethod(_implements) + _implements_common_tensor_ops = classmethod(_implements_common_tensor_ops) __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) __torch_function__ = classmethod(_dispatch__torch_function__) register_layout = classmethod(_register_layout) get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor) _get_to_kwargs = _get_to_kwargs + def __init__(self, *args, **kwargs): + torch._C._log_api_usage_once(str(type(self))) + def __tensor_flatten__(self): - raise NotImplementedError("Subclasses must implement __tensor_flatten__") + if hasattr(self, "tensor_data_names") and hasattr( + self, "tensor_attribute_names" + ): + tensor_data_names = self.tensor_data_names.copy() + if hasattr(self, "optional_tensor_data_names"): + for tensor_data_name in self.optional_tensor_data_names: + maybe_tensor = getattr(self, tensor_data_name) + if maybe_tensor is not None: + tensor_data_names.append(tensor_data_name) + + attr_dict = { + attr: getattr(self, attr) for attr in self.tensor_attribute_names + } + if hasattr(self, "optional_tensor_attribute_names"): + attr_dict = attr_dict | { + attr: getattr(self, attr) + for attr in self.optional_tensor_attribute_names + } + + return tensor_data_names, attr_dict + + raise NotImplementedError( + "Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it" + ) @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - raise NotImplementedError("Subclasses must implement __tensor_unflatten__") + if hasattr(cls, "tensor_data_names") and hasattr(cls, "tensor_attribute_names"): + required_tensors = [ + tensor_data_dict[name] for name in cls.tensor_data_names + ] + optional_tensor_dict = {} + if hasattr(cls, "optional_tensor_data_names"): + optional_tensor_dict = { + tensor_data_name: tensor_data_dict.get(tensor_data_name, None) + for tensor_data_name in cls.optional_tensor_data_names + } + + required_attributes = [ + tensor_attributes[name] for name in cls.tensor_attribute_names + ] + optional_attribute_dict = {} + if hasattr(cls, "optional_tensor_attribute_names"): + optional_attribute_dict = { + name: tensor_attributes[name] + for name in cls.optional_tensor_attribute_names + } + + return cls( + *required_tensors, + *required_attributes, + **optional_tensor_dict, + **optional_attribute_dict, + ) + + raise NotImplementedError( + "Subclasses should implement __tensor_unflatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it" + ) + + def _apply_fn_to_data(self, fn): + if hasattr(self, "tensor_data_names") and hasattr( + self, "tensor_attribute_names" + ): + required_tensors = [ + fn(getattr(self, attr)) for attr in self.tensor_data_names + ] + optional_tensor_dict = {} + if hasattr(self, "optional_tensor_data_names"): + for tensor_data_name in self.optional_tensor_data_names: + maybe_tensor = getattr(self, tensor_data_name) + if maybe_tensor is not None: + optional_tensor_dict[tensor_data_name] = fn(maybe_tensor) + else: + optional_tensor_dict[tensor_data_name] = None + + required_attributes = [ + getattr(self, attr) for attr in self.tensor_attribute_names + ] + optional_attribute_dict = {} + if hasattr(self, "optional_tensor_attribute_names"): + optional_attribute_dict = { + attr_name: getattr(self, attr_name) + for attr_name in self.optional_tensor_attribute_names + } + + return self.__class__( + *required_tensors, + *required_attributes, + **optional_tensor_dict, + **optional_attribute_dict, + ) + + raise NotImplementedError( + "Subclasses should implement _apply_fn_to_data or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it" + ) def __repr__(self): - raise NotImplementedError("Subclasses must implement __repr__") + if hasattr(self, "tensor_data_names") and hasattr( + self, "tensor_attribute_names" + ): + repr_str = "" + # required tensor data + repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}" + for tensor_data_name in self.tensor_data_names[1:]: + repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}" + + # required attributes + for tensor_attribute_name in self.tensor_attribute_names: + repr_str += ( + f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}" + ) + + # optional tensor data + if hasattr(self, "optional_tensor_data_names"): + for tensor_data_name in self.optional_tensor_data_names: + repr_str += ( + f", {tensor_data_name}={getattr(self, tensor_data_name)}" + ) + + # optional tensor attributes + if hasattr(self, "optional_tensor_attribute_names"): + for tensor_attribute_name in self.optional_tensor_attribute_names: + repr_str += f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}" + + return f"{self.__class__.__name__}({repr_str})" + + raise NotImplementedError( + "Subclasses must implement __repr__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it" + ) def get_layout(self): if not hasattr(self, "_layout"): @@ -618,11 +1047,6 @@ def fill_defaults(args, n, defaults_tail): return r -## Deprecated, will be deleted in the future -def _torch_version_at_least(min_version): - return is_fbcode() or version("torch") >= min_version - - # Supported AMD GPU Models and their LLVM gfx Codes: # # | AMD GPU Model | LLVM gfx Code | @@ -696,19 +1120,17 @@ def is_sm_at_least_100(): def check_cpu_version(device, version="2.6.0"): if isinstance(device, torch.device): device = device.type - return device == "cpu" and compare_versions(torch.__version__, version) >= 0 + return device == "cpu" and torch_version_at_least(version) def check_xpu_version(device, version="2.8.0"): if isinstance(device, torch.device): device = device.type - return device == "xpu" and compare_versions(torch.__version__, version) >= 0 + return device == "xpu" and torch_version_at_least(version) -TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") -TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") -TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") -TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev") +def ceil_div(a, b): + return (a + b - 1) // b def is_package_at_least(package_name: str, min_version: str): @@ -717,3 +1139,28 @@ def is_package_at_least(package_name: str, min_version: str): return False return version(package_name) >= min_version + + +def _is_fbgemm_genai_gpu_available(): + # TODO: use is_package_at_least("fbgemm_gpu", "1.2.0") when + # https://github.com/pytorch/FBGEMM/issues/4198 is fixed + if importlib.util.find_spec("fbgemm_gpu") is None: + return False + + import fbgemm_gpu.experimental.gen_ai # noqa: F401 + + if not is_fbcode() and fbgemm_gpu.__version__ < "1.2.0": + return False + + return True + + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + super().__init__() + self.weight = weight + self.bias = bias diff --git a/tutorials/calibration_flow/awq_like.py b/tutorials/calibration_flow/awq_like.py index 19ce9e872e..2e36626fed 100644 --- a/tutorials/calibration_flow/awq_like.py +++ b/tutorials/calibration_flow/awq_like.py @@ -121,9 +121,12 @@ def weight_quant_func(weight): weight, weight_scale, weight_zero_point, block_size, target_dtype ) elif target_dtype == torch.float8_e4m3fn: + scale_2d = ( + weight_scale.view(1, -1) if weight_scale.dim() == 1 else weight_scale + ) return to_affine_quantized_floatx_static( weight, - weight_scale, + scale_2d, block_size, target_dtype, Float8Layout(mm_config=None), diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index df824e506f..43affcdf3f 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -48,11 +48,11 @@ LinearActivationQuantizedTensor, MappingType, PerTensor, - _fake_quantize_affine, quantize_, to_linear_activation_quantized, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.quantization.quant_primitives import _fake_quantize_affine from torchao.quantization.transform_module import ( register_quantize_module_handler, ) diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index faaa9b1ae9..bc999b49d4 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -24,11 +24,11 @@ # for torch 2.4+ from torchao.quantization.quant_api import ( - int8_dynamic_activation_int8_weight, + Int8DynamicActivationInt8WeightConfig, quantize_, ) -quantize_(model, int8_dynamic_activation_int8_weight()) +quantize_(model, Int8DynamicActivationInt8WeightConfig()) ## Quantization code - end ## compilation configs @@ -37,12 +37,6 @@ torch._inductor.config.use_mixed_mm = True ## compilation configs end -# temporary workaround for the API to work with torch.compile -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass - -if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - # temporary workaround to recover the perf with quantized model under torch.compile torch.backends.mha.set_fastpath_enabled(False) diff --git a/version.txt b/version.txt index ac454c6a1f..a803cc227f 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.12.0 +0.14.0