Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 165 additions & 3 deletions backends/vulkan/patterns/quantized_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from executorch.exir.dialects._ops import ops as exir_ops


embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype
embedding_target = exir_ops.edge.aten.embedding.default
torchao_dequantize_affine_target = exir_ops.edge.torchao.dequantize_affine.default


class QuantizedEmbeddingMatch(PatternMatch):
def __init__(self, node: torch.fx.Node) -> None:
self.anchor_node = node
Expand Down Expand Up @@ -68,9 +73,6 @@ def __init__(self, node: torch.fx.Node) -> None:
self.match_found = True


embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype


def _detect_tied_linear_weight(
ep: ExportedProgram,
weight_node: torch.fx.Node,
Expand Down Expand Up @@ -175,3 +177,163 @@ def replace_quantized_embedding_patterns(

embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"]
match.anchor_node.replace_all_uses_with(embedding_q4gsw_node)


class TorchAOQuantizedEmbeddingMatch(PatternMatch):
"""Matches a torchao 4-bit weight-only quantized embedding and rewrites it
as a single et_vk.embedding_q4gsw.default node.

The recognized graph shape is a split torchao.dequantize_affine ->
aten.embedding, whose weight is unpacked int8 [vocab, embed_dim] with values
in [-8, 7]. This requires symmetric 4-bit signed quantization (quant_min=-8,
quant_max=7, zero_point=0) and per-row groupwise blocks (block_size=[1, G]),
which the runtime shader assumes via a fixed subtract-8 offset.
"""

def __init__(self, node: torch.fx.Node) -> None:
self.anchor_node = node
self.match_found = False
self.all_nodes = [node]

# aten.embedding.default args: (weight, indices, *)
dequant_node = node.args[0]
self.indices_node = node.args[1]

if not isinstance(dequant_node, torch.fx.Node):
return
if dequant_node.target != torchao_dequantize_affine_target:
return

self.all_nodes.append(dequant_node)

# torchao.dequantize_affine args:
# (input, block_size, scale, zero_point, input_dtype, quant_min,
# quant_max, ...)
block_size = dequant_node.args[1]
quant_min = dequant_node.args[5] if len(dequant_node.args) > 5 else None
quant_max = dequant_node.args[6] if len(dequant_node.args) > 6 else None

# The shader hardcodes the 4-bit signed offset (subtract 8), which
# corresponds to quant_min=-8, quant_max=7, zero_point=0.
if quant_min != -8 or quant_max != 7:
return

# block_size must be per-row groupwise: [1, group_size]
if not isinstance(block_size, (list, tuple)) or len(block_size) != 2:
return
if block_size[0] != 1:
return
self.group_size = int(block_size[1])

# Trace weight (args[0]), scales (args[2]) and zero_point (args[3]) to
# their placeholders. The symmetric (zero_point == 0) requirement is
# verified on the real tensor in the replacement function, where the
# ExportedProgram is available; checking the fake meta tensor here would
# trigger a data-dependent guard error.
weight_node, arg_chain = utils.trace_args_until_placeholder(
dequant_node.args[0]
)
if weight_node is None:
return
self.weight_node = weight_node
self.all_nodes.extend(arg_chain)

scales_node, arg_chain = utils.trace_args_until_placeholder(
dequant_node.args[2]
)
if scales_node is None:
return
self.scales_node = scales_node
self.all_nodes.extend(arg_chain)

self.zero_point_node, arg_chain = utils.trace_args_until_placeholder(
dequant_node.args[3]
)
self.all_nodes.extend(arg_chain)

self.match_found = True


@register_pattern_detector("torchao_quantized_embedding")
def find_torchao_quantized_embedding_patterns(
node: torch.fx.Node,
) -> Optional[TorchAOQuantizedEmbeddingMatch]:
if node.target != embedding_target:
return None

matched_pattern = TorchAOQuantizedEmbeddingMatch(node)
if matched_pattern.match_found:
return matched_pattern
return None


@register_pattern_replacement("torchao_quantized_embedding")
def replace_torchao_quantized_embedding_patterns(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: TorchAOQuantizedEmbeddingMatch,
):
weight_tensor = get_param_tensor(ep, match.weight_node)
assert weight_tensor is not None

# The weight repack mutates the state dict entry in place, so it must run
# exactly once per backing storage; a second repack of the already-packed
# weight would corrupt it. The repack
# (align_width_and_update_state_dict -> update_program_state_dict) locates
# the entry to overwrite by the param/buffer FQN that backs the placeholder,
# so the idempotency guard keys on that same FQN (via
# utils.register_param_mutation). This dedups not only one placeholder
# shared by multiple call sites, but also distinct placeholder nodes that
# resolve to the same state dict storage (whose per-node meta would otherwise
# diverge). Distinct weights (distinct FQNs) still each pack once. The guard
# also raises if the same weight is later re-mutated with a different tag
# (i.e. an incompatible packing format), surfacing corruption loudly.
if utils.register_param_mutation(ep, match.weight_node, "embedding_q4gsw"):
# The shader applies a fixed signed-4-bit offset (subtract 8), which
# assumes symmetric quantization (zero_point == 0). Verify on the real
# tensor.
if match.zero_point_node is not None:
zero_point_tensor = get_param_tensor(ep, match.zero_point_node)
if zero_point_tensor is not None:
assert torch.all(
zero_point_tensor == 0
), "embedding_q4gsw requires symmetric quantization (zero_point == 0)"

# Repack the unpacked int8 weight [vocab, embed_dim] (values in [-8, 7])
# into the flat 4-bit packed format [vocab, embed_dim / 2] that the
# non-linear embedding_q4gsw path expects. Packing convention (must
# match the runtime shader and embedding_q4gsw_impl):
# packed_byte = (even_val + 8) << 4 | (odd_val + 8)
# i.e. the even-index value goes in the high nibble, odd-index in the
# low.
unpacked_u8 = weight_tensor.to(torch.uint8) + 8
packed_weight = (unpacked_u8[:, ::2] << 4 | unpacked_u8[:, 1::2]).to(
torch.uint8
)

# Update the weight placeholder's state dict entry and fake-tensor meta
# to the repacked tensor. align_to=1 with force_update just forces the
# update; the packed width (embed_dim / 2) is already a multiple of 4.
utils.align_width_and_update_state_dict(
ep, match.weight_node, packed_weight, align_to=1, force_update=True
)

# Scales are symmetric per-group with layout [vocab, num_groups], matching
# the scale layout embedding_q4gsw expects (no transpose).
group_size = match.group_size

with graph_module.graph.inserting_before(match.anchor_node):
embedding_q4gsw_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.embedding_q4gsw.default,
args=(
match.weight_node,
match.scales_node,
group_size,
match.indices_node,
False,
),
)

embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"]
match.anchor_node.replace_all_uses_with(embedding_q4gsw_node)
Loading
Loading