From 0cd2c8213185b56ba68cb252df4e225c5256bc60 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 20 Nov 2024 15:54:32 -0800 Subject: [PATCH 1/3] [Executorch][BE] Rename sdpa_with_kv_cache.py to custom_ops.py Because now we have more than sdpa_with_kv_cache in it Differential Revision: [D66269486](https://our.internmc.facebook.com/intern/diff/D66269486/) [ghstack-poisoned] --- examples/models/llama/runner/native.py | 2 +- examples/models/llama/source_transformation/sdpa.py | 2 +- examples/models/llava/test/test_llava.py | 2 +- examples/models/llava/test/test_pte.py | 2 +- extension/llm/README.md | 2 +- extension/llm/custom_ops/__init__.py | 1 + .../llm/custom_ops/{sdpa_with_kv_cache.py => custom_ops.py} | 0 extension/llm/custom_ops/test_sdpa_with_kv_cache.py | 2 +- 8 files changed, 7 insertions(+), 6 deletions(-) rename extension/llm/custom_ops/{sdpa_with_kv_cache.py => custom_ops.py} (100%) diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 62757506f3b..447394a85cc 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -23,7 +23,7 @@ from executorch.examples.models.llama.runner.generation import LlamaRunner # Note: import this after portable_lib -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +import executorch.extension.llm.custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index f8362648f32..44541f6eaac 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -99,7 +99,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: - from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa + import executorch.extension.llm.custom_ops # noqa _replace_sdpa_with_custom_op(module) return module diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py index 2e50bcecf49..f12c94ae4e2 100644 --- a/examples/models/llava/test/test_llava.py +++ b/examples/models/llava/test/test_llava.py @@ -18,7 +18,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +import executorch.extension.llm.custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa # usort: skip logging.basicConfig(level=logging.INFO) diff --git a/examples/models/llava/test/test_pte.py b/examples/models/llava/test/test_pte.py index 003b2b56755..80da91c2664 100644 --- a/examples/models/llava/test/test_pte.py +++ b/examples/models/llava/test/test_pte.py @@ -14,7 +14,7 @@ from PIL import Image # Custom ops has to be loaded after portable_lib. -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +import executorch.extension.llm.custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa # usort: skip FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" diff --git a/extension/llm/README.md b/extension/llm/README.md index ad504966824..0f71088eea1 100644 --- a/extension/llm/README.md +++ b/extension/llm/README.md @@ -38,7 +38,7 @@ A sampler class in C++ to sample the logistics given some hyperparameters. ## custom_ops Contains custom op, such as: - custom sdpa: implements CPU flash attention and avoids copies by taking the kv cache as one of its arguments. - - _sdpa_with_kv_cache.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration. + - _custom_ops.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration. - _op_sdpa.cpp_: the optimized operator implementation and registration of _sdpa_with_kv_cache.out_. ## runner diff --git a/extension/llm/custom_ops/__init__.py b/extension/llm/custom_ops/__init__.py index e69de29bb2d..4dbceac8d67 100644 --- a/extension/llm/custom_ops/__init__.py +++ b/extension/llm/custom_ops/__init__.py @@ -0,0 +1 @@ +from .custom_ops import * diff --git a/extension/llm/custom_ops/sdpa_with_kv_cache.py b/extension/llm/custom_ops/custom_ops.py similarity index 100% rename from extension/llm/custom_ops/sdpa_with_kv_cache.py rename to extension/llm/custom_ops/custom_ops.py diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index bfd64cb8975..9c8029c7b70 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -11,7 +11,7 @@ import torch import torch.nn.functional as F -from .sdpa_with_kv_cache import custom_ops_lib # noqa +from .custom_ops import custom_ops_lib # noqa def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len): From 1527cdc70099cf96327bf15ba7fe923b04a372f2 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 20 Nov 2024 16:31:40 -0800 Subject: [PATCH 2/3] Update on "[Executorch][BE] Rename sdpa_with_kv_cache.py to custom_ops.py" Because now we have more than sdpa_with_kv_cache in it Differential Revision: [D66269486](https://our.internmc.facebook.com/intern/diff/D66269486/) [ghstack-poisoned] --- extension/llm/custom_ops/targets.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index bb59f48a279..e3e8b30520f 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -81,7 +81,7 @@ def define_common_targets(): runtime.python_library( name = "custom_ops_aot_py", srcs = [ - "sdpa_with_kv_cache.py", + "custom_ops.py", ], visibility = [ "//executorch/...", From c8b282098764d9b38a7bd1fdad3854d598e3751e Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 21 Nov 2024 06:43:53 -0800 Subject: [PATCH 3/3] Update on "[Executorch][BE] Rename sdpa_with_kv_cache.py to custom_ops.py" Because now we have more than sdpa_with_kv_cache in it Differential Revision: [D66269486](https://our.internmc.facebook.com/intern/diff/D66269486/) [ghstack-poisoned]