Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions build_tools/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class _CMakeBuildExtension(extension_cls):

def run(self) -> None:
# Build CMake extensions
cmake_install_dirs: List[Path] = []
for ext in self.extensions:
package_path = Path(self.get_ext_fullpath(ext.name))
install_dir = package_path.resolve().parent
Expand All @@ -135,12 +136,21 @@ def run(self) -> None:
build_dir=build_dir,
install_dir=install_dir,
)
cmake_install_dirs.append(install_dir)

# Build non-CMake extensions as usual
all_extensions = self.extensions
self.extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
# Make CMake-installed shared libraries (e.g. libtransformer_engine.so)
# discoverable when linking framework extensions that declare them as
# NEEDED dependencies on a fresh full-tree build.
for ext in self.extensions:
for cmake_install_dir in cmake_install_dirs:
cmake_install_dir_str = str(cmake_install_dir)
if cmake_install_dir_str not in ext.library_dirs:
ext.library_dirs.append(cmake_install_dir_str)
super().run()
self.extensions = all_extensions

Expand Down
16 changes: 15 additions & 1 deletion build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .utils import rocm_build, rocm_path
from .utils import all_files_in_dir, get_cuda_include_dirs, debug_build_enabled
from .utils import installed_te_core_lib_dir
from typing import List


Expand Down Expand Up @@ -123,6 +124,18 @@ def setup_jax_extension(
if rocm_build():
cxx_flags.extend(["-D__HIP_PLATFORM_AMD__", "-DUSE_ROCM"])

# Link against the TE core library so the jax extension resolves core
# symbols via the ELF NEEDED graph rather than via RTLD_GLOBAL. This
# avoids transitively exposing librocroller.so symbols, which interpose
# with HIP's internal helpers and cause hipModuleLoad to abort with
# `free(): invalid size`.
libraries = ["nccl"] if not rocm_build() else []
libraries.append("transformer_engine")
library_dirs: List[str] = []
core_lib_dir = installed_te_core_lib_dir()
if core_lib_dir is not None:
library_dirs.append(str(core_lib_dir))

# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

Expand All @@ -131,5 +144,6 @@ def setup_jax_extension(
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=["nccl"] if not rocm_build() else [],
libraries=libraries,
library_dirs=library_dirs,
)
11 changes: 11 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
cuda_version,
get_cuda_include_dirs,
debug_build_enabled,
installed_te_core_lib_dir,
)
from typing import List

Expand Down Expand Up @@ -115,6 +116,16 @@ def setup_pytorch_extension(
libraries.append("mpi")
cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"])

# Link against the TE core library so the torch extension resolves core
# symbols via the ELF NEEDED graph rather than via RTLD_GLOBAL. This
# avoids transitively exposing librocroller.so symbols, which interpose
# with HIP's internal helpers and cause hipModuleLoad to abort with
# `free(): invalid size`.
libraries.append("transformer_engine")
core_lib_dir = installed_te_core_lib_dir()
if core_lib_dir is not None:
library_dirs.append(str(core_lib_dir))

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand Down
27 changes: 27 additions & 0 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,33 @@ def remove_dups(_list: List):
return list(set(_list))


def installed_te_core_lib_dir() -> Optional[Path]:
"""Locate an already-installed libtransformer_engine.so.

Searches Python's site-packages for the core library so that
framework extensions can declare an explicit DT_NEEDED link against
it. Used at build time when the core library is being installed
separately from the framework extension (e.g. when building the
``transformer_engine_*_torch`` sdist against a pre-installed
``transformer_engine_*`` core package).

Importing ``transformer_engine.common`` is intentionally avoided
here because that module eagerly loads framework extensions, which
may not exist yet during this very build.
"""
import sysconfig

purelib = Path(sysconfig.get_paths()["purelib"])
candidate_dirs = (
purelib / "transformer_engine",
purelib / "transformer_engine" / "wheel_lib",
)
for candidate in candidate_dirs:
if candidate.is_dir() and any(candidate.glob("libtransformer_engine*.so*")):
return candidate
return None


def found_cmake() -> bool:
""" "Check if valid CMake is available

Expand Down
12 changes: 10 additions & 2 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,16 @@ def is_fp8_fnuz():

@functools.lru_cache(maxsize=None)
def _load_core_library():
"""Load shared library with Transformer Engine C extensions"""
return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL)
"""Load shared library with Transformer Engine C extensions.

Loaded with RTLD_LOCAL so that transitive dependencies (notably
librocroller.so on ROCm) are not promoted into the global symbol
namespace, where they would interpose with HIP runtime helpers and
cause hipModuleLoad to abort with `free(): invalid size`. Framework
extensions therefore link against this library explicitly via ELF
NEEDED rather than relying on global symbol visibility.
"""
return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_LOCAL)


if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
Expand Down
Loading