Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def q8ta_linear(

out = torch.nn.functional.linear(x, weights)
if bias is not None:
out = out + bias
out = out + bias[: out.shape[-1]]

if activation == "relu":
out = torch.nn.functional.relu(out)
Expand Down Expand Up @@ -455,7 +455,7 @@ def q8ta_linear_gemv(

out = torch.nn.functional.linear(x, weights)
if bias is not None:
out = out + bias
out = out + bias[: out.shape[-1]]

if activation == "relu":
out = torch.nn.functional.relu(out)
Expand Down
162 changes: 114 additions & 48 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901
self.match_found = False
self.all_nodes = [self.anchor_node]

# addmm(bias, mat1, mat2) has a different arg layout than
# mm(mat1, mat2) and linear(input, weight, bias?)
is_addmm = self.anchor_node.target == exir_ops.edge.aten.addmm.default
weight_arg_idx = 2 if is_addmm else 1
input_arg_idx = 1 if is_addmm else 0

const_node, arg_chain = utils.trace_args_until_placeholder(
self.anchor_node.args[1]
self.anchor_node.args[weight_arg_idx]
)

# mat2 is not a constant tensor - no match
Expand Down Expand Up @@ -84,26 +90,64 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901
# Identify output node
self.output_node = self.anchor_node

# The implementation has a limitation that output channels must be a
# multiple of 4. This is to ensure that data loads are aligned well with
# texel boundaries. If this is not true, then don't match the pattern.
out_channels = self.output_node.meta["val"].shape[-1]
if out_channels % 4 != 0:
return
# Identify primary input node of the anchor. Due to decomposition of aten.linear
# there may be a view_copy node between the original input tensor to the linear
# op and the actual linear op node.
anchor_primary_input_node = self.anchor_node.args[input_arg_idx]
assert isinstance(anchor_primary_input_node, torch.fx.Node)

# Skip potential view_copy between dq and linear
if utils.is_view_copy_node(anchor_primary_input_node):
self.all_nodes.append(anchor_primary_input_node)
anchor_primary_input_node = anchor_primary_input_node.args[
0
] # pyre-ignore[16]
assert isinstance(anchor_primary_input_node, torch.fx.Node)

# By default, assume that the input tensor is not quantized in any way
self.quantize_input_node = None
self.dequantize_input_node = None
self.pattern_input_node = anchor_primary_input_node

self.input_scales_node = None
self.input_zeros_node = None

scales_arg_idx = 1
zeros_arg_idx = 2

# Identify input node
(
self.fp_input_node,
self.quantize_input_node,
dq_node,
) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
assert self.fp_input_node is not None
self.all_nodes.append(self.fp_input_node)
# If the primary input node comes from a dequantize node, that implies the input
# input tensor is quantized (either statically or dynamically).
if utils.is_dequant_node(anchor_primary_input_node):
# Assume that this is a static quantization pattern; the input to the
# pattern is a statically quantized int8 tensor.
self.dequantize_input_node = anchor_primary_input_node
self.all_nodes.append(self.dequantize_input_node)
input_to_dq_node = self.dequantize_input_node.args[0]
self.pattern_input_node = input_to_dq_node

# torchao dequantize has a slightly different function schema
if (
self.dequantize_input_node.target
== exir_ops.edge.torchao.dequantize_affine.default
):
scales_arg_idx = 2
zeros_arg_idx = 3

self.input_scales_node = self.dequantize_input_node.args[scales_arg_idx]
self.input_zeros_node = self.dequantize_input_node.args[zeros_arg_idx]

# Check for dynamic quantization: input scales are dynamically
# computed via a choose_qparams op
if utils.is_quant_node(input_to_dq_node) and utils.is_dynamic_qscale(
self.input_scales_node
):
self.quantize_input_node = input_to_dq_node
self.pattern_input_node = self.quantize_input_node.args[0]

# The implementation has a limitation that input channels must be a
# multiple of 4. This is to ensure that data loads are aligned well with
# texel boundaries. If this is not true, then don't match the pattern.
in_channels = self.fp_input_node.meta["val"].shape[-1]
in_channels = self.pattern_input_node.meta["val"].shape[-1]
if in_channels % 4 != 0:
return

Expand All @@ -124,32 +168,10 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901
self.all_nodes.extend(arg_chain)

# If input is not quantized, then we are done
if self.quantize_input_node is None:
if self.dequantize_input_node is None:
self.match_found = True
return

scales_arg_idx = 1
zeros_arg_idx = 2

# torchao op has a slightly different function schema
if (
self.quantize_input_node.target
== exir_ops.edge.torchao.quantize_affine.default
):
scales_arg_idx = 2
zeros_arg_idx = 3

self.input_scales_node = self.quantize_input_node.args[scales_arg_idx]
self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx]

assert dq_node is not None
self.all_nodes.extend(
[
self.quantize_input_node,
dq_node,
]
)

# Check if the output is also quantized (q → dq → linear → q pattern)
# Also handle fused linear+relu (q → dq → linear → relu → q pattern)
self.quantize_output_node = None
Expand All @@ -172,7 +194,7 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901
self.match_found = True

def is_weight_only_quantized(self) -> bool:
return self.quantize_input_node is None
return self.dequantize_input_node is None

def has_output_quantization(self) -> bool:
return (
Expand Down Expand Up @@ -204,15 +226,15 @@ def is_weight_perchannel_quantized(self) -> bool:
return scales_shape[0] == weight_shape[-2]

def is_input_static_per_tensor_quantized(self) -> bool:
if self.quantize_input_node is None:
if self.dequantize_input_node is None:
return False

# For static quantization per tensor quantization, the scales and zeros
# are scalars.
return isinstance(self.input_scales_node, float)

def is_input_dynamic_perchannel_quantized(self) -> bool:
if self.quantize_input_node is None:
if self.dequantize_input_node is None:
return False

if not isinstance(self.input_scales_node, torch.fx.Node):
Expand All @@ -228,7 +250,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool:
return False

scales_shape = self.input_scales_node.meta["val"].shape
input_shape = self.fp_input_node.meta["val"].shape
input_shape = self.pattern_input_node.meta["val"].shape

return input_shape[-2] == scales_shape[-1]

Expand Down Expand Up @@ -366,7 +388,7 @@ def make_linear_q4gsw_op(
"call_function",
exir_ops.edge.et_vk.linear_q4gsw.default,
args=(
match.fp_input_node,
match.pattern_input_node,
match.weight_node,
match.weight_scales_node,
group_size,
Expand Down Expand Up @@ -430,7 +452,7 @@ def make_linear_dq8ca_q4gsw_op(
"call_function",
exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default,
args=(
match.fp_input_node,
match.pattern_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
Expand All @@ -450,12 +472,34 @@ def make_linear_q8ta_q8csw_custom_op(
match: QuantizedLinearMatch,
weight_tensor: torch.Tensor,
):
# Pad weight_scales to multiple of 4 so GPU shader reads don't go OOB
weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node)
assert weight_scales_tensor is not None
utils.align_width_and_update_state_dict(
ep, match.weight_scales_node, weight_scales_tensor
)

# Pad bias to multiple of 4 if present
if match.bias_node is not None:
bias_tensor = get_param_tensor(ep, match.bias_node)
if bias_tensor is not None:
utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor)

first_graph_node = list(graph_module.graph.nodes)[0]
with graph_module.graph.inserting_before(first_graph_node):
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
# Pre-compute the weight sums which are needed to apply activation zero point
# when using integer accumulation.
sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous()

# Pad weight sums to align OC to multiple of 4
oc = sum_per_output_channel.shape[0]
if oc % 4 != 0:
num_padding = 4 - (oc % 4)
sum_per_output_channel = F.pad(
sum_per_output_channel, (0, num_padding)
).contiguous()

sums_name = weight_tensor_name + "_sums"
# Sanitize the name
sums_name = sums_name.replace(".", "_")
Expand All @@ -473,7 +517,7 @@ def make_linear_q8ta_q8csw_custom_op(
"call_function",
exir_ops.edge.et_vk.linear_q8ta_q8csw.default,
args=(
match.fp_input_node,
match.pattern_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
Expand All @@ -492,10 +536,32 @@ def make_q8ta_linear_custom_op(
match: QuantizedLinearMatch,
weight_tensor: torch.Tensor,
):
# Pad weight_scales to multiple of 4 so GPU shader reads don't go OOB
weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node)
assert weight_scales_tensor is not None
utils.align_width_and_update_state_dict(
ep, match.weight_scales_node, weight_scales_tensor
)

# Pad bias to multiple of 4 if present
if match.bias_node is not None:
bias_tensor = get_param_tensor(ep, match.bias_node)
if bias_tensor is not None:
utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor)

first_graph_node = list(graph_module.graph.nodes)[0]
with graph_module.graph.inserting_before(first_graph_node):
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous()

# Pad weight sums to align OC to multiple of 4
oc = sum_per_output_channel.shape[0]
if oc % 4 != 0:
num_padding = 4 - (oc % 4)
sum_per_output_channel = F.pad(
sum_per_output_channel, (0, num_padding)
).contiguous()

sums_name = weight_tensor_name + "_sums"
sums_name = sums_name.replace(".", "_")

Expand All @@ -508,7 +574,7 @@ def make_q8ta_linear_custom_op(
)

# Use gemv variant when batch size is 1
input_shape = match.fp_input_node.meta["val"].shape
input_shape = match.pattern_input_node.meta["val"].shape
batch_size = input_shape[-2] if len(input_shape) >= 2 else 1
if batch_size == 1:
op_target = exir_ops.edge.et_vk.q8ta_linear_gemv.default
Expand All @@ -520,7 +586,7 @@ def make_q8ta_linear_custom_op(
"call_function",
op_target,
args=(
match.quantize_input_node,
match.pattern_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
Expand Down
Loading
Loading