Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ jobs:
name: "XLA Linux X86 GPU ONEAPI",
repo: "openxla/xla",
},
{
pool: "linux-x86-n2-16",
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest",
name: "XLA Linux x86 GPU ROCm",
repo: "openxla/xla",
},
{
pool: "linux-x86-g2-16-l4-1gpu",
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest",
Expand Down
16 changes: 8 additions & 8 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ bazel_dep(name = "rules_ml_toolchain")
# echo "sha256-${HASH}"
archive_override(
module_name = "rules_ml_toolchain",
integrity = "sha256-C0L2k6YMYFDYfbHgoOrrhKs/VBkfzglNhjNPrtyAfaA=",
strip_prefix = "rules_ml_toolchain-398d613aea7a4c294da49b79a6d6f3f8732bd84c",
urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/398d613aea7a4c294da49b79a6d6f3f8732bd84c.tar.gz"],
integrity = "sha256-Z+S9+Shsiy/58H5Pp6F8Nf1Qvh7QCmWKN+j1LdZN92w=",
strip_prefix = "rules_ml_toolchain-e0c44e92c4e03de3436ccd01ab7db54be7cdc9f1",
urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/e0c44e92c4e03de3436ccd01ab7db54be7cdc9f1.tar.gz"],
)

# TODO: Upstream the patch?
Expand Down Expand Up @@ -275,18 +275,18 @@ override_repo(
nvshmem_redist = use_extension("@rules_ml_toolchain//extensions:nvshmem_redist.bzl", "nvshmem_redist_ext")
use_repo(nvshmem_redist, "nvidia_nvshmem")

register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64")

register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_cuda")

register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_sycl")

register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64")

register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64_cuda")

### Other local config repos
hipcc_configure = use_extension("@rules_ml_toolchain//extensions:hipcc_configure.bzl", "hipcc_configure_ext")

local_clang_configure = use_extension("@rules_ml_toolchain//extensions:local_clang_configure.bzl", "local_clang_configure_ext")

use_repo(local_clang_configure, "local_config_clang")

use_repo(hipcc_configure, "config_rocm_hipcc")

rocm_configure = use_extension("//third_party/extensions:rocm_configure.bzl", "rocm_configure_ext")
Expand Down
8 changes: 7 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64")

register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64_cuda")


# Initialize hermetic Python
load("//third_party/py:python_init_rules.bzl", "python_init_rules")

Expand Down Expand Up @@ -109,6 +108,13 @@ load(

cuda_configure(name = "local_config_cuda")

load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")

rocm_configure(
name = "local_config_rocm",
rocm_dist = "@config_rocm_hipcc//rocm:rocm_dist",
)

load(
"@rules_ml_toolchain//gpu/nccl:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
Expand Down
52 changes: 52 additions & 0 deletions build_tools/ci/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class BuildType(enum.Enum):
XLA_LINUX_X86_GPU_L4_GITHUB_ACTIONS = enum.auto()
XLA_LINUX_X86_GPU_8X_H100_GITHUB_ACTIONS = enum.auto()
XLA_LINUX_X86_GPU_ONEAPI_GITHUB_ACTIONS = enum.auto()
XLA_LINUX_X86_GPU_ROCM_GITHUB_ACTIONS = enum.auto()
XLA_LINUX_X86_GPU_ROCM_LOCAL_SYSROOT_GITHUB_ACTIONS = enum.auto()

# Presubmit builds for regression testing.
XLA_LINUX_ARM64_CPU_48_VCPU_PRESUBMIT_GITHUB_ACTIONS = enum.auto()
Expand Down Expand Up @@ -513,6 +515,56 @@ def nvidia_gpu_build_with_compute_capability(
subcommand="build",
)

# ROCm builds - hermetic LLVM
# Tag filters from build_tools/rocm/rocm_tag_filters.sh + gpu
rocm_tag_filter = (
"-no_gpu",
"-requires-gpu-intel",
"-requires-gpu-nvidia",
"-cuda-only",
"-oneapi-only",
"-requires-gpu-sm60",
"-requires-gpu-sm60-only",
"-requires-gpu-sm70",
"-requires-gpu-sm70-only",
"-requires-gpu-sm80",
"-requires-gpu-sm80-only",
"-requires-gpu-sm86",
"-requires-gpu-sm86-only",
"-requires-gpu-sm89",
"-requires-gpu-sm89-only",
"-requires-gpu-sm90",
"-requires-gpu-sm90-only",
"-skip_rocprofiler_sdk",
"-no_oss",
"-oss_excluded",
"-oss_serial",
"gpu",
)

Build(
type_=BuildType.XLA_LINUX_X86_GPU_ROCM_GITHUB_ACTIONS,
repo="openxla/xla",
configs=("warnings", "rbe_linux_cpu", "rocm_clang_hermetic"),
target_patterns=_XLA_DEFAULT_TARGET_PATTERNS,
build_tag_filters=rocm_tag_filter,
test_tag_filters=rocm_tag_filter,
options={**_DEFAULT_BAZEL_OPTIONS, "//xla/tsl:ci_build": True},
subcommand="build",
)

# ROCm builds - hermetic LLVM with local sysroot
Build(
type_=BuildType.XLA_LINUX_X86_GPU_ROCM_LOCAL_SYSROOT_GITHUB_ACTIONS,
repo="openxla/xla",
configs=("warnings", "rbe_linux_cpu", "rocm_clang_hermetic_local_sysroot"),
target_patterns=_XLA_DEFAULT_TARGET_PATTERNS,
build_tag_filters=rocm_tag_filter,
test_tag_filters=rocm_tag_filter,
options={**_DEFAULT_BAZEL_OPTIONS, "//xla/tsl:ci_build": True},
subcommand="build",
)

Build(
type_=BuildType.XLA_LINUX_X86_CPU_128_VCPU_PRESUBMIT_GITHUB_ACTIONS,
repo="openxla/xla",
Expand Down
36 changes: 25 additions & 11 deletions tensorflow.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -278,28 +278,42 @@ common:asan --copt -g
common:asan --copt -fno-omit-frame-pointer
common:asan --linkopt -fsanitize=address

common:rocm_base --config=clang_local
common:rocm_base --copt=-Wno-gnu-offsetof-extensions
common:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain
common:rocm_base --copt=-D__HIP_PLATFORM_AMD__
common:rocm_base --define=using_rocm_hipcc=true
common:rocm_base --define=tensorflow_mkldnn_contraction_kernel=0
common:rocm_base --repo_env TF_NEED_ROCM=1
common:rocm_base --define=using_rocm_hipcc=true
common:rocm_base --define=tensorflow_mkldnn_contraction_kernel=0


common:rocm_clang_local --config=rocm_base
common:rocm_clang_local --config=clang_local
common:rocm_clang_local --crosstool_top=@local_config_rocm//crosstool:toolchain
common:rocm_clang_local --config=rocm_base
common:rocm_clang_local --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
common:rocm_clang_local --action_env=HIPCC_COMPILE_FLAGS_APPEND="--offload-compress"

common:rocm_clang_official --config=rocm_base
common:rocm_clang_official --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
common:rocm_clang_official --action_env=HIPCC_COMPILE_FLAGS_APPEND="--offload-compress"
common:rocm_clang_official --action_env=TF_ROCM_CLANG="1"
common:rocm_clang_official --linkopt="-fuse-ld=lld"
common:rocm_clang_official --host_linkopt="-fuse-ld=lld"
# ROCm with hermetic toolchain from rules_ml_toolchain
common:rocm_clang_hermetic --config=rocm_base
common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_cuda=False
common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_sycl=False
common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_hermetic_cc=True
common:rocm_clang_hermetic --repo_env=SYSROOT_DIST=linux_glibc_2_35
common:rocm_clang_hermetic --repo_env=TF_ROCM_AMDGPU_TARGETS=gfx90a,gfx942
common:rocm_clang_hermetic --strategy=CppLink=local
common:rocm_clang_hermetic --host_platform=@rules_ml_toolchain//common:linux_x86_64
common:rocm_clang_hermetic --extra_execution_platforms=@rules_ml_toolchain//common:linux_x86_64
common:rocm_clang_hermetic --platforms=@rules_ml_toolchain//common:linux_x86_64

common:rocm --config=rocm_clang_official
common:rocm --config=rocm_clang_hermetic

common:rocm_ci --config=rocm
common:rocm_ci --@local_config_rocm//rocm:rocm_path_type=hermetic

common:rocm_ci_hermetic --dynamic_mode=off
common:rocm_ci_hermetic --config=rocm_clang_official
common:rocm_ci_hermetic --repo_env="ROCM_DISTRO_VERSION=rocm_7.10.0_gfx90X"
common:rocm_ci_hermetic --config=rocm_clang_hermetic
common:rocm_ci_hermetic --repo_env="ROCM_DISTRO_VERSION=rocm_7.12.0_gfx94X"
common:rocm_ci_hermetic --@local_config_rocm//rocm:rocm_path_type=hermetic

# This config option is used for SYCL as GPU backend.
Expand Down
8 changes: 7 additions & 1 deletion third_party/extensions/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")

def _rocm_configure_ext_impl(mctx):
rocm_configure(
name = "local_config_rocm",
rocm_dist = "@config_rocm_hipcc//rocm:rocm_dist",
)

rocm_configure_ext = module_extension(
implementation = lambda mctx: rocm_configure(name = "local_config_rocm"),
implementation = _rocm_configure_ext_impl,
)
21 changes: 15 additions & 6 deletions third_party/gpus/crosstool/BUILD.rocm.tpl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# This file is expanded from a template by cuda_configure.bzl
# Update cuda_configure.bzl#verify_build_defines when adding new variables.
# This file is expanded from a template by rocm_configure.bzl
# Update rocm_configure.bzl#verify_build_defines when adding new variables.

load(":cc_toolchain_config.bzl", "cc_toolchain_config")
load("@local_config_clang//:clang.bzl", "local_clang")

# Local clang configuration for non-hermetic toolchain
_LOCAL_CLANG = local_clang()

licenses(["restricted"])

Expand Down Expand Up @@ -62,16 +66,18 @@ cc_toolchain_config(
target_libc = "local",
abi_version = "local",
abi_libc_version = "local",
cxx_builtin_include_directories = [%{cxx_builtin_include_directories}],
host_compiler_path = "%{host_compiler_path}",
host_compiler_prefix = "%{host_compiler_prefix}",
# Include directories detected from local clang + ROCm includes
cxx_builtin_include_directories = _LOCAL_CLANG.include_directories + [%{cxx_builtin_include_directories}],
host_compiler_path = "clang/bin/crosstool_wrapper_driver_is_not_gcc",
host_compiler_prefix = "/usr/bin",
compile_flags = [
"-U_FORTIFY_SOURCE",
"-fstack-protector",
"-Wall",
"-Wunused-but-set-parameter",
"-Wno-free-nonheap-object",
"-fno-omit-frame-pointer",
"-no-canonical-prefixes",
],
opt_compile_flags = [
"-g0",
Expand All @@ -84,9 +90,10 @@ cc_toolchain_config(
dbg_compile_flags = ["-g"],
cxx_flags = ["-std=c++17"],
link_flags = [
"-fuse-ld=gold",
"-fuse-ld=lld",
"-Wl,-no-as-needed",
"-Wl,-z,relro,-z,now",
"-Wl,--allow-shlib-undefined",
],
link_libs = [
"-lstdc++",
Expand All @@ -103,6 +110,8 @@ cc_toolchain_config(
coverage_compile_flags = ["--coverage"],
coverage_link_flags = ["--coverage"],
supports_start_end_lib = True,
# Compiler path from local_clang_info(), sets CLANG_COMPILER_PATH env var
clang_compiler_path = _LOCAL_CLANG.compiler_path,
)

filegroup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ import re
import sys
import shlex

# Template values set by rocm_configure.bzl.
CPU_COMPILER = ('%{cpu_compiler}')
HOST_COMPILER_PATH = ('%{host_compiler_path}')
# Template values set by rocm_configure.bzl or environment
CPU_COMPILER = os.environ.get('HOST_COMPILER', '/usr/bin/clang')

HIPCC_PATH = '%{rocm_root}/bin/hipcc'
HIPCC_ENV = '%{hipcc_env}'
Expand Down
39 changes: 39 additions & 0 deletions third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

load(
"@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl",
"env_entry",
"env_set",
"feature",
"feature_set",
"flag_group",
Expand Down Expand Up @@ -93,6 +95,23 @@ def _impl(ctx):
enabled = True,
)

# Set environment variables from action_env dict
env_entries = []
if ctx.attr.action_env:
for key, value in ctx.attr.action_env.items():
env_entries.append(env_entry(key = key, value = value))

compiler_env_feature = feature(
name = "compiler_env",
enabled = True,
env_sets = [
env_set(
actions = all_compile_actions + all_link_actions,
env_entries = env_entries,
),
] if env_entries else [],
)

default_compile_flags_feature = feature(
name = "default_compile_flags",
enabled = True,
Expand Down Expand Up @@ -1072,6 +1091,22 @@ def _impl(ctx):
],
)

clang_compiler_path_feature = feature(
name = "clang-compiler-path",
enabled = ctx.attr.clang_compiler_path != "",
env_sets = [
env_set(
actions = all_compile_actions + all_link_actions,
env_entries = [
env_entry(
key = "HOST_COMPILER",
value = ctx.attr.clang_compiler_path,
),
],
),
] if ctx.attr.clang_compiler_path else [],
)

features = [
dependency_file_feature,
random_seed_feature,
Expand Down Expand Up @@ -1100,6 +1135,8 @@ def _impl(ctx):
strip_debug_symbols_feature,
coverage_feature,
supports_pic_feature,
compiler_env_feature,
clang_compiler_path_feature,
] + (
[
supports_start_end_lib_feature,
Expand Down Expand Up @@ -1164,6 +1201,8 @@ cc_toolchain_config = rule(
"host_compiler_path": attr.string(),
"host_compiler_prefix": attr.string(),
"linker_bin_path": attr.string(),
"action_env": attr.string_dict(),
"clang_compiler_path": attr.string(),
},
provides = [CcToolchainConfigInfo],
)
Expand Down
13 changes: 12 additions & 1 deletion third_party/gpus/rocm/BUILD.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_lib_import")
load("@config_rocm_hipcc//rocm:build_defs.bzl", "hipcc_config")
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_lib_import", "rocm_version_number")

licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like

Expand Down Expand Up @@ -65,6 +66,7 @@ cc_library(

cc_library(
name = "rocm_config",
defines = ["TENSORFLOW_USE_ROCM=1"],
visibility = ["//visibility:public"],
deps = select({
":build_hermetic": [
Expand Down Expand Up @@ -110,6 +112,9 @@ cc_library(
# These must live in a cc_library (not a toolchain feature) because
# cc_library linkopts propagate transitively through CcInfo to the
# final linking target, whereas toolchain features do not.
# Get lib_paths from hipcc_config for multiple ROCm paths support
_ROCM_LIB_PATHS = hipcc_config().lib_paths

cc_library(
name = "rocm_rpath",
linkopts = select({
Expand All @@ -130,6 +135,12 @@ cc_library(
visibility = ["//visibility:public"],
)

filegroup(
name = "amdgcn_bitcode",
srcs = glob(["%{rocm_root}/amdgcn/bitcode/*.bc"]),
visibility = ["//visibility:public"],
)

alias(
name = "hip",
actual = ":hip_runtime",
Expand Down
Loading
Loading