From 434f6adec8aa0803ec9a2415560292d1e7488a45 Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Fri, 15 May 2026 20:36:47 +0000 Subject: [PATCH 1/2] fix(loader): load TE core with RTLD_LOCAL to stop rocroller symbol leak Switch libtransformer_engine.so from RTLD_GLOBAL to RTLD_LOCAL and link the torch/jax extensions against it explicitly via DT_NEEDED. This prevents librocroller.so symbols from interposing with HIP and fixes `free(): invalid size` in hipModuleLoad when TE is imported before MORI's shmem init on ROCm. --- build_tools/build_ext.py | 10 ++++++++++ build_tools/jax.py | 26 +++++++++++++++++++++++++- build_tools/pytorch.py | 21 +++++++++++++++++++++ transformer_engine/common/__init__.py | 12 ++++++++++-- 4 files changed, 66 insertions(+), 3 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 31bb4392d..7072e8c8a 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -115,6 +115,7 @@ class _CMakeBuildExtension(extension_cls): def run(self) -> None: # Build CMake extensions + cmake_install_dirs: List[Path] = [] for ext in self.extensions: package_path = Path(self.get_ext_fullpath(ext.name)) install_dir = package_path.resolve().parent @@ -135,12 +136,21 @@ def run(self) -> None: build_dir=build_dir, install_dir=install_dir, ) + cmake_install_dirs.append(install_dir) # Build non-CMake extensions as usual all_extensions = self.extensions self.extensions = [ ext for ext in self.extensions if not isinstance(ext, CMakeExtension) ] + # Make CMake-installed shared libraries (e.g. libtransformer_engine.so) + # discoverable when linking framework extensions that declare them as + # NEEDED dependencies. + for ext in self.extensions: + for cmake_install_dir in cmake_install_dirs: + cmake_install_dir_str = str(cmake_install_dir) + if cmake_install_dir_str not in ext.library_dirs: + ext.library_dirs.append(cmake_install_dir_str) super().run() self.extensions = all_extensions diff --git a/build_tools/jax.py b/build_tools/jax.py index 62dc4336e..d7aa2abb6 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -123,6 +123,29 @@ def setup_jax_extension( if rocm_build(): cxx_flags.extend(["-D__HIP_PLATFORM_AMD__", "-DUSE_ROCM"]) + # Link against the TE core library so the jax extension resolves core + # symbols via the ELF NEEDED graph rather than via RTLD_GLOBAL. This + # avoids transitively exposing librocroller.so symbols, which interpose + # with HIP's internal helpers and cause hipModuleLoad to abort with + # `free(): invalid size`. + # + # The CMake build of the core library runs before this extension is + # linked (see `_CMakeBuildExtension.run` in build_ext.py), and the + # CMake install directory is injected into `library_dirs` there so the + # linker can find `libtransformer_engine.so` even on a clean build. We + # additionally try to resolve a previously-built copy here so that + # incremental builds and tooling that links this extension in isolation + # still work. + libraries = ["nccl"] if not rocm_build() else [] + libraries.append("transformer_engine") + library_dirs: List[str] = [] + try: + from transformer_engine.common import _get_shared_object_file + core_lib_path = Path(_get_shared_object_file("core")) + library_dirs.append(str(core_lib_path.parent)) + except (ImportError, FileNotFoundError): + pass + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension @@ -131,5 +154,6 @@ def setup_jax_extension( sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, - libraries=["nccl"] if not rocm_build() else [], + libraries=libraries, + library_dirs=library_dirs, ) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index c797ffad7..c4036e7fb 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -115,6 +115,27 @@ def setup_pytorch_extension( libraries.append("mpi") cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"]) + # Link against the TE core library so the torch extension resolves core + # symbols via the ELF NEEDED graph rather than via RTLD_GLOBAL. This + # avoids transitively exposing librocroller.so symbols, which interpose + # with HIP's internal helpers and cause hipModuleLoad to abort with + # `free(): invalid size`. + # + # The CMake build of the core library runs before this extension is + # linked (see `_CMakeBuildExtension.run` in build_ext.py), and the + # CMake install directory is injected into `library_dirs` there so the + # linker can find `libtransformer_engine.so` even on a clean build. We + # additionally try to resolve a previously-built copy here so that + # incremental builds and tooling that links this extension in isolation + # still work. + libraries.append("transformer_engine") + try: + from transformer_engine.common import _get_shared_object_file + core_lib_path = Path(_get_shared_object_file("core")) + library_dirs.append(str(core_lib_path.parent)) + except (ImportError, FileNotFoundError): + pass + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 9dbf998e5..0438fcc40 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -409,8 +409,16 @@ def is_fp8_fnuz(): @functools.lru_cache(maxsize=None) def _load_core_library(): - """Load shared library with Transformer Engine C extensions""" - return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL) + """Load shared library with Transformer Engine C extensions. + + Loaded with RTLD_LOCAL so that transitive dependencies (notably + librocroller.so on ROCm) are not promoted into the global symbol + namespace, where they would interpose with HIP runtime helpers and + cause hipModuleLoad to abort with `free(): invalid size`. Framework + extensions therefore link against this library explicitly via ELF + NEEDED rather than relying on global symbol visibility. + """ + return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_LOCAL) if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): From c3fae8bd0318eaba2431cd89af5fd9226a27e6e1 Mon Sep 17 00:00:00 2001 From: sudhu2k Date: Fri, 15 May 2026 21:50:46 +0000 Subject: [PATCH 2/2] fix(build): unbreak framework-extension-only sdist build Replace `from transformer_engine.common import _get_shared_object_file` in `build_tools/{pytorch,jax}.py` with a new `installed_te_core_lib_dir` helper that locates an already-installed libtransformer_engine.so via sysconfig. The previous import eagerly loaded framework extensions and asserted they were installed, which broke builds of the framework extension itself (e.g. `transformer_engine_rocm_torch.tar.gz`). --- build_tools/build_ext.py | 2 +- build_tools/jax.py | 18 ++++-------------- build_tools/pytorch.py | 18 ++++-------------- build_tools/utils.py | 27 +++++++++++++++++++++++++++ 4 files changed, 36 insertions(+), 29 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 7072e8c8a..6cb7bdc98 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -145,7 +145,7 @@ def run(self) -> None: ] # Make CMake-installed shared libraries (e.g. libtransformer_engine.so) # discoverable when linking framework extensions that declare them as - # NEEDED dependencies. + # NEEDED dependencies on a fresh full-tree build. for ext in self.extensions: for cmake_install_dir in cmake_install_dirs: cmake_install_dir_str = str(cmake_install_dir) diff --git a/build_tools/jax.py b/build_tools/jax.py index d7aa2abb6..9144e4b44 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -12,6 +12,7 @@ from .utils import rocm_build, rocm_path from .utils import all_files_in_dir, get_cuda_include_dirs, debug_build_enabled +from .utils import installed_te_core_lib_dir from typing import List @@ -128,23 +129,12 @@ def setup_jax_extension( # avoids transitively exposing librocroller.so symbols, which interpose # with HIP's internal helpers and cause hipModuleLoad to abort with # `free(): invalid size`. - # - # The CMake build of the core library runs before this extension is - # linked (see `_CMakeBuildExtension.run` in build_ext.py), and the - # CMake install directory is injected into `library_dirs` there so the - # linker can find `libtransformer_engine.so` even on a clean build. We - # additionally try to resolve a previously-built copy here so that - # incremental builds and tooling that links this extension in isolation - # still work. libraries = ["nccl"] if not rocm_build() else [] libraries.append("transformer_engine") library_dirs: List[str] = [] - try: - from transformer_engine.common import _get_shared_object_file - core_lib_path = Path(_get_shared_object_file("core")) - library_dirs.append(str(core_lib_path.parent)) - except (ImportError, FileNotFoundError): - pass + core_lib_dir = installed_te_core_lib_dir() + if core_lib_dir is not None: + library_dirs.append(str(core_lib_dir)) # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index c4036e7fb..69063cee3 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -17,6 +17,7 @@ cuda_version, get_cuda_include_dirs, debug_build_enabled, + installed_te_core_lib_dir, ) from typing import List @@ -120,21 +121,10 @@ def setup_pytorch_extension( # avoids transitively exposing librocroller.so symbols, which interpose # with HIP's internal helpers and cause hipModuleLoad to abort with # `free(): invalid size`. - # - # The CMake build of the core library runs before this extension is - # linked (see `_CMakeBuildExtension.run` in build_ext.py), and the - # CMake install directory is injected into `library_dirs` there so the - # linker can find `libtransformer_engine.so` even on a clean build. We - # additionally try to resolve a previously-built copy here so that - # incremental builds and tooling that links this extension in isolation - # still work. libraries.append("transformer_engine") - try: - from transformer_engine.common import _get_shared_object_file - core_lib_path = Path(_get_shared_object_file("core")) - library_dirs.append(str(core_lib_path.parent)) - except (ImportError, FileNotFoundError): - pass + core_lib_dir = installed_te_core_lib_dir() + if core_lib_dir is not None: + library_dirs.append(str(core_lib_dir)) # Construct PyTorch CUDA extension sources = [str(path) for path in sources] diff --git a/build_tools/utils.py b/build_tools/utils.py index e250238e6..b6241194a 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -82,6 +82,33 @@ def remove_dups(_list: List): return list(set(_list)) +def installed_te_core_lib_dir() -> Optional[Path]: + """Locate an already-installed libtransformer_engine.so. + + Searches Python's site-packages for the core library so that + framework extensions can declare an explicit DT_NEEDED link against + it. Used at build time when the core library is being installed + separately from the framework extension (e.g. when building the + ``transformer_engine_*_torch`` sdist against a pre-installed + ``transformer_engine_*`` core package). + + Importing ``transformer_engine.common`` is intentionally avoided + here because that module eagerly loads framework extensions, which + may not exist yet during this very build. + """ + import sysconfig + + purelib = Path(sysconfig.get_paths()["purelib"]) + candidate_dirs = ( + purelib / "transformer_engine", + purelib / "transformer_engine" / "wheel_lib", + ) + for candidate in candidate_dirs: + if candidate.is_dir() and any(candidate.glob("libtransformer_engine*.so*")): + return candidate + return None + + def found_cmake() -> bool: """ "Check if valid CMake is available