Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 147 additions & 30 deletions olive/passes/pytorch/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import logging
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable

Expand All @@ -14,6 +15,7 @@
from olive.common.quant.hf_utils import (
OliveHfQuantizationConfig,
OliveHfQuantizationMethod,
OliveHfQuantizationOverrideConfig,
replace_matching_submodules,
tie_quant_word_embeddings,
)
Expand Down Expand Up @@ -78,6 +80,106 @@ def get_quantizer_config(allow_embeds: bool = False) -> dict[str, PassConfigPara
}


def get_qkv_quantization_groups(wrapper: ModelWrapper, module_names: set[str] | None = None) -> list[tuple[str, ...]]:
"""Get attention input projection groups that must share quantization settings.

Names are resolved from ``wrapper.model.named_modules()`` to stay correct for any layer
container (``ModuleList``, ``ModuleDict``, custom containers) and for unpacked QKV
submodules. When ``module_names`` is provided, attention inputs not in the set are
dropped from the group. Groups with fewer than two members are skipped.
"""
module_to_name = {id(module): name for name, module in wrapper.model.named_modules()}
qkv_groups = []
for layer_wrapper in wrapper.get_layer_wrappers():
attn_inputs, _ = layer_wrapper.get_attention_inputs()
group = tuple(
name
for name in (module_to_name.get(id(module)) for module in attn_inputs)
if name is not None and (module_names is None or name in module_names)
)
if len(group) > 1:
qkv_groups.append(group)
return qkv_groups


def _quant_config_rank(qargs: dict[str, int | bool]) -> tuple[int, int, int]:
"""Rank quantization configs by precision; higher rank means more precise.

Ordering: higher ``bits`` wins; among equal bits, smaller positive ``group_size`` wins;
per-channel (``-1``) wins over per-tensor (``0``) but loses to positive group sizes.
``symmetric`` is intentionally not part of the ordering since it is a representation
choice rather than a strict precision axis.
"""
bits = qargs["bits"].value if hasattr(qargs["bits"], "value") else qargs["bits"]
group_size = qargs["group_size"]
if group_size > 0:
group_size_rank = (2, -group_size)
elif group_size == -1:
group_size_rank = (1, 0)
else:
group_size_rank = (0, 0)
return bits, *group_size_rank


def normalize_qkv_quant_config(
wrapper: ModelWrapper,
qcfg: OliveHfQuantizationConfig,
locked_modules: set[str] | None = None,
) -> OliveHfQuantizationConfig:
"""Promote split QKV projection overrides to one shared quantization config.

Groups span all attention input projections of a layer regardless of whether the current
pass quantizes them; follow-up passes (e.g. RTN after AutoClip) will pick up the shared
settings via the recorded overrides so downstream QKV fusion remains valid.

``locked_modules`` are modules whose overrides must not be rewritten -- typically the
pre-existing overrides of an already-quantized checkpoint. For a group containing a
locked member, the shared config is forced to that locked member's config; if multiple
locked members of one group disagree, the group is left untouched.
"""
locked_modules = locked_modules or set()
for group in get_qkv_quantization_groups(wrapper):
group_qargs = {name: qcfg.get_qlinear_init_args(name) for name in group}
if len({tuple(qargs.items()) for qargs in group_qargs.values()}) == 1:
continue

locked_in_group = [name for name in group if name in locked_modules]
locked_configs = {tuple(group_qargs[name].items()) for name in locked_in_group}
if len(locked_configs) > 1:
logger.debug(
"QKV group %s contains already-quantized members with conflicting configs; "
"skipping (downstream QKV fusion may be inhibited).",
group,
)
continue
promoted_qargs = (
group_qargs[locked_in_group[0]] if locked_in_group else max(group_qargs.values(), key=_quant_config_rank)
)

logger.debug("Promoting QKV group %s to shared quantization config %s", group, promoted_qargs)
for name in group:
if name in locked_modules:
continue
override = {k: v for k, v in promoted_qargs.items() if getattr(qcfg, k) != v}
if override:
qcfg.overrides[name] = OliveHfQuantizationOverrideConfig(**override)
else:
qcfg.overrides.pop(name, None)

return qcfg


def _collect_excluded_attn_inputs(wrapper: ModelWrapper) -> set[torch.nn.Module]:
excluded: set[torch.nn.Module] = set()
for layer_wrapper in wrapper.get_layer_wrappers():
attn_inputs, _ = layer_wrapper.get_attention_inputs()
if len(attn_inputs) == 1:
excluded.add(attn_inputs[0])
else:
excluded.update(attn_inputs[:2])
return excluded


