diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 26645856d27cf..57546d16bfa4c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,6 +66,12 @@ jobs: name: "XLA Linux X86 GPU ONEAPI", repo: "openxla/xla", }, + { + pool: "linux-x86-n2-16", + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest", + name: "XLA Linux x86 GPU ROCm", + repo: "openxla/xla", + }, { pool: "linux-x86-g2-16-l4-1gpu", container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest", diff --git a/MODULE.bazel b/MODULE.bazel index 1d200a52ae8ed..344be7248ae6a 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -46,9 +46,9 @@ bazel_dep(name = "rules_ml_toolchain") # echo "sha256-${HASH}" archive_override( module_name = "rules_ml_toolchain", - integrity = "sha256-C0L2k6YMYFDYfbHgoOrrhKs/VBkfzglNhjNPrtyAfaA=", - strip_prefix = "rules_ml_toolchain-398d613aea7a4c294da49b79a6d6f3f8732bd84c", - urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/398d613aea7a4c294da49b79a6d6f3f8732bd84c.tar.gz"], + integrity = "sha256-Z+S9+Shsiy/58H5Pp6F8Nf1Qvh7QCmWKN+j1LdZN92w=", + strip_prefix = "rules_ml_toolchain-e0c44e92c4e03de3436ccd01ab7db54be7cdc9f1", + urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/e0c44e92c4e03de3436ccd01ab7db54be7cdc9f1.tar.gz"], ) # TODO: Upstream the patch? @@ -275,18 +275,18 @@ override_repo( nvshmem_redist = use_extension("@rules_ml_toolchain//extensions:nvshmem_redist.bzl", "nvshmem_redist_ext") use_repo(nvshmem_redist, "nvidia_nvshmem") -register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64") - register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_cuda") register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_sycl") register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64") -register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64_cuda") - -### Other local config repos hipcc_configure = use_extension("@rules_ml_toolchain//extensions:hipcc_configure.bzl", "hipcc_configure_ext") + +local_clang_configure = use_extension("@rules_ml_toolchain//extensions:local_clang_configure.bzl", "local_clang_configure_ext") + +use_repo(local_clang_configure, "local_config_clang") + use_repo(hipcc_configure, "config_rocm_hipcc") rocm_configure = use_extension("//third_party/extensions:rocm_configure.bzl", "rocm_configure_ext") diff --git a/WORKSPACE b/WORKSPACE index 56d1677d159f4..08c1df2a86591 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -31,7 +31,6 @@ register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64") register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64_cuda") - # Initialize hermetic Python load("//third_party/py:python_init_rules.bzl", "python_init_rules") @@ -109,6 +108,13 @@ load( cuda_configure(name = "local_config_cuda") +load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") + +rocm_configure( + name = "local_config_rocm", + rocm_dist = "@config_rocm_hipcc//rocm:rocm_dist", +) + load( "@rules_ml_toolchain//gpu/nccl:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", diff --git a/build_tools/ci/build.py b/build_tools/ci/build.py index ad44ce14bfd92..e0058e2106f0e 100755 --- a/build_tools/ci/build.py +++ b/build_tools/ci/build.py @@ -110,6 +110,8 @@ class BuildType(enum.Enum): XLA_LINUX_X86_GPU_L4_GITHUB_ACTIONS = enum.auto() XLA_LINUX_X86_GPU_8X_H100_GITHUB_ACTIONS = enum.auto() XLA_LINUX_X86_GPU_ONEAPI_GITHUB_ACTIONS = enum.auto() + XLA_LINUX_X86_GPU_ROCM_GITHUB_ACTIONS = enum.auto() + XLA_LINUX_X86_GPU_ROCM_LOCAL_SYSROOT_GITHUB_ACTIONS = enum.auto() # Presubmit builds for regression testing. XLA_LINUX_ARM64_CPU_48_VCPU_PRESUBMIT_GITHUB_ACTIONS = enum.auto() @@ -513,6 +515,56 @@ def nvidia_gpu_build_with_compute_capability( subcommand="build", ) +# ROCm builds - hermetic LLVM +# Tag filters from build_tools/rocm/rocm_tag_filters.sh + gpu +rocm_tag_filter = ( + "-no_gpu", + "-requires-gpu-intel", + "-requires-gpu-nvidia", + "-cuda-only", + "-oneapi-only", + "-requires-gpu-sm60", + "-requires-gpu-sm60-only", + "-requires-gpu-sm70", + "-requires-gpu-sm70-only", + "-requires-gpu-sm80", + "-requires-gpu-sm80-only", + "-requires-gpu-sm86", + "-requires-gpu-sm86-only", + "-requires-gpu-sm89", + "-requires-gpu-sm89-only", + "-requires-gpu-sm90", + "-requires-gpu-sm90-only", + "-skip_rocprofiler_sdk", + "-no_oss", + "-oss_excluded", + "-oss_serial", + "gpu", +) + +Build( + type_=BuildType.XLA_LINUX_X86_GPU_ROCM_GITHUB_ACTIONS, + repo="openxla/xla", + configs=("warnings", "rbe_linux_cpu", "rocm_clang_hermetic"), + target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, + build_tag_filters=rocm_tag_filter, + test_tag_filters=rocm_tag_filter, + options={**_DEFAULT_BAZEL_OPTIONS, "//xla/tsl:ci_build": True}, + subcommand="build", +) + +# ROCm builds - hermetic LLVM with local sysroot +Build( + type_=BuildType.XLA_LINUX_X86_GPU_ROCM_LOCAL_SYSROOT_GITHUB_ACTIONS, + repo="openxla/xla", + configs=("warnings", "rbe_linux_cpu", "rocm_clang_hermetic_local_sysroot"), + target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, + build_tag_filters=rocm_tag_filter, + test_tag_filters=rocm_tag_filter, + options={**_DEFAULT_BAZEL_OPTIONS, "//xla/tsl:ci_build": True}, + subcommand="build", +) + Build( type_=BuildType.XLA_LINUX_X86_CPU_128_VCPU_PRESUBMIT_GITHUB_ACTIONS, repo="openxla/xla", diff --git a/tensorflow.bazelrc b/tensorflow.bazelrc index de3272783ad01..630b8b4f4304d 100644 --- a/tensorflow.bazelrc +++ b/tensorflow.bazelrc @@ -278,28 +278,42 @@ common:asan --copt -g common:asan --copt -fno-omit-frame-pointer common:asan --linkopt -fsanitize=address -common:rocm_base --config=clang_local common:rocm_base --copt=-Wno-gnu-offsetof-extensions -common:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain +common:rocm_base --copt=-D__HIP_PLATFORM_AMD__ common:rocm_base --define=using_rocm_hipcc=true common:rocm_base --define=tensorflow_mkldnn_contraction_kernel=0 common:rocm_base --repo_env TF_NEED_ROCM=1 +common:rocm_base --define=using_rocm_hipcc=true +common:rocm_base --define=tensorflow_mkldnn_contraction_kernel=0 + + +common:rocm_clang_local --config=rocm_base +common:rocm_clang_local --config=clang_local +common:rocm_clang_local --crosstool_top=@local_config_rocm//crosstool:toolchain +common:rocm_clang_local --config=rocm_base +common:rocm_clang_local --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" +common:rocm_clang_local --action_env=HIPCC_COMPILE_FLAGS_APPEND="--offload-compress" -common:rocm_clang_official --config=rocm_base -common:rocm_clang_official --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -common:rocm_clang_official --action_env=HIPCC_COMPILE_FLAGS_APPEND="--offload-compress" -common:rocm_clang_official --action_env=TF_ROCM_CLANG="1" -common:rocm_clang_official --linkopt="-fuse-ld=lld" -common:rocm_clang_official --host_linkopt="-fuse-ld=lld" +# ROCm with hermetic toolchain from rules_ml_toolchain +common:rocm_clang_hermetic --config=rocm_base +common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_cuda=False +common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_sycl=False +common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_hermetic_cc=True +common:rocm_clang_hermetic --repo_env=SYSROOT_DIST=linux_glibc_2_35 +common:rocm_clang_hermetic --repo_env=TF_ROCM_AMDGPU_TARGETS=gfx90a,gfx942 +common:rocm_clang_hermetic --strategy=CppLink=local +common:rocm_clang_hermetic --host_platform=@rules_ml_toolchain//common:linux_x86_64 +common:rocm_clang_hermetic --extra_execution_platforms=@rules_ml_toolchain//common:linux_x86_64 +common:rocm_clang_hermetic --platforms=@rules_ml_toolchain//common:linux_x86_64 -common:rocm --config=rocm_clang_official +common:rocm --config=rocm_clang_hermetic common:rocm_ci --config=rocm common:rocm_ci --@local_config_rocm//rocm:rocm_path_type=hermetic common:rocm_ci_hermetic --dynamic_mode=off -common:rocm_ci_hermetic --config=rocm_clang_official -common:rocm_ci_hermetic --repo_env="ROCM_DISTRO_VERSION=rocm_7.10.0_gfx90X" +common:rocm_ci_hermetic --config=rocm_clang_hermetic +common:rocm_ci_hermetic --repo_env="ROCM_DISTRO_VERSION=rocm_7.12.0_gfx94X" common:rocm_ci_hermetic --@local_config_rocm//rocm:rocm_path_type=hermetic # This config option is used for SYCL as GPU backend. diff --git a/third_party/extensions/rocm_configure.bzl b/third_party/extensions/rocm_configure.bzl index cf755c8bc307f..857eccc96a3f8 100644 --- a/third_party/extensions/rocm_configure.bzl +++ b/third_party/extensions/rocm_configure.bzl @@ -2,6 +2,12 @@ load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") +def _rocm_configure_ext_impl(mctx): + rocm_configure( + name = "local_config_rocm", + rocm_dist = "@config_rocm_hipcc//rocm:rocm_dist", + ) + rocm_configure_ext = module_extension( - implementation = lambda mctx: rocm_configure(name = "local_config_rocm"), + implementation = _rocm_configure_ext_impl, ) diff --git a/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/gpus/crosstool/BUILD.rocm.tpl index 392b29f3c8b54..23120be051c47 100644 --- a/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -1,7 +1,11 @@ -# This file is expanded from a template by cuda_configure.bzl -# Update cuda_configure.bzl#verify_build_defines when adding new variables. +# This file is expanded from a template by rocm_configure.bzl +# Update rocm_configure.bzl#verify_build_defines when adding new variables. load(":cc_toolchain_config.bzl", "cc_toolchain_config") +load("@local_config_clang//:clang.bzl", "local_clang") + +# Local clang configuration for non-hermetic toolchain +_LOCAL_CLANG = local_clang() licenses(["restricted"]) @@ -62,9 +66,10 @@ cc_toolchain_config( target_libc = "local", abi_version = "local", abi_libc_version = "local", - cxx_builtin_include_directories = [%{cxx_builtin_include_directories}], - host_compiler_path = "%{host_compiler_path}", - host_compiler_prefix = "%{host_compiler_prefix}", + # Include directories detected from local clang + ROCm includes + cxx_builtin_include_directories = _LOCAL_CLANG.include_directories + [%{cxx_builtin_include_directories}], + host_compiler_path = "clang/bin/crosstool_wrapper_driver_is_not_gcc", + host_compiler_prefix = "/usr/bin", compile_flags = [ "-U_FORTIFY_SOURCE", "-fstack-protector", @@ -72,6 +77,7 @@ cc_toolchain_config( "-Wunused-but-set-parameter", "-Wno-free-nonheap-object", "-fno-omit-frame-pointer", + "-no-canonical-prefixes", ], opt_compile_flags = [ "-g0", @@ -84,9 +90,10 @@ cc_toolchain_config( dbg_compile_flags = ["-g"], cxx_flags = ["-std=c++17"], link_flags = [ - "-fuse-ld=gold", + "-fuse-ld=lld", "-Wl,-no-as-needed", "-Wl,-z,relro,-z,now", + "-Wl,--allow-shlib-undefined", ], link_libs = [ "-lstdc++", @@ -103,6 +110,8 @@ cc_toolchain_config( coverage_compile_flags = ["--coverage"], coverage_link_flags = ["--coverage"], supports_start_end_lib = True, + # Compiler path from local_clang_info(), sets CLANG_COMPILER_PATH env var + clang_compiler_path = _LOCAL_CLANG.compiler_path, ) filegroup( diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 9a1fa830d14c8..8f0535a640b8e 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -22,9 +22,8 @@ import re import sys import shlex -# Template values set by rocm_configure.bzl. -CPU_COMPILER = ('%{cpu_compiler}') -HOST_COMPILER_PATH = ('%{host_compiler_path}') +# Template values set by rocm_configure.bzl or environment +CPU_COMPILER = os.environ.get('HOST_COMPILER', '/usr/bin/clang') HIPCC_PATH = '%{rocm_root}/bin/hipcc' HIPCC_ENV = '%{hipcc_env}' diff --git a/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl b/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl index 26c6bcdff5cf9..37fb69a78d8f2 100644 --- a/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl +++ b/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl @@ -2,6 +2,8 @@ load( "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "env_entry", + "env_set", "feature", "feature_set", "flag_group", @@ -93,6 +95,23 @@ def _impl(ctx): enabled = True, ) + # Set environment variables from action_env dict + env_entries = [] + if ctx.attr.action_env: + for key, value in ctx.attr.action_env.items(): + env_entries.append(env_entry(key = key, value = value)) + + compiler_env_feature = feature( + name = "compiler_env", + enabled = True, + env_sets = [ + env_set( + actions = all_compile_actions + all_link_actions, + env_entries = env_entries, + ), + ] if env_entries else [], + ) + default_compile_flags_feature = feature( name = "default_compile_flags", enabled = True, @@ -1072,6 +1091,22 @@ def _impl(ctx): ], ) + clang_compiler_path_feature = feature( + name = "clang-compiler-path", + enabled = ctx.attr.clang_compiler_path != "", + env_sets = [ + env_set( + actions = all_compile_actions + all_link_actions, + env_entries = [ + env_entry( + key = "HOST_COMPILER", + value = ctx.attr.clang_compiler_path, + ), + ], + ), + ] if ctx.attr.clang_compiler_path else [], + ) + features = [ dependency_file_feature, random_seed_feature, @@ -1100,6 +1135,8 @@ def _impl(ctx): strip_debug_symbols_feature, coverage_feature, supports_pic_feature, + compiler_env_feature, + clang_compiler_path_feature, ] + ( [ supports_start_end_lib_feature, @@ -1164,6 +1201,8 @@ cc_toolchain_config = rule( "host_compiler_path": attr.string(), "host_compiler_prefix": attr.string(), "linker_bin_path": attr.string(), + "action_env": attr.string_dict(), + "clang_compiler_path": attr.string(), }, provides = [CcToolchainConfigInfo], ) diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index 56f726ee9e5c1..a5981c655bb32 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -1,6 +1,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") -load("@local_config_rocm//rocm:build_defs.bzl", "rocm_lib_import") +load("@config_rocm_hipcc//rocm:build_defs.bzl", "hipcc_config") +load("@local_config_rocm//rocm:build_defs.bzl", "rocm_lib_import", "rocm_version_number") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like @@ -65,6 +66,7 @@ cc_library( cc_library( name = "rocm_config", + defines = ["TENSORFLOW_USE_ROCM=1"], visibility = ["//visibility:public"], deps = select({ ":build_hermetic": [ @@ -110,6 +112,9 @@ cc_library( # These must live in a cc_library (not a toolchain feature) because # cc_library linkopts propagate transitively through CcInfo to the # final linking target, whereas toolchain features do not. +# Get lib_paths from hipcc_config for multiple ROCm paths support +_ROCM_LIB_PATHS = hipcc_config().lib_paths + cc_library( name = "rocm_rpath", linkopts = select({ @@ -130,6 +135,12 @@ cc_library( visibility = ["//visibility:public"], ) +filegroup( + name = "amdgcn_bitcode", + srcs = glob(["%{rocm_root}/amdgcn/bitcode/*.bc"]), + visibility = ["//visibility:public"], +) + alias( name = "hip", actual = ":hip_runtime", diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index 3cd2a18e9203a..0858477078e6e 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -2,6 +2,9 @@ load("@rules_cc//cc:cc_import.bzl", "cc_import") load("@rules_cc//cc:cc_library.bzl", "cc_library") # Macros for building ROCm code. +# rocm_library is loaded and wrapped below +load("@rules_ml_toolchain//cc/rocm:rocm_library.bzl", _rocm_library_impl = "rocm_library") + def if_rocm(if_true, if_false = []): """Shorthand for select()'ing on whether we're building with ROCm. @@ -14,24 +17,6 @@ def if_rocm(if_true, if_false = []): "//conditions:default": if_false, }) -def rocm_default_copts(): - """Default options for all ROCm compilations.""" - return if_rocm(["-x", "rocm"] + %{rocm_extra_copts}) - -def rocm_copts(opts = []): - """Gets the appropriate set of copts for (maybe) ROCm compilation. - - If we're doing ROCm compilation, returns copts for our particular ROCm - compiler. If we're not doing ROCm compilation, returns an empty list. - - """ - return rocm_default_copts() + select({ - "//conditions:default": [], - "@local_config_rocm//rocm:using_hipcc": ([ - "", - ]), - }) + if_rocm_is_configured(opts) - def rocm_gpu_architectures(): """Returns a list of supported GPU architectures.""" return %{rocm_gpu_architectures} @@ -68,19 +53,22 @@ def is_rocm_configured(): """ return %{rocm_is_configured} -def rocm_hipblaslt(): - return %{rocm_is_configured} and %{rocm_hipblaslt} +# rocm_library is now defined in @rules_ml_toolchain//cc/rocm:rocm_library.bzl +# It's loaded at the top with alias _rocm_library_impl and wrapped below +def rocm_library(name, srcs = [], hdrs = [], copts = [], deps = [], **kwargs): + """Wrapper for rocm_library that adds local_config_rocm headers.""" + if "@local_config_rocm//rocm:rocm_headers" not in deps: + deps = deps + ["@local_config_rocm//rocm:rocm_headers"] -def if_rocm_hipblaslt(x): - if %{rocm_is_configured} and (%{rocm_hipblaslt} == "True"): - return select({"//conditions:default": x}) - return select({"//conditions:default": []}) + _rocm_library_impl( + name = name, + srcs = srcs, + hdrs = hdrs, + copts = copts, + deps = deps, + **kwargs + ) -def rocm_library(copts = [], deps = [], **kwargs): - """Wrapper over cc_library which adds default ROCm options.""" - if "@local_config_rocm//rocm:rocm_headers" not in deps: - deps.append("@local_config_rocm//rocm:rocm_headers") - cc_library(copts = rocm_default_copts() + copts, deps = deps, **kwargs) def get_rbe_amdgpu_pool(is_single_gpu = False): return "%{single_gpu_rbe_pool}" if is_single_gpu else "%{multi_gpu_rbe_pool}" diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 404bf533877ad..fe364120252e0 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -7,19 +7,12 @@ * `TF_SYSROOT`: The sysroot to use when compiling. * `CLANG_COMPILER_PATH`: The clang compiler path that will be used for host code compilation if TF_ROCM_CLANG is 1. - * `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`. * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. * `TF_ROCM_RBE_DOCKER_IMAGE`: Docker image to be used in rbe worker to execute the action * `TF_ROCM_RBE_SINGLE_GPU_POOL`: The name of the rbe pool used to execute single gpu tests * `TF_ROCM_RBE_MULTI_GPU_POOL`: The name of the rbe pool used to execute multi gpu tests """ -load("@bazel_skylib//lib:paths.bzl", "paths") -load( - "//third_party/gpus/rocm:rocm_redist.bzl", - "create_rocm_distro", - "rocm_redist", -) load( "//third_party/remote_config:common.bzl", "config_repo_label", @@ -31,12 +24,9 @@ load( "get_host_environ", "get_python_bin", "realpath", - "relative_to", - "which", ) load( ":compiler_common_tools.bzl", - "get_cxx_inc_directories", "to_list_of_strings", ) load( @@ -48,20 +38,11 @@ load( "enable_sycl", ) -_CLANG_COMPILER_PATH = "CLANG_COMPILER_PATH" -_TF_SYSROOT = "TF_SYSROOT" -_ROCM_TOOLKIT_PATH = "ROCM_PATH" _TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" _DISTRIBUTION_PATH = "rocm/rocm_dist" -_ROCM_DISTRO_VERSION = "ROCM_DISTRO_VERSION" -_ROCM_DISTRO_URL = "ROCM_DISTRO_URL" -_ROCM_DISTRO_HASH = "ROCM_DISTRO_HASH" -_ROCM_DISTRO_LINKS = "ROCM_DISTRO_LINKS" _TMPDIR = "TMPDIR" -_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" -_TF_ROCM_MULTIPLE_PATHS = "TF_ROCM_MULTIPLE_PATHS" _TF_ROCM_RBE_DOCKER_IMAGE = "TF_ROCM_RBE_DOCKER_IMAGE" _TF_ROCM_RBE_POOL = "TF_ROCM_RBE_POOL" _TF_ROCM_RBE_SINGLE_GPU_POOL = "TF_ROCM_RBE_SINGLE_GPU_POOL" @@ -72,7 +53,18 @@ _DEFAULT_TF_ROCM_RBE_MULTI_GPU_POOL = "linux_x64_multigpu" # rocm/tensorflow-build:latest-jammy-python3.11-rocm7.0.2 _DEFAULT_TF_ROCM_RBE_DOCKER_IMAGE = "rocm/tensorflow-build@sha256:a2672ff2510b369b4a5f034272a518dc93c2e492894e3befaeef19649632ccaa" -_LLVM_PATH = "LLVM_PATH" + +def auto_configure_fail(msg): + """Output failure message when rocm configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg)) + +def auto_configure_warning(msg): + """Output warning message during auto configuration.""" + yellow = "\033[1;33m" + no_color = "\033[0m" + print("\n%sAuto-Configuration Warning:%s %s\n" % (yellow, no_color, msg)) def verify_build_defines(params): """Verify all variables that crosstool/BUILD.rocm.tpl expects are substituted. @@ -82,10 +74,7 @@ def verify_build_defines(params): """ missing = [] for param in [ - "cxx_builtin_include_directories", - "extra_no_canonical_prefixes_flags", "host_compiler_path", - "host_compiler_prefix", "linker_bin_path", "unfiltered_compile_flags", ]: @@ -94,78 +83,51 @@ def verify_build_defines(params): if missing: auto_configure_fail( - "BUILD.rocm.tpl template is missing these variables: " + - str(missing) + - ".\nWe only got: " + - str(params) + - ".", + "Missing template parameters: %s" % missing, ) -def find_cc(repository_ctx): - """Find the C++ compiler.""" - - target_cc_name = "clang" - cc_name = target_cc_name - - cc_name_from_env = get_host_environ(repository_ctx, _CLANG_COMPILER_PATH) - if cc_name_from_env: - cc_name = cc_name_from_env - if cc_name.startswith("/"): - # Absolute path, maybe we should make this supported by our which function. - return cc_name - cc = which(repository_ctx, cc_name) - if cc == None: - fail(("Cannot find {}, either correct your path or set the {}" + - " environment variable").format(target_cc_name, _CLANG_COMPILER_PATH)) - return cc - -def auto_configure_fail(msg): - """Output failure message when rocm configuration fails.""" - red = "\033[0;31m" - no_color = "\033[0m" - fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg)) - -def auto_configure_warning(msg): - """Output warning message during auto configuration.""" - yellow = "\033[1;33m" - no_color = "\033[0m" - print("\n%sAuto-Configuration Warning:%s %s\n" % (yellow, no_color, msg)) - -# END cc_configure common functions (see TODO above). - def _rocm_include_path(repository_ctx, rocm_config, bash_bin): """Generates the entries for rocm inc dirs based on rocm_config. Args: repository_ctx: The repository context. - rocm_config: The path to the gcc host compiler. + rocm_config: The ROCm config struct. bash_bin: path to the bash interpreter. Returns: - A string containing the Starlark string for each of the hipcc - compiler include directories, which can be added to the CROSSTOOL - file. + A list of the ROCm compiler include directories. """ inc_dirs = [] # Add HIP-Clang headers (relative to rocm root) - rocm_path = repository_ctx.path(rocm_config.rocm_toolkit_path) - clang_path = rocm_path.get_child("llvm/bin/clang") - resource_dir_result = execute(repository_ctx, [str(clang_path), "-print-resource-dir"]) + inc_dirs.append(str(repository_ctx.path(rocm_config.rocm_toolkit_path)) + "/include") + inc_dirs.append(str(repository_ctx.path(rocm_config.rocm_toolkit_path)) + "/lib/llvm/lib/clang/18/include") - if resource_dir_result.return_code: - auto_configure_fail("Failed to run hipcc -print-resource-dir: %s" % err_out(resource_dir_result)) - - resource_dir_abs = resource_dir_result.stdout.strip() + return inc_dirs - resource_dir_rel = relative_to(repository_ctx, str(rocm_path.realpath), resource_dir_abs, bash_bin) +def _hipcc_env(repository_ctx): + """Returns the environment variable string for hipcc. - resource_dir = str(rocm_path.get_child(resource_dir_rel)) + Args: + repository_ctx: The repository context. - inc_dirs.append(resource_dir + "/include") - inc_dirs.append(resource_dir + "/share") + Returns: + A string containing environment variables for hipcc. + """ + hipcc_env = "" + for name in [ + "HIP_CLANG_PATH", + "DEVICE_LIB_PATH", + "HIP_VDI_HOME", + "HIPCC_VERBOSE", + "HIPCC_COMPILE_FLAGS_APPEND", + "HIPCC_LINK_FLAGS_APPEND", + ]: + if get_host_environ(repository_ctx, name): + hipcc_env = (hipcc_env + " " + name + "=" + + get_host_environ(repository_ctx, name)) - return inc_dirs + return hipcc_env.strip() def _enable_rocm(repository_ctx): enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM") @@ -192,42 +154,6 @@ def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin): auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target) return amdgpu_targets -def _hipcc_env(repository_ctx): - """Returns the environment variable string for hipcc. - - Args: - repository_ctx: The repository context. - - Returns: - A string containing environment variables for hipcc. - """ - hipcc_env = "" - for name in [ - "HIP_CLANG_PATH", - "DEVICE_LIB_PATH", - "HIP_VDI_HOME", - "HIPCC_VERBOSE", - "HIPCC_COMPILE_FLAGS_APPEND", - "HIPPCC_LINK_FLAGS_APPEND", - "HCC_AMDGPU_TARGET", - "HIP_PLATFORM", - ]: - env_value = get_host_environ(repository_ctx, name) - if env_value: - hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";") - return hipcc_env.strip() - -def _crosstool_verbose(repository_ctx): - """Returns the environment variable value CROSSTOOL_VERBOSE. - - Args: - repository_ctx: The repository context. - - Returns: - A string containing value of environment variable CROSSTOOL_VERBOSE. - """ - return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0") - def _lib_name(lib, version = "", static = False): """Constructs the name of a library on Linux. @@ -451,7 +377,7 @@ def _create_dummy_repository(repository_ctx): repository_ctx, "rocm:rocm_config.h", { - "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, + "%{rocm_toolkit_path}": "/opt/rocm", "%{hipblaslt_flag}": "0", }, "rocm/rocm_config/rocm_config.h", @@ -473,108 +399,31 @@ def _norm_path(path): path = path[:-1] return path -def _flag_enabled(repository_ctx, flag_name): - return get_host_environ(repository_ctx, flag_name) == "1" - -def _use_rocm_clang(repository_ctx): - # Returns the flag if we need to use clang for the host. - return _flag_enabled(repository_ctx, "TF_ROCM_CLANG") - -def _tf_sysroot(repository_ctx): - return get_host_environ(repository_ctx, _TF_SYSROOT, "") - def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): amdgpu_target_flags = ["--offload-arch=" + amdgpu_target for amdgpu_target in amdgpu_targets] return str(amdgpu_target_flags) -def _canonical_path(p): - parts = [x for x in p.split("/") if x != ""] - return paths.join(*parts) - -def _get_file_name(url): - last_slash_index = url.rfind("/") - return url[last_slash_index + 1:] - -def _download_package(repository_ctx, pkg): - file_name = _get_file_name(pkg["url"]) - - print("Downloading {}".format(pkg["url"])) - repository_ctx.report_progress("Downloading and extracting {}, expected hash is {}".format(pkg["url"], pkg["sha256"])) # buildifier: disable=print - repository_ctx.download_and_extract( - url = pkg["url"], - output = _DISTRIBUTION_PATH, - sha256 = pkg["sha256"], - type = "zip" if pkg["url"].endswith(".whl") else "", - ) - - if pkg.get("sub_package", None): - repository_ctx.report_progress("Extracting {}".format(pkg["sub_package"])) # buildifier: disable=print - repository_ctx.extract( - archive = "{}/{}".format(_DISTRIBUTION_PATH, pkg["sub_package"]), - output = _DISTRIBUTION_PATH, - ) - - repository_ctx.delete(file_name) - def _remove_root_dir(path, root_dir): if path.startswith(root_dir + "/"): return path[len(root_dir) + 1:] return path -def _setup_rocm_distro_dir_impl(repository_ctx, rocm_distro): - repository_ctx.file("rocm/.index") - for pkg in rocm_distro.packages: - _download_package(repository_ctx, pkg) - - for entry in rocm_distro.required_softlinks: - repository_ctx.symlink( - "{}/{}".format(_DISTRIBUTION_PATH, entry.target), - "{}/{}".format(_DISTRIBUTION_PATH, entry.link), - ) - bash_bin = get_bash_bin(repository_ctx) - return _get_rocm_config(repository_ctx, bash_bin, _canonical_path("{}/{}".format(_DISTRIBUTION_PATH, rocm_distro.rocm_root)), "") - def _setup_rocm_distro_dir(repository_ctx): """Sets up the rocm hermetic installation directory to be used in hermetic build""" bash_bin = get_bash_bin(repository_ctx) - rocm_distro_url = repository_ctx.os.environ.get(_ROCM_DISTRO_URL) - if rocm_distro_url: - rocm_distro_hash = repository_ctx.os.environ.get(_ROCM_DISTRO_HASH) - if not rocm_distro_hash: - fail("{} environment variable is required".format(_ROCM_DISTRO_HASH)) - rocm_distro_links = repository_ctx.os.environ.get(_ROCM_DISTRO_LINKS, "") - rocm_distro = create_rocm_distro(rocm_distro_url, rocm_distro_hash, rocm_distro_links) - return _setup_rocm_distro_dir_impl(repository_ctx, rocm_distro) - - rocm_distro = repository_ctx.os.environ.get(_ROCM_DISTRO_VERSION) - if rocm_distro: - return _setup_rocm_distro_dir_impl(repository_ctx, rocm_redist[rocm_distro]) - - multiple_paths = repository_ctx.os.environ.get(_TF_ROCM_MULTIPLE_PATHS) - if multiple_paths: - paths_list = multiple_paths.split(":") - for rocm_custom_path in paths_list: - cmd = "find " + rocm_custom_path + "/* \\( -type f -o -type l \\)" - result = execute(repository_ctx, [bash_bin, "-c", cmd]).stdout.strip().split("\n") - for file_path in result: - relative_path = file_path[len(rocm_custom_path):] - symlink_path = _DISTRIBUTION_PATH + relative_path - if files_exist(repository_ctx, [symlink_path], bash_bin)[0]: - fail("File already present: " + relative_path) - else: - repository_ctx.symlink(file_path, symlink_path) - llvm_path = repository_ctx.os.environ.get(_LLVM_PATH) - if llvm_path: - repository_ctx.symlink(llvm_path, _DISTRIBUTION_PATH + "/llvm") - repository_ctx.symlink(llvm_path, _DISTRIBUTION_PATH + "/lib/llvm") - repository_ctx.symlink(llvm_path + "/amdgcn", _DISTRIBUTION_PATH + "/amdgcn") - return _get_rocm_config(repository_ctx, bash_bin, _DISTRIBUTION_PATH, _DISTRIBUTION_PATH) - else: - rocm_path = repository_ctx.os.environ.get(_ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) - repository_ctx.report_progress("Using local rocm installation {}".format(rocm_path)) # buildifier: disable=print - repository_ctx.symlink(rocm_path, _DISTRIBUTION_PATH) - return _get_rocm_config(repository_ctx, bash_bin, _DISTRIBUTION_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + + # Use ROCm dist directory from hipcc_configure repository + rocm_dist_label = repository_ctx.attr.rocm_dist + if not rocm_dist_label: + fail("rocm_dist attribute is required. " + + "Set it to @config_rocm_hipcc//rocm:rocm_dist") + + # Directly get the path to rocm_dist directory (exported via exports_files) + hipcc_rocm_path = repository_ctx.path(rocm_dist_label) + repository_ctx.report_progress("Using ROCm from: {}".format(hipcc_rocm_path)) + repository_ctx.symlink(hipcc_rocm_path, _DISTRIBUTION_PATH) + return _get_rocm_config(repository_ctx, bash_bin, _DISTRIBUTION_PATH, str(hipcc_rocm_path)) def _create_local_rocm_repository(repository_ctx): """Creates the repository containing files set up to build with ROCm.""" @@ -634,19 +483,6 @@ def _create_local_rocm_repository(repository_ctx): "%{rocm_repo_name}": repository_ctx.name, } - is_rocm_clang = _use_rocm_clang(repository_ctx) - tf_sysroot = _tf_sysroot(repository_ctx) - - multiple_paths = repository_ctx.os.environ.get(_TF_ROCM_MULTIPLE_PATHS) - if multiple_paths: - paths_list = multiple_paths.split(":") - rocm_lib_paths = [] - for rocm_custom_path in paths_list: - lib_path = rocm_custom_path + "/lib/" - if files_exist(repository_ctx, [lib_path], bash_bin)[0] and not lib_path in rocm_lib_paths: - rocm_lib_paths.append(lib_path) - repository_dict["%{rocm_lib_paths}"] = ":".join(rocm_lib_paths) - repository_ctx.template( "rocm/BUILD", tpl_paths["rocm:BUILD"], @@ -654,27 +490,13 @@ def _create_local_rocm_repository(repository_ctx): ) # Set up crosstool/ - cc = find_cc(repository_ctx) + rocm_defines = {} + rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/usr/bin" - host_compiler_includes = get_cxx_inc_directories( - repository_ctx, - cc, - tf_sysroot, + rocm_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( + _rocm_include_path(repository_ctx, rocm_config, bash_bin), ) - # host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) - - rocm_defines = {} - rocm_defines["%{builtin_sysroot}"] = tf_sysroot - rocm_defines["%{compiler}"] = "clang" - host_compiler_prefix = "/usr/bin" - rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix - rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + host_compiler_prefix - rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "" - rocm_defines["%{unfiltered_compile_flags}"] = "" - rocm_defines["%{rocm_hipcc_files}"] = "[]" - rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-no-canonical-prefixes\"" - rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ "-DTENSORFLOW_USE_ROCM=1", "-D__HIP_PLATFORM_AMD__", @@ -682,12 +504,9 @@ def _create_local_rocm_repository(repository_ctx): "-DUSE_ROCM", ]) + # Use wrapper as the host compiler path for the toolchain rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" - rocm_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( - host_compiler_includes + _rocm_include_path(repository_ctx, rocm_config, bash_bin), - ) - verify_build_defines(rocm_defines) # Only expand template variables in the BUILD file @@ -709,17 +528,10 @@ def _create_local_rocm_repository(repository_ctx): "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_rocm"], { - "%{cpu_compiler}": str(cc), - "%{compiler_is_clang}": "True", - "%{rocm_root}": "external/" + repository_ctx.name + "/" + str(rocm_config.rocm_toolkit_path), + "%{rocm_root}": rocm_toolkit_path, "%{hipcc_env}": _hipcc_env(repository_ctx), "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_library}": "amdhip64", - "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), - "%{gcc_host_compiler_path}": str(cc), - "%{rocm_amdgpu_targets}": ",".join( - ["\"%s\"" % c for c in rocm_config.amdgpu_targets], - ), + "%{crosstool_verbose}": "0", "%{tmpdir}": get_host_environ( repository_ctx, _TMPDIR, @@ -822,19 +634,12 @@ def _rocm_autoconf_impl(repository_ctx): _ENVIRONS = [ "TF_NEED_ROCM", - "TF_ROCM_CLANG", "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro - _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, - _ROCM_DISTRO_VERSION, - _ROCM_DISTRO_URL, - _ROCM_DISTRO_HASH, - _ROCM_DISTRO_LINKS, _TF_ROCM_RBE_DOCKER_IMAGE, _TF_ROCM_RBE_POOL, _TF_ROCM_RBE_SINGLE_GPU_POOL, _TF_ROCM_RBE_MULTI_GPU_POOL, - _TF_ROCM_MULTIPLE_PATHS, ] remote_rocm_configure = repository_rule( @@ -843,6 +648,10 @@ remote_rocm_configure = repository_rule( remotable = True, attrs = { "environ": attr.string_dict(), + "rocm_dist": attr.label( + doc = "Label to the rocm_dist directory from hipcc_configure " + + "(e.g. @config_rocm_hipcc//rocm:rocm_dist).", + ), "_find_rocm_config": attr.label( default = Label("//third_party/gpus:find_rocm_config.py"), ), @@ -853,6 +662,10 @@ rocm_configure = repository_rule( implementation = _rocm_autoconf_impl, environ = _ENVIRONS + [_TF_ROCM_CONFIG_REPO], attrs = { + "rocm_dist": attr.label( + doc = "Label to the rocm_dist directory from hipcc_configure " + + "(e.g. @config_rocm_hipcc//rocm:rocm_dist).", + ), "_find_rocm_config": attr.label( default = Label("//third_party/gpus:find_rocm_config.py"), ), diff --git a/workspace2.bzl b/workspace2.bzl index fdf77d1345104..ffc9d18b54e27 100644 --- a/workspace2.bzl +++ b/workspace2.bzl @@ -5,10 +5,12 @@ load("@bazel_skylib//lib:versions.bzl", "versions") load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") +load("@rules_ml_toolchain//cc/llvms/local:local_clang_configure.bzl", "local_clang_configure") load("@rules_ml_toolchain//gpu/rocm:hipcc_configure.bzl", "hipcc_configure") load("@rules_ml_toolchain//gpu/sycl:sycl_configure.bzl", "sycl_configure") load("@rules_ml_toolchain//gpu/sycl:sycl_init_repository.bzl", "sycl_init_repository") load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") +load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/absl:workspace.bzl", absl = "repo") load("//third_party/benchmark:workspace.bzl", benchmark = "repo") load("//third_party/brotli:workspace.bzl", brotli = "repo") @@ -26,7 +28,6 @@ load("//third_party/FP16:workspace.bzl", FP16 = "repo") load("//third_party/fxdiv:workspace.bzl", fxdiv = "repo") load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo") load("//third_party/gloo:workspace.bzl", gloo = "repo") -load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/gutil:workspace.bzl", gutil = "repo") load("//third_party/highway:workspace.bzl", highway = "repo") load("//third_party/highwayhash:workspace.bzl", highwayhash = "repo") @@ -140,7 +141,11 @@ def _tf_toolchains(): tensorrt_configure(name = "local_config_tensorrt") python_configure(name = "local_config_python") hipcc_configure(name = "config_rocm_hipcc") # Must be before rocm_configure. - rocm_configure(name = "local_config_rocm") + rocm_configure( + name = "local_config_rocm", + rocm_dist = "@config_rocm_hipcc//rocm:rocm_dist", + ) + local_clang_configure(name = "local_config_clang") sycl_init_repository() sycl_configure(name = "local_config_sycl") remote_execution_configure(name = "local_config_remote_execution") diff --git a/workspace3.bzl b/workspace3.bzl index c8baec092b1ba..3a1f80aff2dbb 100644 --- a/workspace3.bzl +++ b/workspace3.bzl @@ -50,10 +50,10 @@ def workspace(): # Details: https://github.com/google-ml-infra/rules_ml_toolchain tf_http_archive( name = "rules_ml_toolchain", - sha256 = "0b42f693a60c6050d87db1e0a0eaeb84ab3f54191fce094d86334faedc807da0", - strip_prefix = "rules_ml_toolchain-398d613aea7a4c294da49b79a6d6f3f8732bd84c", + sha256 = "67e4bdf9286c8b2ff9f07e4fa7a17c35fd50be1ed00a658a37e8f52dd64df76c", + strip_prefix = "rules_ml_toolchain-e0c44e92c4e03de3436ccd01ab7db54be7cdc9f1", urls = tf_mirror_urls( - "https://github.com/google-ml-infra/rules_ml_toolchain/archive/398d613aea7a4c294da49b79a6d6f3f8732bd84c.tar.gz", + "https://github.com/google-ml-infra/rules_ml_toolchain/archive/e0c44e92c4e03de3436ccd01ab7db54be7cdc9f1.tar.gz", ), ) diff --git a/xla/service/gpu/llvm_gpu_backend/BUILD b/xla/service/gpu/llvm_gpu_backend/BUILD index 1edcbe50aaf9b..ee3fd7d401737 100644 --- a/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/xla/service/gpu/llvm_gpu_backend/BUILD @@ -296,6 +296,7 @@ xla_cc_test( srcs = ["amdgpu_bitcode_link_test.cc"], data = [ "tests_data/amdgpu.ll", + "@local_config_rocm//rocm:amdgcn_bitcode", ], tags = if_google([ # Embedded libdevice is required for this test, but not supported in the Google-internal build. diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index b85a27ecd027b..f647d6d0630eb 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -954,10 +954,8 @@ cc_library( rocm_library( name = "buffer_comparator_kernel_rocm", - srcs = [ - "buffer_comparator_kernel_rocm.cu.cc", - "//xla/stream_executor/gpu:buffer_comparator_kernel_lib.cu.h", - ], + srcs = ["buffer_comparator_kernel_rocm.cu.cc"], + hdrs = ["//xla/stream_executor/gpu:buffer_comparator_kernel_lib.cu.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], tags = [ "gpu", @@ -999,10 +997,8 @@ rocm_library( rocm_library( name = "ragged_all_to_all_kernel_rocm", - srcs = [ - "ragged_all_to_all_kernel_rocm.cc", - "//xla/stream_executor/gpu:ragged_all_to_all_kernel_lib.cu.h", - ], + srcs = ["ragged_all_to_all_kernel_rocm.cc"], + hdrs = ["//xla/stream_executor/gpu:ragged_all_to_all_kernel_lib.cu.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], tags = [ "gpu", @@ -1085,10 +1081,8 @@ cc_library( # CUB scan kernel implementation (hipCUB/rocPRIM) - compiled with hipcc rocm_library( name = "cub_scan_kernel_rocm_impl", - srcs = [ - "cub_scan_kernel_rocm.h", - "cub_scan_kernel_rocm_impl.cu.cc", - ], + srcs = ["cub_scan_kernel_rocm_impl.cu.cc"], + hdrs = ["cub_scan_kernel_rocm.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], tags = [ "gpu", @@ -1159,9 +1153,9 @@ rocm_library( name = "topk_kernel_rocm", srcs = [ "topk_kernel_rocm_bfloat16.cu.cc", - "topk_kernel_rocm_common.cu.h", "topk_kernel_rocm_float.cu.cc", ], + hdrs = ["topk_kernel_rocm_common.cu.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], tags = [ "gpu", @@ -1180,10 +1174,8 @@ rocm_library( rocm_library( name = "repeat_buffer_kernel_rocm", - srcs = [ - "repeat_buffer_kernel_rocm.cc", - "//xla/stream_executor/gpu:repeat_buffer_kernel.cu.h", - ], + srcs = ["repeat_buffer_kernel_rocm.cc"], + hdrs = ["//xla/stream_executor/gpu:repeat_buffer_kernel.cu.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], tags = [ "gpu", @@ -1200,10 +1192,8 @@ rocm_library( rocm_library( name = "redzone_allocator_kernel_rocm", - srcs = [ - "redzone_allocator_kernel_rocm.cu.cc", - "//xla/stream_executor/gpu:redzone_allocator_kernel_lib.cu.h", - ], + srcs = ["redzone_allocator_kernel_rocm.cu.cc"], + hdrs = ["//xla/stream_executor/gpu:redzone_allocator_kernel_lib.cu.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], tags = [ "gpu", @@ -1221,10 +1211,8 @@ rocm_library( rocm_library( name = "gpu_test_kernels_rocm", testonly = 1, - srcs = [ - "gpu_test_kernels_rocm.cu.cc", - "//xla/stream_executor/gpu:gpu_test_kernels_lib.cu.h", - ], + srcs = ["gpu_test_kernels_rocm.cu.cc"], + hdrs = ["//xla/stream_executor/gpu:gpu_test_kernels_lib.cu.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], linkstatic = True, tags = [ @@ -1260,10 +1248,8 @@ rocm_library( rocm_library( name = "multi_gpu_barrier_kernel_rocm", - srcs = [ - "multi_gpu_barrier_kernel_rocm.cu.cc", - "//xla/stream_executor/gpu:multi_gpu_barrier_kernel.cu.h", - ], + srcs = ["multi_gpu_barrier_kernel_rocm.cu.cc"], + hdrs = ["//xla/stream_executor/gpu:multi_gpu_barrier_kernel.cu.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], tags = [ "gpu", @@ -1283,10 +1269,8 @@ rocm_library( rocm_library( name = "all_reduce_kernel_rocm", - srcs = [ - "all_reduce_kernel_rocm.cc", - "//xla/stream_executor/gpu:all_reduce_kernel_lib.cu.h", - ], + srcs = ["all_reduce_kernel_rocm.cc"], + hdrs = ["//xla/stream_executor/gpu:all_reduce_kernel_lib.cu.h"], # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], tags = [ "gpu", diff --git a/xla/stream_executor/rocm/multi_gpu_barrier_kernel_rocm.cu.cc b/xla/stream_executor/rocm/multi_gpu_barrier_kernel_rocm.cu.cc index 474438cb6f8ba..46f22b3b96f98 100644 --- a/xla/stream_executor/rocm/multi_gpu_barrier_kernel_rocm.cu.cc +++ b/xla/stream_executor/rocm/multi_gpu_barrier_kernel_rocm.cu.cc @@ -16,12 +16,14 @@ limitations under the License. #include #include "absl/base/casts.h" +// clang-format off +#include "xla/stream_executor/rocm/collective_signal_rocm.cu.h" // IWYU pragma: keep +// clang-format on #include "xla/stream_executor/gpu/collective_signal.cu.h" #include "xla/stream_executor/gpu/gpu_kernel_registry.h" #include "xla/stream_executor/gpu/multi_gpu_barrier_kernel.cu.h" #include "xla/stream_executor/gpu/multi_gpu_barrier_kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/rocm/collective_signal_rocm.cu.h" // IWYU pragma: keep #include "xla/stream_executor/rocm/rocm_platform_id.h" namespace stream_executor::gpu {