diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 31bb4392d..6cb7bdc98 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -115,6 +115,7 @@ class _CMakeBuildExtension(extension_cls): def run(self) -> None: # Build CMake extensions + cmake_install_dirs: List[Path] = [] for ext in self.extensions: package_path = Path(self.get_ext_fullpath(ext.name)) install_dir = package_path.resolve().parent @@ -135,12 +136,21 @@ def run(self) -> None: build_dir=build_dir, install_dir=install_dir, ) + cmake_install_dirs.append(install_dir) # Build non-CMake extensions as usual all_extensions = self.extensions self.extensions = [ ext for ext in self.extensions if not isinstance(ext, CMakeExtension) ] + # Make CMake-installed shared libraries (e.g. libtransformer_engine.so) + # discoverable when linking framework extensions that declare them as + # NEEDED dependencies on a fresh full-tree build. + for ext in self.extensions: + for cmake_install_dir in cmake_install_dirs: + cmake_install_dir_str = str(cmake_install_dir) + if cmake_install_dir_str not in ext.library_dirs: + ext.library_dirs.append(cmake_install_dir_str) super().run() self.extensions = all_extensions diff --git a/build_tools/jax.py b/build_tools/jax.py index 62dc4336e..9144e4b44 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -12,6 +12,7 @@ from .utils import rocm_build, rocm_path from .utils import all_files_in_dir, get_cuda_include_dirs, debug_build_enabled +from .utils import installed_te_core_lib_dir from typing import List @@ -123,6 +124,18 @@ def setup_jax_extension( if rocm_build(): cxx_flags.extend(["-D__HIP_PLATFORM_AMD__", "-DUSE_ROCM"]) + # Link against the TE core library so the jax extension resolves core + # symbols via the ELF NEEDED graph rather than via RTLD_GLOBAL. This + # avoids transitively exposing librocroller.so symbols, which interpose + # with HIP's internal helpers and cause hipModuleLoad to abort with + # `free(): invalid size`. + libraries = ["nccl"] if not rocm_build() else [] + libraries.append("transformer_engine") + library_dirs: List[str] = [] + core_lib_dir = installed_te_core_lib_dir() + if core_lib_dir is not None: + library_dirs.append(str(core_lib_dir)) + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension @@ -131,5 +144,6 @@ def setup_jax_extension( sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, - libraries=["nccl"] if not rocm_build() else [], + libraries=libraries, + library_dirs=library_dirs, ) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index c797ffad7..69063cee3 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -17,6 +17,7 @@ cuda_version, get_cuda_include_dirs, debug_build_enabled, + installed_te_core_lib_dir, ) from typing import List @@ -115,6 +116,16 @@ def setup_pytorch_extension( libraries.append("mpi") cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"]) + # Link against the TE core library so the torch extension resolves core + # symbols via the ELF NEEDED graph rather than via RTLD_GLOBAL. This + # avoids transitively exposing librocroller.so symbols, which interpose + # with HIP's internal helpers and cause hipModuleLoad to abort with + # `free(): invalid size`. + libraries.append("transformer_engine") + core_lib_dir = installed_te_core_lib_dir() + if core_lib_dir is not None: + library_dirs.append(str(core_lib_dir)) + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/build_tools/utils.py b/build_tools/utils.py index e250238e6..b6241194a 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -82,6 +82,33 @@ def remove_dups(_list: List): return list(set(_list)) +def installed_te_core_lib_dir() -> Optional[Path]: + """Locate an already-installed libtransformer_engine.so. + + Searches Python's site-packages for the core library so that + framework extensions can declare an explicit DT_NEEDED link against + it. Used at build time when the core library is being installed + separately from the framework extension (e.g. when building the + ``transformer_engine_*_torch`` sdist against a pre-installed + ``transformer_engine_*`` core package). + + Importing ``transformer_engine.common`` is intentionally avoided + here because that module eagerly loads framework extensions, which + may not exist yet during this very build. + """ + import sysconfig + + purelib = Path(sysconfig.get_paths()["purelib"]) + candidate_dirs = ( + purelib / "transformer_engine", + purelib / "transformer_engine" / "wheel_lib", + ) + for candidate in candidate_dirs: + if candidate.is_dir() and any(candidate.glob("libtransformer_engine*.so*")): + return candidate + return None + + def found_cmake() -> bool: """ "Check if valid CMake is available diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 9dbf998e5..0438fcc40 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -409,8 +409,16 @@ def is_fp8_fnuz(): @functools.lru_cache(maxsize=None) def _load_core_library(): - """Load shared library with Transformer Engine C extensions""" - return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL) + """Load shared library with Transformer Engine C extensions. + + Loaded with RTLD_LOCAL so that transitive dependencies (notably + librocroller.so on ROCm) are not promoted into the global symbol + namespace, where they would interpose with HIP runtime helpers and + cause hipModuleLoad to abort with `free(): invalid size`. Framework + extensions therefore link against this library explicitly via ELF + NEEDED rather than relying on global symbol visibility. + """ + return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_LOCAL) if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):