From 5c05f891d4489a62305575b887501f1a4c4f55db Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 5 Dec 2024 08:24:59 -0800 Subject: [PATCH] [Executorch][custom ops] Change lib loading logic to account for package dir Pull Request resolved: https://github.com/pytorch/executorch/pull/7038 Just looking at the location of the source file. In this case custom_ops.py, can, and does, yield to wrong location depending on where you import custom_ops from. If you are importing custom_ops from another source file inside extension folder, e.g. builder.py that is in extensions/llm/export, then, I think, custom_ops gets resolved to the one installed in site-packages or pip package. But if this is imported from say examples/models/llama/source_transformations/quantized_kv_cache.py (Like in the in next PR), then it seems to resolve to the source location. In one of the CI this is /pytorch/executorch. Now depending on which directory your filepath resolves to, you will search for lib in that. This of course does not work when filepath resolves to source location. This PR changes that to resolve to package location. ghstack-source-id: 256711930 //unit-test-arm broken in trunk @bypass-github-export-checks @exported-using-ghexport Differential Revision: [D66385480](https://our.internmc.facebook.com/intern/diff/D66385480/) --- extension/llm/custom_ops/custom_ops.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 26dac551a30..3570e34d192 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -11,7 +11,6 @@ # pyre-unsafe import logging -from pathlib import Path import torch @@ -23,7 +22,17 @@ op2 = torch.ops.llama.fast_hadamard_transform.default assert op2 is not None except: - libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*")) + import glob + + import executorch + + executorch_package_path = executorch.__path__[0] + logging.info(f"Looking for libcustom_ops_aot_lib.so in {executorch_package_path}") + libs = list( + glob.glob( + f"{executorch_package_path}/**/libcustom_ops_aot_lib.*", recursive=True + ) + ) assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" logging.info(f"Loading custom ops library: {libs[0]}") torch.ops.load_library(libs[0])