def prepare_model(
model: HfModelHandler,
config: type[BasePassConfig],
Expand All @@ -99,42 +201,65 @@ def prepare_model(
if existing_qcfg := getattr(model.get_hf_model_config(), "quantization_config", None):
if not allow_quantized:
raise ValueError("Model is already quantized. Cannot quantize again using this pass.")
if not isinstance(existing_qcfg, dict):
existing_qcfg = existing_qcfg.to_dict()
# Always work on a fresh copy: the underlying HF config holds the original object
# (dict or dataclass) and we mutate ``existing_qcfg`` heavily below.
existing_qcfg = deepcopy(existing_qcfg) if isinstance(existing_qcfg, dict) else existing_qcfg.to_dict()
if existing_qcfg.get("quant_method", None) != OliveHfQuantizationMethod.OLIVE:
raise ValueError("Model has an existing quantization configuration that is not compatible with this pass.")

wrapper = ModelWrapper.from_model(load_hf_base_model(model))
wrapper.model.eval()

qcfg = get_quant_config(model, config)
excluded_attn_inputs = _collect_excluded_attn_inputs(wrapper) if exclude_attn_inputs else set()

fresh_qcfg = normalize_qkv_quant_config(wrapper, get_quant_config(model, config))

originally_tied_embeddings = wrapper.config.tie_word_embeddings
if qcfg.lm_head or qcfg.embeds:
if fresh_qcfg.lm_head or fresh_qcfg.embeds:
wrapper.maybe_untie_word_embeddings()

lm_head_name = wrapper.get_lm_head()[1]
embeds_name = wrapper.get_embeds()[1][0]
new_qargs: dict[str, dict[str, int | bool]] = {}

excluded_attn_inputs: set[torch.nn.Module] = set()
if exclude_attn_inputs:
for layer_wrapper in wrapper.get_layer_wrappers():
attn_inputs, _ = layer_wrapper.get_attention_inputs()
if len(attn_inputs) == 1:
excluded_attn_inputs.add(attn_inputs[0])
else:
excluded_attn_inputs.update(attn_inputs[:2])

def should_quantize(module: torch.nn.Module, name: str) -> bool:
if module in excluded_attn_inputs:
return False
if isinstance(module, torch.nn.Linear):
return name != lm_head_name or qcfg.lm_head
if qcfg.embeds and isinstance(module, torch.nn.Embedding):
return name != lm_head_name or fresh_qcfg.lm_head
if fresh_qcfg.embeds and isinstance(module, torch.nn.Embedding):
return name == embeds_name
return False

# Pre-existing quantized weights are immutable. If we're merging with an existing
# checkpoint, build the final qcfg first (merge fresh into existing, then renormalize
# QKV with already-quantized modules locked) so that the quant_info we attach below
# uses the same settings the on-disk fusion will require. Every module that is already
# a QuantLinear/QuantEmbedding after load is on-disk-immutable, including those that
# used the existing config's defaults (no explicit override entry).
on_disk_overrides: set[str] = set()
already_quantized: set[str] = set()
if existing_qcfg:
on_disk_overrides = set((existing_qcfg.get("overrides") or {}).keys())
already_quantized = {
name for name, module in wrapper.model.named_modules() if isinstance(module, (QuantLinear, QuantEmbedding))
}
fresh_names = {name for name, module in wrapper.model.named_modules() if should_quantize(module, name)}
Comment thread
jambayk marked this conversation as resolved.
merged = existing_qcfg
merged["overrides"] = existing_qcfg.get("overrides") or {}
for name in fresh_names:
qargs = fresh_qcfg.get_qlinear_init_args(name)
override = {k: v for k, v in qargs.items() if merged[k] != v}
if override:
merged["overrides"][name] = override
merged["lm_head"] |= fresh_qcfg.lm_head
merged["embeds"] |= fresh_qcfg.embeds
qcfg = OliveHfQuantizationConfig(**merged)
qcfg = normalize_qkv_quant_config(wrapper, qcfg, locked_modules=already_quantized)
else:
qcfg = fresh_qcfg

new_qargs: dict[str, dict[str, int | bool]] = {}

def add_quant_info(module: torch.nn.Module, name: str) -> torch.nn.Module:
# TODO(jambayk): validate that the module and config are compatible
qargs = qcfg.get_qlinear_init_args(name)
Expand All @@ -144,23 +269,15 @@ def add_quant_info(module: torch.nn.Module, name: str) -> torch.nn.Module:

replace_matching_submodules(wrapper.model, should_quantize, add_quant_info, description="Preparing model")

# remove overrides for modules not being quantized
# Drop overrides for modules that won't be quantized this pass. Pre-existing (on-disk)
# overrides are preserved verbatim since they describe already-quantized weights.
# QKV-group overrides for modules excluded from this pass are not kept: when the
# follow-up pass runs, the quantized members in the group will be locked and pull the
# remaining members back into the shared config via ``normalize_qkv_quant_config``.
for name in list(qcfg.overrides or {}):
if name not in new_qargs:
if name not in new_qargs and name not in on_disk_overrides:
qcfg.overrides.pop(name)

# merge the new_quant_settings into the existing quant_config
if existing_qcfg:
merged_qcfg_dict = existing_qcfg
merged_qcfg_dict["overrides"] = existing_qcfg.get("overrides") or {}
for name, qargs in new_qargs.items():
override = {k: v for k, v in qargs.items() if merged_qcfg_dict[k] != v}
if override:
merged_qcfg_dict["overrides"][name] = override
merged_qcfg_dict["lm_head"] |= qcfg.lm_head
merged_qcfg_dict["embeds"] |= qcfg.embeds
qcfg = OliveHfQuantizationConfig(**merged_qcfg_dict)

word_embeddings_eligible_for_tieing = (
originally_tied_embeddings
and embeds_name in new_qargs
Expand Down
Loading
Loading