From 42764edd99eae1c6f1f8e0c98318e37f6d829469 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 22 Nov 2024 12:59:14 -0800 Subject: [PATCH] [Executorch][custom ops] Change lib loading logic to account for package dir Just looking at the location of the source file. In this case custom_ops.py, can, and does, yield to wrong location depending on where you import custom_ops from. If you are importing custom_ops from another source file inside extension folder, e.g. builder.py that is in extensions/llm/export, then, I think, custom_ops gets resolved to the one installed in site-packages or pip package. But if this is imported from say examples/models/llama/source_transformations/quantized_kv_cache.py (Like in the in next PR), then it seems to resolve to the source location. In one of the CI this is /pytorch/executorch. Now depending on which directory your filepath resolves to, you will search for lib in that. This of course does not work when filepath resolves to source location. This PR changes that to resolve to package location. Differential Revision: [D66385480](https://our.internmc.facebook.com/intern/diff/D66385480/) [ghstack-poisoned] --- extension/llm/custom_ops/custom_ops.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 26dac551a30..c739793cc26 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -23,7 +23,17 @@ op2 = torch.ops.llama.fast_hadamard_transform.default assert op2 is not None except: - libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*")) + import glob + + import executorch + + executorch_package_path = executorch.__path__[0] + logging.info(f"Looking for libcustom_ops_aot_lib.so in {executorch_package_path }") + libs = list( + glob.glob( + f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True + ) + ) assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" logging.info(f"Loading custom ops library: {libs[0]}") torch.ops.load_library(libs[0])