Skip to content

Commit dceb1a1

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] enabling double support for quantization and dequantization ops"
With the added double support in the layout template, this diff is enabling it as input/output for dequantization. Since there are limitations with how 64bit can be supported, the expectation is that IO be downgraded to 32bit Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/) [ghstack-poisoned]
2 parents ed60d2d + 6dab7cc commit dceb1a1

17 files changed

Lines changed: 483 additions & 91 deletions

File tree

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -807,14 +807,19 @@ TEST(
807807

808808
TEST(
809809
VulkanDequantizePerTensorTest,
810-
test_vulkan_dequantize_per_tensor_int32_to_double) {
810+
test_vulkan_dequantize_per_tensor_int8_to_double) {
811+
if (!vkcompute::api::context()
812+
->adapter_ptr()
813+
->has_full_int8_buffers_support()) {
814+
GTEST_SKIP();
815+
}
811816
test_vulkan_dequantize_per_tensor(
812-
{2, 4, 3}, // input sizes
813-
0.0001, // scale
814-
100, // zero_point
815-
-2147483648, // quant_min
816-
2147483647, // quant_max
817-
at::kInt, // input dtype
817+
{2, 3}, // input sizes
818+
0.05, // scale
819+
10, // zero_point
820+
-128, // quant_min
821+
127, // quant_max
822+
at::kChar, // input dtype
818823
at::kDouble); // output dtype
819824
}
820825

@@ -1316,16 +1321,21 @@ TEST(
13161321

13171322
TEST(
13181323
VulkanDequantizePerTokenTest,
1319-
test_vulkan_dequantize_per_token_int32_to_double) {
1320-
std::vector<float> scales = {0.0001, 0.0002, 0.0003, 0.0};
1321-
std::vector<int> zero_points = {100, -100, 50, -50};
1324+
test_vulkan_dequantize_per_token_int8_to_double) {
1325+
if (!vkcompute::api::context()
1326+
->adapter_ptr()
1327+
->has_full_int8_buffers_support()) {
1328+
GTEST_SKIP();
1329+
}
1330+
std::vector<float> scales = {0.05, 0.001};
1331+
std::vector<int> zero_points = {10, -5};
13221332

13231333
test_vulkan_dequantize_per_token(
1324-
{2, 2, 8}, // input sizes (2*2=4 tokens)
1334+
{2, 2}, // input sizes (2 tokens)
13251335
scales,
13261336
zero_points,
1327-
-2147483648, // quant_min
1328-
2147483647, // quant_max
1329-
at::kInt, // input dtype
1337+
-128, // quant_min
1338+
127, // quant_max
1339+
at::kChar, // input dtype
13301340
at::kDouble); // output dtype
13311341
}

examples/models/llama/TARGETS

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ runtime.python_binary(
8585
":export_library",
8686
"//caffe2:torch",
8787
"//executorch/extension/pybindings:aten_lib",
88+
"//executorch/extension/llm/export:export_llm_lib",
8889
],
8990
)
9091

@@ -133,8 +134,6 @@ runtime.python_library(
133134
name = "export_library",
134135
srcs = [
135136
"export_llama.py",
136-
"export_llama_args.py",
137-
"export_llama_hydra.py",
138137
"export_llama_lib.py",
139138
"model.py",
140139
],

examples/models/llama/config/llm_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class BaseConfig:
8686
checkpoint_dir: Optional[str] = None
8787
tokenizer_path: Optional[str] = None
8888
metadata: Optional[str] = None
89-
use_lora: int = int
89+
use_lora: int = 0
9090
fairseq2: bool = False
9191
preq_mode: Optional[PreqMode] = None
9292
preq_group_size: int = 32
@@ -214,7 +214,7 @@ class ExportConfig:
214214

215215
max_seq_length: int = 128
216216
max_context_length: int = 128
217-
output_dir: Optional[str] = None
217+
output_dir: str = "."
218218
output_name: Optional[str] = None
219219
so_library: Optional[str] = None
220220
export_only: bool = False

examples/models/llama/export_llama.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717

1818
import torch
1919

20+
from executorch.examples.models.llama.export_llama_lib import (
21+
build_args_parser,
22+
export_llama,
23+
)
24+
2025
sys.setrecursionlimit(4096)
2126

2227

@@ -39,15 +44,12 @@ def main() -> None:
3944
sys.argv = [arg for arg in sys.argv if arg != "--hydra"]
4045
print(f"running with {sys.argv}")
4146
runpy.run_module(
42-
"executorch.examples.models.llama.export_llama_hydra", run_name="__main__"
47+
"executorch.extension.llm.export.export_llm", run_name="__main__"
4348
)
4449
else:
45-
# Use the legacy version of the export_llama script which uses argsparse.
46-
from executorch.examples.models.llama.export_llama_args import (
47-
main as export_llama_args_main,
48-
)
49-
50-
export_llama_args_main(remaining_args)
50+
parser = build_args_parser()
51+
remaining_args = parser.parse_args(remaining_args)
52+
export_llama(remaining_args)
5153

5254

5355
if __name__ == "__main__":

examples/models/llama/export_llama_args.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

examples/models/llama/export_llama_hydra.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

extension/llm/export/TARGETS

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,41 @@ runtime.python_library(
4747
],
4848
)
4949

50+
runtime.python_binary(
51+
name = "export_llm",
52+
srcs = [
53+
"export_llm.py",
54+
],
55+
main_function = "executorch.extension.llm.export.export_llm.main",
56+
preload_deps = [
57+
"//executorch/extension/llm/custom_ops:model_sharding_py",
58+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
59+
"//executorch/kernels/quantized:aot_lib",
60+
],
61+
deps = [
62+
"fbsource//third-party/pypi/hydra-core:hydra-core",
63+
"fbsource//third-party/pypi/omegaconf:omegaconf",
64+
"//executorch/examples/models/llama:export_library",
65+
"//executorch/extension/pybindings:aten_lib",
66+
],
67+
)
68+
69+
runtime.python_library(
70+
name = "export_llm_lib",
71+
srcs = [
72+
"export_llm.py",
73+
],
74+
deps = [
75+
"fbsource//third-party/pypi/hydra-core:hydra-core",
76+
"fbsource//third-party/pypi/omegaconf:omegaconf",
77+
"//executorch/examples/models/llama:export_library",
78+
],
79+
visibility = [
80+
"//executorch/examples/...",
81+
"//executorch/extension/llm/...",
82+
],
83+
)
84+
5085
runtime.python_test(
5186
name = "export_passes_test",
5287
srcs = [

extension/llm/export/builder.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ def __init__(
133133
self.output_dir = "."
134134
self._saved_pte_filename = None
135135

136+
def __post_init__(self):
137+
"""
138+
Post init function to update metadata based on dynamic shape
139+
"""
140+
dynamic_shape = self._get_dynamic_shape()
141+
if dynamic_shape is not None:
142+
token_dim = dynamic_shape[0][1]
143+
if self.verbose:
144+
logging.info(
145+
f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}"
146+
)
147+
self.metadata["get_max_seq_len"] = token_dim.max
148+
136149
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
137150
"""
138151
Set the directory where the .pte file will be saved.
@@ -180,14 +193,19 @@ def _get_dynamic_shape(self) -> Any:
180193
if self.dynamic_shapes:
181194
return self.dynamic_shapes
182195

183-
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
184196
if self.enable_dynamic_shape:
185197
if not self.use_kv_cache:
186198
# Only one input argument: tokens
187-
self.dynamic_shapes = ({1: dim},)
199+
# Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
200+
self.dynamic_shapes = (
201+
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
202+
)
188203
else:
189204
# Two input arguments: tokens and input_pos but input_pos is static shape
190-
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
205+
self.dynamic_shapes = (
206+
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
207+
{"input_pos": {0: 1}},
208+
)
191209
else:
192210
# Two input arguments: tokens and input_pos but both are of static shape
193211
self.dynamic_shapes = None

extension/llm/export/export_llm.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Export an LLM with ExecuTorch. Currently follows the following steps:
9+
1. Instantiate our custom PyTorch transformer definition from examples/llama/models/llama_transformer.py.
10+
2. Load weights into the model.
11+
3. Apply source transformations/TorchAO quantization.
12+
4. Export model to intermediate IRs.
13+
5. Graph transformations/PT2E quantization.
14+
6. Partition graph and delegate to backend(s).
15+
7. Export to final ExecuTorch .pte format.
16+
17+
Example usage using full CLI arguments:
18+
python -m extension.llm.export.export_llm \
19+
base.model_class="llama3" \
20+
model.use_sdpa_with_kv_cache=True \
21+
model.use_kv_cache=True \
22+
debug.verbose=True \
23+
backend.xnnpack.enabled=True \
24+
backend.xnnpack.extended_ops=True \
25+
quantization.qmode="8da4w"
26+
"""
27+
28+
import hydra
29+
30+
from executorch.examples.models.llama.config.llm_config import LlmConfig
31+
from executorch.examples.models.llama.export_llama_lib import export_llama
32+
from hydra.core.config_store import ConfigStore
33+
from omegaconf import OmegaConf
34+
35+
cs = ConfigStore.instance()
36+
cs.store(name="llm_config", node=LlmConfig)
37+
38+
39+
@hydra.main(version_base=None, config_path=None, config_name="llm_config")
40+
def main(llm_config: LlmConfig) -> None:
41+
export_llama(OmegaConf.to_object(llm_config))
42+
43+
44+
if __name__ == "__main__":
45+
main()

extension/llm/export/test/test_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888
# Check first element (tokens dimension)
8989
self.assertIsInstance(result[0], dict)
9090
self.assertIn(1, result[0])
91-
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
91+
self.assertEqual(result[0][1].max, self.max_seq_len)
9292

9393
# Check second element (input_pos dimension)
9494
self.assertIsInstance(result[1], dict)

0 commit comments

Comments
 (0)