Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot build XLA with ROCM in the latest versions #58

Closed
costaraphael opened this issue Oct 18, 2023 · 51 comments
Closed

Cannot build XLA with ROCM in the latest versions #58

costaraphael opened this issue Oct 18, 2023 · 51 comments

Comments

@costaraphael
Copy link

Hey folks!

First of all thanks for the great work supporting Elixir's ML ecosystem! <3

I'm trying to set up a demo of GPU distribution at my company, but whenever I try to compile xla 0.5.1 with XLA_BUILD=true and XLA_TARGET=rocm I get the following error:

==> xla
Compiling 2 files (.ex)
Generated xla app
rm -f /root/.cache/xla_extension/xla-b938cfdf2d4e9a5f69c494a316e92638c1a119ef/xla/extension && \
        ln -s "/livebook/test_app/deps/xla/extension" /root/.cache/xla_extension/xla-b938cfdf2d4e9a5f69c494a316e92638c1a119ef/xla/extension && \
        cd /root/.cache/xla_extension/xla-b938cfdf2d4e9a5f69c494a316e92638c1a119ef && \
        bazel build --define "framework_shared_object=false" -c opt   --config=rocm --action_env=HIP_PLATFORM=hcc //xla/extension:xla_extension && \
        mkdir -p /root/.cache/xla/0.5.1/cache/build/ && \
        cp -f /root/.cache/xla_extension/xla-b938cfdf2d4e9a5f69c494a316e92638c1a119ef/bazel-bin/xla/extension/xla_extension.tar.gz /root/.cache/xla/0.5.1/cache/build/xla_extension-x86_64-linux-gnu-rocm.tar.gz
ERROR: Config value 'rocm' is not defined in any .rc file
make: *** [Makefile:26: /root/.cache/xla/0.5.1/cache/build/xla_extension-x86_64-linux-gnu-rocm.tar.gz] Error 2
could not compile dependency :xla, "mix compile" failed. Errors may have been logged above. You can recompile this dependency with "mix deps.compile xla --force", update it with "mix deps.update xla" or clean it with "mix deps.clean xla"
==> test_app
** (Mix) Could not compile with "make" (exit status: 2).
You need to have gcc and make installed. If you are using
Ubuntu or any other Debian-based system, install the packages
"build-essential". Also install "erlang-dev" package if not
included in your Erlang/OTP version. If you're on Fedora, run
"dnf group install 'Development Tools'".

I did some digging, and I think it is because the openxla version the Makefile is pointing to doesn't have the rocm configuration:

The rocm config was apparently added back in this commit: openxla/xla@98b6197.

I tried forcing XLA to download the latest passing build of openxla (openxla/xla@dba73eb) by editing the Makefile locally, but it failed with the following log:

Starting local Bazel server and connecting to it...
INFO: Reading 'startup' options from /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --windows_enable_symlinks
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'build' from /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'build' from /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc:
  'build' options: --define framework_shared_object=true --define tsl_protobuf_header_only=true --define=use_fast_cpp_protos=true --define=allow_oversize_protos=true --spawn_strategy=standalone -c opt --announce_rc --define=grpc_no_ares=true --noincompatible_remove_legacy_whole_archive --features=-force_no_whole_archive --enable_platform_specific_config --define=with_xla_support=true --config=short_logs --config=v2 --define=no_aws_support=true --define=no_hdfs_support=true --experimental_cc_shared_library --experimental_link_static_libraries_once=false --incompatible_enforce_config_setting_visibility
INFO: Found applicable config definition build:short_logs in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:v2 in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1
INFO: Found applicable config definition build:rocm in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --crosstool_top=@local_config_rocm//crosstool:toolchain --define=using_rocm_hipcc=true --define=tensorflow_mkldnn_contraction_kernel=0 --repo_env TF_NEED_ROCM=1 --config=no_tfrt
INFO: Found applicable config definition build:no_tfrt in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/runtime_fallback/test,tensorflow/core/runtime_fallback/test/gpu,tensorflow/core/runtime_fallback/test/saved_model,tensorflow/core/runtime_fallback/test/testdata,tensorflow/core/tfrt/stubs,tensorflow/core/tfrt/tfrt_session,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python,tensorflow/core/tfrt/saved_model/utils
INFO: Found applicable config definition build:linux in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --host_copt=-w --copt=-Wno-all --copt=-Wno-extra --copt=-Wno-deprecated --copt=-Wno-deprecated-declarations --copt=-Wno-ignored-attributes --copt=-Wno-array-bounds --copt=-Wunused-result --copt=-Werror=unused-result --copt=-Wswitch --copt=-Werror=switch --copt=-Wno-error=unused-but-set-variable --define=PREFIX=/usr --define=LIBDIR=$(PREFIX)/lib --define=INCLUDEDIR=$(PREFIX)/include --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --config=dynamic_kernels --experimental_guard_against_concurrent_changes
INFO: Found applicable config definition build:dynamic_kernels in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --define=dynamic_loaded_kernels=true --copt=-DAUTOLOAD_DYNAMIC_KERNELS
Loading: 
DEBUG: /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/third_party/repo.bzl:132:14: 
Warning: skipping import of repository 'tf_runtime' because it already exists.
DEBUG: /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/third_party/repo.bzl:132:14: 
Warning: skipping import of repository 'llvm-raw' because it already exists.
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 0 packages loaded
Loading: 0 packages loaded
    currently loading: xla/extension
INFO: Repository local_config_rocm instantiated at:
  /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/WORKSPACE:19:15: in <toplevel>
  /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/workspace2.bzl:90:19: in workspace
  /root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/workspace2.bzl:624:19: in workspace
  /root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/workspace2.bzl:78:19: in _tf_toolchains
Repository rule rocm_configure defined at:
  /root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl:832:33: in <toplevel>
Loading: 0 packages loaded
    currently loading: xla/extension
ERROR: An error occurred during the fetch of repository 'local_config_rocm':
   Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl", line 810, column 38, in _rocm_autoconf_impl
                _create_local_rocm_repository(repository_ctx)
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl", line 546, column 35, in _create_local_rocm_repository
                rocm_config = _get_rocm_config(repository_ctx, bash_bin)
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl", line 393, column 30, in _get_rocm_config
                config = find_rocm_config(repository_ctx)
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl", line 371, column 26, in find_rocm_config
                exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config])
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/remote_config/common.bzl", line 230, column 13, in execute
                fail(
Error in fail: Repository command failed
ERROR: MIOpen version file "None" not found
ERROR: /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/WORKSPACE:19:15: fetching rocm_configure rule //external:local_config_rocm: Traceback (most recent call last):
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl", line 810, column 38, in _rocm_autoconf_impl
                _create_local_rocm_repository(repository_ctx)
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl", line 546, column 35, in _create_local_rocm_repository
                rocm_config = _get_rocm_config(repository_ctx, bash_bin)
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl", line 393, column 30, in _get_rocm_config
                config = find_rocm_config(repository_ctx)
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/gpus/rocm_configure.bzl", line 371, column 26, in find_rocm_config
                exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config])
        File "/root/.cache/bazel/_bazel_root/d366f579e16ea48eba76171a9c9ace02/external/tsl/third_party/remote_config/common.bzl", line 230, column 13, in execute
                fail(
Error in fail: Repository command failed
ERROR: MIOpen version file "None" not found
ERROR: Skipping '//xla/extension:xla_extension': no such package '@local_config_rocm//rocm': Repository command failed
ERROR: MIOpen version file "None" not found
WARNING: Target pattern parsing failed.
ERROR: no such package '@local_config_rocm//rocm': Repository command failed
ERROR: MIOpen version file "None" not found
INFO: Elapsed time: 39.473s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (0 packages loaded)
make: *** [Makefile:26: /root/.cache/xla/0.5.1/cache/build/xla_extension-x86_64-linux-gnu-rocm.tar.gz] Error 1

So my guess is that it's not going to be as simple as just pointing to the new version 😅

I'd love to help get this sorted, but I'll need some pointers of where to start looking.

@jonatanklosko
Copy link
Member

Hey @costaraphael, here are some notes from Jax, perhaps you are missing some of these packages, miopen-hip looks relevant?

There are likely many changes on XLA main that we would need to account for to build, but let's see if we can address that specific error first.

FTR the ROCm experience is rough, because we don't have resources to properly test it and we generally try to address issues as people try using it.

@costaraphael
Copy link
Author

Hey @jonatanklosko! Thanks for taking a look at this!

Yeah, I was having another look at this yesterday and indeed the second error I have up there seems to be a problem with my setup. I didn't have enough time to attempt fixing it but I will at some point this week and I'll come back here with the results.

Thanks again for the Jax material you linked, if it works it might save me a couple of hours trying to understand the Bazel setup XLA has 😄

FTR the ROCm experience is rough, because we don't have resources to properly test it and we generally try to address issues as people try using it.

I expected this to be the case, so no worries! Apparently, ROCm is being a pain in all ML ecosystems, even in Python 😅

@costaraphael
Copy link
Author

@jonatanklosko quick question, when you mention lack of resources for testing, do you mean actual GPUs? Or people? Or both? 😅

@jonatanklosko
Copy link
Member

I'd say both. XLA is still evolving quickly, so whenever we update there are changes and we test those across devices/GPUs. Currently most people use Nvidia, so that's our primary focus, we are a small group of people and ROCm is not high priority, but reports/contributions in this regard are welcome :)

@costaraphael
Copy link
Author

costaraphael commented Oct 27, 2023

@jonatanklosko I had a look at this, and I was able to move forward a bit more:

I was able to get past all the checks for ROCm libraries. I installed everything in the Jax material you linked, which the slight additions that I also had to install rocblas, rocsolver and the dev headers for miopen. If anyone is finding this,
the final apt command looks like:

sudo apt install miopen-hip-dev hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
    rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs \
    rocsolver-dev rocblas-dev

The bad news is that the problem still seems to be on Bazeland. Building with:

OPENXLA_GIT_REV=dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57 XLA_BUILD=true XLA_TARGET=rocm mix deps.compile

yields

===> Analyzing applications...
===> Compiling telemetry
==> xla
Compiling 2 files (.ex)
Generated xla app
rm -f /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/extension && \
    ln -s "/tests/my_app/deps/xla/extension" /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/extension && \
    cd /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57 && \
    bazel build --define "framework_shared_object=false" -c opt   --config=rocm --action_env=HIP_PLATFORM=hcc //xla/extension:xla_extension && \
    mkdir -p /root/.cache/xla/0.5.1/cache/build/ && \
    cp -f /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/bazel-bin/xla/extension/xla_extension.tar.gz /root/.cache/xla/0.5.1/cache/build/xla_extension-x86_64-linux-gnu-rocm.tar.gz
INFO: Reading 'startup' options from /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --windows_enable_symlinks
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'build' from /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'build' from /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc:
  'build' options: --define framework_shared_object=true --define tsl_protobuf_header_only=true --define=use_fast_cpp_protos=true --define=allow_oversize_protos=true --spawn_strategy=standalone -c opt --announce_rc --define=grpc_no_ares=true --noincompatible_remove_legacy_whole_archive --features=-force_no_whole_archive --enable_platform_specific_config --define=with_xla_support=true --config=short_logs --config=v2 --define=no_aws_support=true --define=no_hdfs_support=true --experimental_cc_shared_library --experimental_link_static_libraries_once=false --incompatible_enforce_config_setting_visibility
INFO: Found applicable config definition build:short_logs in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:v2 in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1
INFO: Found applicable config definition build:rocm in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --crosstool_top=@local_config_rocm//crosstool:toolchain --define=using_rocm_hipcc=true --define=tensorflow_mkldnn_contraction_kernel=0 --repo_env TF_NEED_ROCM=1 --config=no_tfrt
INFO: Found applicable config definition build:no_tfrt in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/ir,tensorflow/compiler/mlir/tfrt/ir/mlrt,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/mlrt,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/compiler/mlir/tfrt/transforms/mlrt,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/runtime_fallback/test,tensorflow/core/runtime_fallback/test/gpu,tensorflow/core/runtime_fallback/test/saved_model,tensorflow/core/runtime_fallback/test/testdata,tensorflow/core/tfrt/stubs,tensorflow/core/tfrt/tfrt_session,tensorflow/core/tfrt/mlrt,tensorflow/core/tfrt/mlrt/attribute,tensorflow/core/tfrt/mlrt/kernel,tensorflow/core/tfrt/mlrt/bytecode,tensorflow/core/tfrt/mlrt/interpreter,tensorflow/compiler/mlir/tfrt/translate/mlrt,tensorflow/compiler/mlir/tfrt/translate/mlrt/testdata,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug,tensorflow/core/tfrt/saved_model/python,tensorflow/core/tfrt/graph_executor/python,tensorflow/core/tfrt/saved_model/utils
INFO: Found applicable config definition build:linux in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --host_copt=-w --copt=-Wno-all --copt=-Wno-extra --copt=-Wno-deprecated --copt=-Wno-deprecated-declarations --copt=-Wno-ignored-attributes --copt=-Wno-array-bounds --copt=-Wunused-result --copt=-Werror=unused-result --copt=-Wswitch --copt=-Werror=switch --copt=-Wno-error=unused-but-set-variable --define=PREFIX=/usr --define=LIBDIR=$(PREFIX)/lib --define=INCLUDEDIR=$(PREFIX)/include --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --config=dynamic_kernels --experimental_guard_against_concurrent_changes
INFO: Found applicable config definition build:dynamic_kernels in file /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/.bazelrc: --define=dynamic_loaded_kernels=true --copt=-DAUTOLOAD_DYNAMIC_KERNELS
Loading: 
Loading: 
Loading: 0 packages loaded
Analyzing: target //xla/extension:xla_extension (1 packages loaded, 0 targets configured)
Analyzing: target //xla/extension:xla_extension (34 packages loaded, 10 targets configured)
Analyzing: target //xla/extension:xla_extension (34 packages loaded, 10 targets configured)
Analyzing: target //xla/extension:xla_extension (59 packages loaded, 206 targets configured)
ERROR: /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/extension/BUILD:12:10: no such target '//xla/service:memory_space_assignment_proto_cc_impl': target 'memory_space_assignment_proto_cc_impl' not declared in package 'xla/service' defined by /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/service/BUILD (Tip: use `query "//xla/service:*"` to see all the targets in that package) and referenced by '//xla/extension:libxla_extension.so'
ERROR: /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/extension/BUILD:12:10: no such target '//xla/service/gpu:hlo_op_profile_proto_cc_impl': target 'hlo_op_profile_proto_cc_impl' not declared in package 'xla/service/gpu' defined by /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/service/gpu/BUILD (Tip: use `query "//xla/service/gpu:*"` to see all the targets in that package) and referenced by '//xla/extension:libxla_extension.so'
ERROR: /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/extension/BUILD:12:10: no such target '//xla/stream_executor:dnn_proto_cc_impl': target 'dnn_proto_cc_impl' not declared in package 'xla/stream_executor' defined by /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/stream_executor/BUILD (Tip: use `query "//xla/stream_executor:*"` to see all the targets in that package) and referenced by '//xla/extension:libxla_extension.so'
ERROR: /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/extension/BUILD:12:10: no such target '//xla/pjrt:tpu_client': target 'tpu_client' not declared in package 'xla/pjrt' defined by /root/.cache/xla_extension/xla-dba73eb7c7c6dbc589f3fe3334cabcbdebd53e57/xla/pjrt/BUILD (did you mean 'pjrt_client'? Tip: use `query "//xla/pjrt:*"` to see all the targets in that package) and referenced by '//xla/extension:libxla_extension.so'
ERROR: Analysis of target '//xla/extension:xla_extension' failed; build aborted: 
INFO: Elapsed time: 6.019s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (68 packages loaded, 206 targets configured)
make: *** [Makefile:26: /root/.cache/xla/0.5.1/cache/build/xla_extension-x86_64-linux-gnu-rocm.tar.gz] Error 1

I had a quick look, and it seems like this is because the https://github.com/elixir-nx/xla/blob/cf9753eb4c70312ab4b195cca3f568779e731c59/extension/BUILD file is pointing to stuff that used to exist but doesn't anymore.

@jonatanklosko
Copy link
Member

@costaraphael you can try building on the jk-bump branch. Unfortunately the new XLA revision fails to build for me locally on M1, so I need to figure this out first in order to check if we also need changes in EXLA.

@costaraphael
Copy link
Author

Hey @jonatanklosko thank you so much for looking into this! I've tried building it like:

Mix.install(
  [
    {:xla, github: "elixir-nx/xla", branch: "jk-bump"}
  ],
  system_env: %{
    "XLA_BUILD" => "true",
    "XLA_TARGET" => "rocm"
  }
)

It initially gave me an error message complaining about hipcub and rocprim headers not being there (which they weren't).

After installing hipcub-dev and rocprim-dev it started compiling just fine, until it failed with the following error:

[1,905 / 6,472] Compiling xla/service/hlo_rematerialization.cc; 36s local ... (16 actions running)
[1,906 / 6,472] Compiling xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc; 28s local ... (16 actions, 15 running)
[1,908 / 6,472] Compiling xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc; 29s local ... (16 actions, 15 running)
[1,912 / 6,472] Compiling xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc; 30s local ... (15 actions, 14 running)
[1,912 / 6,472] Compiling xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc; 32s local ... (16 actions running)
[1,913 / 6,472] Compiling xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc; 33s local ... (16 actions running)
[1,916 / 6,472] Compiling xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc; 34s local ... (16 actions, 15 running)
[1,918 / 6,472] Compiling xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc; 35s local ... (16 actions, 15 running)
ERROR: /root/.cache/xla_extension/xla-cb7121e143cc52fe931f3e724bc7fa115363b4b6/xla/stream_executor/rocm/BUILD:509:11: Compiling xla/stream_executor/rocm/rocm_helpers.cu.cc failed: undeclared inclusion(s) in rule '//xla/stream_executor/rocm:rocm_helpers':
this rule is missing dependency declarations for the following files included by 'xla/stream_executor/rocm/rocm_helpers.cu.cc':
  '/opt/rocm-5.7.1/include/hip/hip_version.h'
  '/opt/rocm-5.7.1/include/hip/hip_runtime.h'
  '/opt/rocm-5.7.1/include/hip/hip_common.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_hip_runtime.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_hip_common.h'
  '/opt/rocm-5.7.1/include/hip/hip_runtime_api.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/host_defines.h'
  '/opt/rocm-5.7.1/include/hip/driver_types.h'
  '/opt/rocm-5.7.1/include/hip/texture_types.h'
  '/opt/rocm-5.7.1/include/hip/channel_descriptor.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_channel_descriptor.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_hip_vector_types.h'
  '/opt/rocm-5.7.1/include/hip/surface_types.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_hip_runtime_pt_api.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/hip_ldg.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_hip_atomic.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_device_functions.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/math_fwd.h'
  '/opt/rocm-5.7.1/include/hip/hip_vector_types.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/device_library_decls.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_warp_functions.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_hip_unsafe_atomics.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_surface_functions.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/texture_fetch_functions.h'
  '/opt/rocm-5.7.1/include/hip/hip_texture_types.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/ockl_image.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/texture_indirect_functions.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_math_functions.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/hip_fp16_math_fwd.h'
  '/opt/rocm-5.7.1/include/hip/library_types.h'
  '/opt/rocm-5.7.1/include/hip/hip_bfloat16.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_hip_bfloat16.h'
  '/opt/rocm-5.7.1/include/hip/hip_fp16.h'
  '/opt/rocm-5.7.1/include/hip/amd_detail/amd_hip_fp16.h'
Warning: HIP_PLATFORM=hcc is deprecated. Please use HIP_PLATFORM=amd. 
clang++: warning: argument unused during compilation: '-fcuda-flush-denormals-to-zero' [-Wunused-command-line-argument]
Warning: HIP_PLATFORM=hcc is deprecated. Please use HIP_PLATFORM=amd. 
Target //xla/extension:xla_extension failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 538.439s, Critical Path: 101.57s
INFO: 1936 processes: 442 internal, 1494 local.
FAILED: Build did NOT complete successfully
make: *** [Makefile:27: /root/.cache/xla/0.5.1/cache/build/xla_extension-x86_64-linux-gnu-rocm.tar.gz] Error 1

@jonatanklosko
Copy link
Member

Hmm, this again looks more environment specific (not bazel). Looks similar to tensorflow/tensorflow#61354, the suggestion is to try Clang 16.

Are these header files present in your file system? If not, you can also try locating hip_version.h, perhaps it's elsewhere.

@costaraphael
Copy link
Author

@jonatanklosko I've checked the environment, all the files listed there do exist. What I think the problem was is that those specific header files were being included (directly or indirectly) by xla/stream_executor/rocm/rocm_helpers.cu.cc and were not listed as part of a cc_library in Bazel.

I suspect this is occurring because the setup looks by default into the /opt/rocm folder for the ROCm headers, but that folder is a symlink to /opt/rocm-5.7.1 via /etc/alternatives (to support multiple versions installed, I suppose). So all the header files are listed via the symlink by Bazel, but at some point are being included through the actual folder.

So in other words, Bazel knows about /opt/rocm/include/hip/hip_version.h, but knows nothing about /opt/rocm-5.7.1/include/hip/hip_version.h, even though they are the same file.

I tried pointing to the actual ROCm folder using ROCM_PATH=/opt/rocm-5.7.1 (they made it configurable here). This gave me a different error though:

[930 / 4,745] Generating code from table: lib/Target/AMDGPU/AMDGPU.td @llvm-project//llvm:AMDGPUCommonTableGen__gen_asm_matcher_genrule; 27s local ... (16 actions, 15 running)
[941 / 4,745] Generating code from table: lib/Target/AMDGPU/AMDGPU.td @llvm-project//llvm:AMDGPUCommonTableGen__gen_asm_matcher_genrule; 28s local ... (16 actions, 15 running)
ERROR: /root/.cache/xla_extension/xla-cb7121e143cc52fe931f3e724bc7fa115363b4b6/xla/service/gpu/BUILD:1138:23: Compiling xla/service/gpu/cub_sort_kernel.cu.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //xla/service/gpu:cub_sort_kernel_s32) external/local_config_rocm/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer ... (remaining 93 arguments skipped)
In file included from xla/service/gpu/cub_sort_kernel.cu.cc:16:
./xla/service/gpu/cub_sort_kernel.h:22:10: fatal error: absl/status/status.h: No such file or directory
   22 | #include "absl/status/status.h"
      |          ^~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
Target //xla/extension:xla_extension failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 125.820s, Critical Path: 70.12s
INFO: 870 processes: 19 internal, 851 local.
FAILED: Build did NOT complete successfully
make: *** [Makefile:27: /root/.cache/xla/0.5.1/cache/build/xla_extension-x86_64-linux-gnu-rocm.tar.gz] Error 1

Looking at the target causing the error, I see that the Abseil headers are only being included if CUDA is configured, which is a problem since the source file includes the Abseil headers regardless of if CUDA is configured or not 😅

I manually edited the Bazel definition there to use if_gpu_is_configured instead of if_cuda_is_configured, and the build seemed to progress further, but broke at:

[283 / 3,891] Generating code from table: lib/Target/AMDGPU/AMDGPUGISel.td @llvm-project//llvm:amdgpu_isel_target_gen__gen_global_isel_genrule; 41s local ... (16 actions running)
ERROR: /root/.cache/xla_extension/xla-cb7121e143cc52fe931f3e724bc7fa115363b4b6/xla/service/gpu/BUILD:1138:23: Compiling xla/service/gpu/cub_sort_kernel.cu.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //xla/service/gpu:cub_sort_kernel_u32_b16) external/local_config_rocm/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer ... (remaining 97 arguments skipped)
In file included from xla/service/gpu/cub_sort_kernel.cu.cc:26:
./xla/service/gpu/gpu_prim_rocm.h:31:8: error: 'float_bit_mask' is not a class template
   31 | struct float_bit_mask<Eigen::half> {
      |        ^~~~~~~~~~~~~~
./xla/service/gpu/gpu_prim_rocm.h:31:36: error: explicit specialization of non-template 'rocprim::detail::float_bit_mask'
   31 | struct float_bit_mask<Eigen::half> {
      |                                    ^
./xla/service/gpu/gpu_prim_rocm.h:39:8: error: 'float_bit_mask' is not a class template
   39 | struct float_bit_mask<tsl::bfloat16> {
      |        ^~~~~~~~~~~~~~
./xla/service/gpu/gpu_prim_rocm.h:39:36: error: 'rocprim::detail::float_bit_mask' is not a template
   39 | struct float_bit_mask<tsl::bfloat16> {
      |                                    ^
./xla/service/gpu/gpu_prim_rocm.h:31:8: note: previous declaration here
   31 | struct float_bit_mask<Eigen::half> {
      |        ^~~~~~~~~~~~~~
./xla/service/gpu/gpu_prim_rocm.h:47:8: error: 'radix_key_codec_base' is not a class template
   47 | struct radix_key_codec_base<Eigen::half>
      |        ^~~~~~~~~~~~~~~~~~~~
./xla/service/gpu/gpu_prim_rocm.h:48:31: error: expected template-name before '<' token
   48 |     : radix_key_codec_floating<Eigen::half, uint16_t> {};
      |                               ^
./xla/service/gpu/gpu_prim_rocm.h:48:31: error: expected '{' before '<' token
./xla/service/gpu/gpu_prim_rocm.h:50:8: error: 'radix_key_codec_base' is not a class template
   50 | struct radix_key_codec_base<tsl::bfloat16>
      |        ^~~~~~~~~~~~~~~~~~~~
./xla/service/gpu/gpu_prim_rocm.h:50:42: error: 'rocprim::detail::radix_key_codec_base' is not a template
   50 | struct radix_key_codec_base<tsl::bfloat16>
      |                                          ^
./xla/service/gpu/gpu_prim_rocm.h:47:8: note: previous declaration here
   47 | struct radix_key_codec_base<Eigen::half>
      |        ^~~~~~~~~~~~~~~~~~~~
./xla/service/gpu/gpu_prim_rocm.h:51:31: error: expected template-name before '<' token
   51 |     : radix_key_codec_floating<tsl::bfloat16, uint16_t> {};
      |                               ^
xla/service/gpu/cub_sort_kernel.cu.cc: In function 'absl::lts_20230125::Status xla::gpu::{anonymous}::CubSortKeys(void*, size_t&, const void*, void*, size_t, bool)':
xla/service/gpu/cub_sort_kernel.cu.cc:53:22: error: 'gpuprim::DeviceRadixSort' has not been declared
   53 |           ? gpuprim::DeviceRadixSort::SortKeysDescending<KeyT>(
      |                      ^~~~~~~~~~~~~~~
xla/service/gpu/cub_sort_kernel.cu.cc:53:62: error: expected primary-expression before '>' token
   53 |           ? gpuprim::DeviceRadixSort::SortKeysDescending<KeyT>(
      |                                                              ^
xla/service/gpu/cub_sort_kernel.cu.cc:56:22: error: 'gpuprim::DeviceRadixSort' has not been declared
   56 |           : gpuprim::DeviceRadixSort::SortKeys<KeyT>(
      |                      ^~~~~~~~~~~~~~~
xla/service/gpu/cub_sort_kernel.cu.cc:56:52: error: expected primary-expression before '>' token
   56 |           : gpuprim::DeviceRadixSort::SortKeys<KeyT>(
      |                                                    ^
xla/service/gpu/cub_sort_kernel.cu.cc: In function 'absl::lts_20230125::Status xla::gpu::{anonymous}::CubSortPairs(void*, size_t&, const void*, void*, const void*, void*, size_t, bool)':
xla/service/gpu/cub_sort_kernel.cu.cc:70:22: error: 'gpuprim::DeviceRadixSort' has not been declared
   70 |           ? gpuprim::DeviceRadixSort::SortPairsDescending<KeyT, ValT>(
      |                      ^~~~~~~~~~~~~~~
xla/service/gpu/cub_sort_kernel.cu.cc:70:63: error: expected primary-expression before ',' token
   70 |           ? gpuprim::DeviceRadixSort::SortPairsDescending<KeyT, ValT>(
      |                                                               ^
xla/service/gpu/cub_sort_kernel.cu.cc:70:69: error: expected primary-expression before '>' token
   70 |           ? gpuprim::DeviceRadixSort::SortPairsDescending<KeyT, ValT>(
      |                                                                     ^
xla/service/gpu/cub_sort_kernel.cu.cc:75:22: error: 'gpuprim::DeviceRadixSort' has not been declared
   75 |           : gpuprim::DeviceRadixSort::SortPairs<KeyT, ValT>(
      |                      ^~~~~~~~~~~~~~~
xla/service/gpu/cub_sort_kernel.cu.cc:75:53: error: expected primary-expression before ',' token
   75 |           : gpuprim::DeviceRadixSort::SortPairs<KeyT, ValT>(
      |                                                     ^
xla/service/gpu/cub_sort_kernel.cu.cc:75:59: error: expected initializer before '>' token
   75 |           : gpuprim::DeviceRadixSort::SortPairs<KeyT, ValT>(
      |                                                           ^
Target //xla/extension:xla_extension failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 61.626s, Critical Path: 42.00s
INFO: 248 processes: 17 internal, 231 local.
FAILED: Build did NOT complete successfully
make: *** [Makefile:27: /root/.cache/xla/0.5.1/cache/build/xla_extension-x86_64-linux-gnu-rocm.tar.gz] Error 1

I looked a bit into the error, and it seems like:

I'm kinda lost as to what to do from here 😓 The source file is set up as if it was GPU-agnostic, but it all seems very CUDA-specific. I tried checking AMD's fork of XLA but their project structure is fairly different (and 400+ commits behind at this point).

@costaraphael
Copy link
Author

Correcting my comment above, DeviceRadixSort also exists for ROCm, implemented via hipCUB, which is being included in the [gpu_prim_rocm.h](https://github.com/openxla/xla/blob/cb7121e143cc52fe931f3e724bc7fa115363b4b6/xla/service/gpu/gpu_prim_rocm.h) file.

The supposedly undefined float_bit_mask template is also defined, in rocPRIM.

My suspicion is that these packages are not being loaded properly, as if you check the entrypoint for hipCUB, if neither __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__ are defined, it essentially does nothing (which would explain these definitions not existing).

I tried to confirm this by adding a #define __HIP_PLATFORM_AMD__ line to gpu_prim_rocm.h before the hipCUB headers are imported, and got a bunch of undefined errors, like:

bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp:164:24: error: there are no arguments to 'HIP_KERNEL_NAME' that depend on a template parameter, so a declaration of 'HIP_KERNEL_NAME' must be available [-fpermissive]
  164 |     hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_arg_minmax_kernel<config>),
      |                        ^~~~~~~~~~~~~~~
In file included from bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/hipcub/backend/rocprim/hipcub.hpp:65,
                 from bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/hipcub/hipcub.hpp:36,
                 from ./xla/service/gpu/gpu_prim_rocm.h:23,
                 from xla/service/gpu/cub_sort_kernel.cu.cc:26:

Which tells me a bunch of stuff that should be loaded was not, probably due to something missing (in the environment or in the build flags).

I'm going to keep looking at this, but let me know if you have any ideas 😅

@jonatanklosko
Copy link
Member

jonatanklosko commented Nov 2, 2023

Just an update, I almost managed to build it inside ROCM Dockerfile (rocm/dev-ubuntu-20.04:5.7-complete) using a couple flags:

ENV TF_ROCM_AMDGPU_TARGETS "gfx900,gfx906,gfx908,gfx90a,gfx1030"
ENV TF_ROCM_VERSION "50700"
ENV ROCM_PATH "/opt/rocm-5.7.0"

A couple errors were actually fixed very recently on XLA main (including the if_cuda_is_configured -> if_gpu_is_configured typo, and the missing templates).

At the end I run into a linking error, which is the same one I run on macOS when building locally (duplicate symbols). XLA was recently vendored into tensorflow and there may be some structural changes that lead to duplicated files.

@costaraphael
Copy link
Author

@jonatanklosko that is amazing news! I'll give it a go using the latest XLA revisions and the envs you shared tomorrow!

I'm also going to have a look at that rocm-libs package the Dockerfile you mentioned is installing (there might be some lib I missed in my setup, I wasn't aware of this meta package).

I'll report back here with the results.

@jonatanklosko
Copy link
Member

FTR TF_ROCM_VERSION is actually not necessary.

@costaraphael please try the following and let me know if there are any runtime errors:

Mix.install(
  [
    {:nx, github: "elixir-nx/nx", sparse: "nx", ref: "cef5a12d", override: true},
    {:exla, github: "elixir-nx/nx", sparse: "exla", ref: "cef5a12d", override: true}
  ],
  system_env: %{
    "XLA_ARCHIVE_URL" =>
      "https://static.jonatanklosko.com/builds/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
  }
)

Nx.global_default_backend(EXLA.Backend)

Nx.iota({3})

@costaraphael
Copy link
Author

@jonatanklosko it worked! ❤️ 🎉

2023-11-03 10:51:34.434237: E xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN

10:51:35.462 [info] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero

10:51:35.462 [info] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero

10:51:35.463 [info] XLA service 0x7f5b70e79cf0 initialized for platform ROCM (this does not guarantee that XLA will be used). Devices:

10:51:35.463 [info]   StreamExecutor device (0): AMD Instinct MI100, AMDGPU ISA version: gfx908:sramecc+:xnack-

10:51:35.463 [info]   StreamExecutor device (1): AMD Instinct MI100, AMDGPU ISA version: gfx908:sramecc+:xnack-

10:51:35.463 [info] Using BFC allocator.

10:51:35.463 [info] XLA backend allocating 30425481216 bytes on device 0 for BFCAllocator.

10:51:35.463 [info] XLA backend allocating 30425481216 bytes on device 1 for BFCAllocator.

#Nx.Tensor<
  s64[3]
  EXLA.Backend<rocm:0, 0.4101236191.154796069.65091>
  [0, 1, 2]
>

I tried going further with some more involved tasks, but they all failed though:

  • Stable diffusion and speech-to-text: the serving is properly set up, but it fails when running it a segmentation fault.
  • Text embedding: fails to set up the serving itself with 10:40:56.442 [error] bitcode module is required by this HLO module but was not found at ./opencl.bc

(BTW, they fail either by trying to run the serving or by trying to start a serving process)

I'm not sure if these ☝️ were supposed to just work at this point, nor if this issue is the right place to talk about them, just mentioning for completeness sake.

@jonatanklosko
Copy link
Member

jonatanklosko commented Nov 3, 2023

Can you also run EXLA tests?

git clone https://github.com/elixir-nx/nx && \
  cd nx/exla && \
  git checkout cef5a12d && \
  mix deps.get && \
  XLA_ARCHIVE_URL="https://static.jonatanklosko.com/builds/xla_extension-x86_64-linux-gnu-rocm.tar.gz" EXLA_TARGET=rocm mix test

@jonatanklosko
Copy link
Member

Text embedding: fails to set up the serving itself with 10:40:56.442 [error] bitcode module is required by this HLO module but was not found at ./opencl.bc

Looks like ROCm/ROCm#1796, does it change anything if you set ROCM_PATH?

@costaraphael
Copy link
Author

@jonatanklosko I've ran it, and found the same error I mentioned above:

11:04:35.428 [error] domain=elixir.xla file=xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc line=240  bitcode module is required by this HLO module but was not found at ./opencl.bc


  9) test unary ops atanh (EXLA.MLIR.ExecutableTest)
     test/exla/mlir/executable_test.exs:314
     ** (RuntimeError) bitcode module not found at ./opencl.bc
     code: result_mlir = Nx.Defn.jit_apply(function, [t])
     stacktrace:
       (exla 0.7.0-dev) lib/exla/mlir/module.ex:110: EXLA.MLIR.Module.unwrap!/1
       (exla 0.7.0-dev) lib/exla/mlir/module.ex:96: EXLA.MLIR.Module.compile/5
       (stdlib 5.1.1) timer.erl:270: :timer.tc/2
       (exla 0.7.0-dev) lib/exla/defn.ex:434: anonymous fn/12 in EXLA.Defn.compile/8
       (exla 0.7.0-dev) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
       (stdlib 5.1.1) timer.erl:270: :timer.tc/2
       (exla 0.7.0-dev) lib/exla/defn.ex:409: EXLA.Defn.compile/8
       (exla 0.7.0-dev) lib/exla/defn.ex:272: EXLA.Defn.__compile__/4
       (exla 0.7.0-dev) lib/exla/defn.ex:258: EXLA.Defn.__jit__/5
       (nx 0.7.0-dev) lib/nx/defn.ex:433: Nx.Defn.do_jit_apply/3
       test/exla/mlir/executable_test.exs:325: (test)

I did some searching on the first error, and apparently setting ROCM_PATH=/opt/rocm-5.7.1 solved it.

Now I'm also getting the segmentation faults I mentioned in the other comment:

terminate called after throwing an instance of 'absl::lts_20230802::BadStatusOrAccess'
  what():  Bad StatusOr access: FAILED_PRECONDITION: Could not load dynamic library 'libhipblaslt.so'; dlerror: libhipblaslt.so: cannot open shared object file: No such file or directory
Aborted (core dumped)

plus this error:

[FATAL] xla/stream_executor/rocm/rocm_dnn.cc:1846 Unsupported DNN data type: tf.float64 (dnn::DataType::kDouble)
Aborted (core dumped)

I'm currently looking for libraries I may have missed through apt search rocm. So far I've installed rocm-libs, rocm-hip-libraries, and rocm-hip-runtime-dev, but no luck.

@costaraphael
Copy link
Author

I actually went checking for that dynamic lib, and found it:

# stat /opt/rocm-5.7.1/lib/libhipblaslt.so
  File: /opt/rocm-5.7.1/lib/libhipblaslt.so -> libhipblaslt.so.0
  Size: 17        	Blocks: 0          IO Block: 4096   symbolic link
Device: 1000b9h/1048761d	Inode: 11685698    Links: 1
Access: (0777/lrwxrwxrwx)  Uid: (    0/    root)   Gid: (    0/    root)
Access: 2023-11-03 11:18:05.123205115 +0000
Modify: 2023-10-07 01:41:56.000000000 +0000
Change: 2023-11-03 10:25:18.069698132 +0000
 Birth: -

So I'm not sure where it is looking for it.

@jonatanklosko
Copy link
Member

You can also try running the EXLA tests in Docker, just to see if any of these errors are environment specific. Here's a simple docker file with ROCM and Elixir:

Dockerfile

FROM hexpm/elixir:1.15.4-erlang-26.0.2-ubuntu-focal-20230126 AS elixir

FROM rocm/dev-ubuntu-20.04:5.7-complete

# Set the missing UTF-8 locale, otherwise Elixir warns
ENV LC_ALL C.UTF-8

# Make sure installing packages (like tzdata) doesn't prompt for configuration
ENV DEBIAN_FRONTEND noninteractive

# Install Erlang and Elixir

# Erlang runtime dependencies, see https://github.com/hexpm/bob/blob/3b5721dccdfe9d59766f374e7b4fb7fb8a7c720e/priv/scripts/docker/erlang-ubuntu-focal.dockerfile#L41-L45
RUN apt-get update && apt-get install -y --no-install-recommends libodbc1 libssl1.1 libsctp1

# We copy the top-level directory first to preserve symlinks in /usr/local/bin
COPY --from=elixir /usr/local /usr/ELIXIR_LOCAL
RUN cp -r /usr/ELIXIR_LOCAL/lib/* /usr/local/lib && \
  cp -r /usr/ELIXIR_LOCAL/bin/* /usr/local/bin && \
  rm -rf /usr/ELIXIR_LOCAL

# ---

RUN apt-get install -y git curl build-essential

ENV ROCM_PATH "/opt/rocm-5.7.0"

@costaraphael
Copy link
Author

I'm already running it in a Docker container 😅 I'll update the Dockerfile to look exactly like yours though (it will take a while to run).

In the meantime, I forced the dynamic library lookup path by setting LD_LIBRARY_PATH="/opt/rocm-5.7.1/lib" in the mix test command. It got rid of the library lookup error, but I'm still left with (running with --trace to help find the offending tests):

[FATAL] xla/stream_executor/rocm/rocm_dnn.cc:1846 Unsupported DNN data type: tf.float64 (dnn::DataType::kDouble)
Aborted (core dumped)

(☝️ could not find it by running with --trace)

EXLA.Defn.ExprTest [test/exla/defn/expr_test.exs]
  * test cholesky works on a 4x4 matrix [L#3939]Memory access fault by GPU node-1 (Agent handle: 0x7f9b3ca6da40) on address 0xfffffffff000. Reason: Page not present or supervisor privilege.
Aborted (core dumped)
EXLA.BackendTest [test/exla/backend_test.exs]
  * test Nx.LinAlg.svd/2 [L#156]:0:rocdevice.cpp            :2690: 1346318956343 us: [pid:168451 tid:0x7f3ee527b700] Callback: Queue 0x7f3b37600000 aborting with error : HSA_STATUS_ERROR_MEMORY_APERTURE_VIOLATION: The agent attempted to access memory beyond the largest legal address. code: 0x29
Aborted (core dumped)
EXLA.BackendTest [test/exla/backend_test.exs]
  * doctest Nx.conv/3 (705) [L#62]Segmentation fault (core dumped)

I'll let you know once the new Docker image is up and running and I had a chance to run the tests.

@costaraphael
Copy link
Author

@jonatanklosko I've ran the tests using the Dockerfile you mentioned, it is definitely cleaner!

Unfortunately I'm still seeing errors. I kept running the tests with the --trace flag and skipping the tests causing fatal errors and segmentation faults, and I was able to group them into these two groups:


Segmentation fault (core dumped)

Caused by any test that touches Nx.conv/3 (which explains neither stable diffusion nor speech to text working really)


:0:rocdevice.cpp            :2690: 1376989234258 us: [pid:92262 tid:0x7f1e00303700] Callback: Queue 0x7f1cfc800000 aborting with error : HSA_STATUS_ERROR_MEMORY_APERTURE_VIOLATION: The agent attempted to access memory beyond the largest legal address. code: 0x29
Aborted (core dumped)

Nx.LinAlg.svd/2 or Nx.LinAlg.cholesky/1 seems to be causing this one.


There are also 7 tests failing due to precision errors (posting the screenshot to get the visual diff).

Precision error

image

After skipping the tests for the three functions mentioned above, only these precision errors remain (everything seems to be working fine).

@costaraphael
Copy link
Author

@jonatanklosko I was running the tests in parallel (no --trace) and re-enabled the Nx.conv/3 doctest. I'd occasionally get the following error instead of a SEGFAULT:

     ** (RuntimeError) Failed to determine best cudnn convolution algorithm for:
     %cudnn-conv = (c64[1,1,5]{2,1,0}, u8[0]{0}) custom-call(c64[1,1,3]{2,1,0} %p0.2, c64[1,1,3]{2,1,0} %p1.3), window={size=3 pad=2_2}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}

     Original error: INTERNAL: Unsupported convolution datatype

     To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

Enabling that XLA flag yields the following logs/error:

19:55:00.948 [warning] domain=elixir.xla file=xla/service/gpu/conv_algorithm_picker.cc line=1071  Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (c64[1,1,5]{2,1,0}, u8[0]{0}) custom-call(c64[1,1,3]{2,1,0} %p0.2, c64[1,1,3]{2,1,0} %p1.3), window={size=3 pad=2_2}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}

Original error: INTERNAL: Unsupported convolution datatype

As a result, convolution performance may be suboptimal.

19:55:01.066 [warning] domain=elixir.xla file=xla/service/gpu/runtime/support.cc line=58  Intercepted XLA runtime error:
UNIMPLEMENTED: Unimplemented convolution

19:55:01.066 [error] domain=elixir.xla file=xla/pjrt/pjrt_stream_executor_client.cc line=2716  Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.conv.forward' failed: Unimplemented convolution; current tracing scope: cudnn-conv; current profiling annotation: XlaModule:#hlo_module=_Function_29.18096260_2_in_EXLA.Backend.conv_4_.6,program_id=588#.

It seems like this is coming from the complex numbers example, which for some reason never got hit when I was running the tests with --trace.

Maybe it is related?

@josevalim
Copy link
Contributor

Precision errors and features not being implemented are definitely expected.

@costaraphael
Copy link
Author

Precision errors and features not being implemented are definitely expected.

Yeah, I was actually expecting more issues TBH 😅 Blows me away how fast the roadblocks in this issue got lifted ❤️


I'll have some time tomorrow to take a jab at this.

My plan is to take a look at the SEGFAULT issues first. I'll try to retrieve and make sense of a core dump, that failing I'll try following the code path triggered by Nx.conv/3 to look for anything suspicious/unimplemented for ROCm.

Does this approach make sense? Or do you think I should try tackling something else/a different approach?

@jonatanklosko
Copy link
Member

@costaraphael are you able to run the bumblebee models now, or do you still get segfaults there too?

The segfaults related to complex conv may be just that rocm doesn't support complex numbers there, but complex numbers shouldn't be a blocker for running models :)


As for segfaults, when debugging these in the past I did:

  1. Edit /etc/sysctl.conf to include kernel.core_pattern = /var/crash/core-%e-%s-%u-%g-%p-%t
  2. sudo sysctl -p
  3. ulimit -c unlimited
  4. Run whatever causes the segfault
  5. Run gdb ~/.asdf/installs/erlang/24.0.6/erts-12.0.4/bin/beam.smp -core /var/crash/FILE (beam.smp path according to your Erlang installation)
  6. bt to backtrace

@costaraphael
Copy link
Author

@jonatanklosko I was able to run embeddings and text generation/conversational with zero problems! Stable diffusion and speech-to-text are still causing SEGFAULTs.

I also had a small issue using partitions: true to use both GPUs (XLA seems to be able to talk to both GPUs just fine, but I'm getting an error every time the serving runs on GPU 1 saying "the computation is allocated on GPU 0 but input tensors are allocated on GPU 1"). I'm almost positive this is either something I'm messing up or unrelated to XLA, so I'm not investigating it just yet.

Thanks for the SEGFAULT investigation tips! They'll come in handy later!

@jonatanklosko
Copy link
Member

"the computation is allocated on GPU 0 but input tensors are allocated on GPU 1"

This is most likely about params which are loaded into one of the GPUs. Try passing preallocate_params: true in the serving options, like Bumblebee.Text.generation(..., preallocate_params: true).

@josevalim
Copy link
Contributor

@jonatanklosko will pre-allocation help? If the data is loaded on GPU-0, I think preallocation will still complain when data on GPU-0 is attempted to be loaded into GPU-1. I think you need both to load the params into the host and then preallocate, right?

@jonatanklosko
Copy link
Member

Ah, I think you are right!

@costaraphael so you also need Bumblebee.load_model({:hf, "..."}, backend: {EXLA.Backend, client: :host}) :)

@costaraphael
Copy link
Author

So, this is what I have on a Livebook so far:

Mix.install(
  [
    {:nx, github: "elixir-nx/nx", sparse: "nx", ref: "cef5a12d", override: true},
    {:exla, github: "elixir-nx/nx", sparse: "exla", ref: "cef5a12d", override: true},
    {:bumblebee, "~> 0.4.2"},
    {:kino_bumblebee, "~> 0.4.0"}
  ],
  system_env: %{
    "XLA_ARCHIVE_URL" =>
      "https://static.jonatanklosko.com/builds/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
  },
  config: [nx: [default_backend: {EXLA.Backend, client: :host}]]
)
source = "Can I get medicare?"

compare_to = [
  "How do I get a replacement Medicare card?",
  "What is the monthly premium for Medicare Part B?",
  "How do I terminate my Medicare Part B (medical insurance)?",
  "How do I sign up for Medicare?",
  "Can I sign up for Medicare Part B if I am working and have health insurance through an employer?",
  "How do I sign up for Medicare Part B if I already have Part A?",
  "What are Medicare late enrollment penalties?",
  "What is Medicare and who can get it?",
  "How can I get help with my Medicare Part A and Part B premiums?",
  "What are the different parts of Medicare?",
  "Will my Medicare premiums be higher because of my higher income?",
  "What is TRICARE ?",
  "Should I sign up for Medicare Part B if I have Veterans' Benefits?"
]

hf_model = "sentence-transformers/all-MiniLM-L6-v2"
{:ok, model} = Bumblebee.load_model({:hf, hf_model}, backend: {EXLA.Backend, client: :host})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, hf_model})

serving =
  Bumblebee.Text.text_embedding(model, tokenizer,
    compile: [batch_size: 32, sequence_length: 512],
    defn_options: [compiler: EXLA, client: :rocm],
    embedding_processor: :l2_norm,
    preallocate_params: true
  )

Kino.start_child!({Nx.Serving, serving: serving, partitions: true, name: TextEmbeddingServing})
import Kino.Shorts

[orig | compare] =
  TextEmbeddingServing
  |> Nx.Serving.batched_run([source | compare_to])
  |> Enum.map(& &1.embedding)

scores =
  orig
  |> Nx.dot(compare |> Nx.stack() |> Nx.vectorize(:similarity))
  |> Nx.multiply(100)

grid([
  markdown("This is how the phrase `#{source}` compares to the other phrases:"),
  scores
  |> Nx.to_list()
  |> Enum.zip_with(compare_to, &%{phrase: &2, similarity: &1})
  |> data_table()
])

The error I get is:

** (ArgumentError) EXLA computation (defn) is allocated on client rocm #0 (rocm) but one of the input tensors are allocated on rocm #1 (rocm).

EXLA only transfers tensors allocated on host to other clients. You can force `:host` as your default backend with:

    # via config
    config :nx, default_backend: {EXLA.Backend, client: :host}

    # via API
    Nx.global_default_backend({EXLA.Backend, client: :host})

Otherwise ensure your tensors are allocated on the same client-device pair as your numerical definitions (defn). The default client-device is rocm #0 (rocm)

Let me know if you want me to open a parallel discussion for this in the forum or in a new issue.

@josevalim
Copy link
Contributor

I am not at home but you may be able to reproduce this locally @jonatanklosko by using XLA_FLAGS=--xla_force_host_platform_device_count=2.

@costaraphael do you have the full stacktrace for that error message?

@costaraphael
Copy link
Author

@josevalim here:

    (exla 0.7.0-dev) lib/exla/defn/buffers.ex:125: EXLA.Defn.Buffers.from_nx!/3
    (exla 0.7.0-dev) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
    (exla 0.7.0-dev) lib/exla/defn.ex:344: EXLA.Defn.maybe_outfeed/7
    (stdlib 5.0.2) timer.erl:270: :timer.tc/2
    (exla 0.7.0-dev) lib/exla/defn.ex:285: anonymous fn/7 in EXLA.Defn.__compile__/4
    (nx 0.7.0-dev) lib/nx/defn.ex:433: Nx.Defn.do_jit_apply/3
    (nx 0.7.0-dev) lib/nx.ex:13506: Nx.slice/4
    #cell:2gu6ce5zg2ju4jstrn77ab3utg4zqh3s:3: (file)

@josevalim
Copy link
Contributor

What is the code on line 3?

#cell:2gu6ce5zg2ju4jstrn77ab3utg4zqh3s:3: (file)

@costaraphael
Copy link
Author

My bad:

  |> Nx.Serving.batched_run([source | compare_to])

I also tried passing Nx.backend_copy/1 or Nx.backend_transfer/1 to batched_run with the same result.

@costaraphael
Copy link
Author

BTW, the serving does work with the current setup, but intercalating between working and not working. My assumption is that this is because Nx is jumping between the two GPUs for each run.

@josevalim
Copy link
Contributor

The bug is here: https://github.com/elixir-nx/nx/blob/main/exla/lib/exla/backend.ex#L322

It is considering the client but not the device. I will work on a fix tomorrow and let you know. :)

@costaraphael
Copy link
Author

costaraphael commented Nov 9, 2023

I unfortunately was not able to make any progress on the SEGFAULT yesterday 😞

I was not able to retrieve a core dump because I'm accessing the GPU through a container in K8s, the base image is Ubuntu which uses apport to handle core dumps, and apport is disabled by default on containers (unless you boot the container using systemd). I also don't have access to a privileged container there.

I then tried following the code, but got completely lost trying to follow what happens inside XLA 😅

One new piece of information I learned is that the compilation in itself seems to be working fine:

iex(2)> left = Nx.reshape(Nx.iota({9}), {1, 1, 3, 3})
#Nx.Tensor<
  s64[1][1][3][3]
  EXLA.Backend<rocm:0, 0.1590974901.976879655.142578>
  [
    [
      [
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]
      ]
    ]
  ]
>
iex(3)> right = Nx.reshape(Nx.iota({4}), {4, 1, 1, 1})
#Nx.Tensor<
  s64[4][1][1][1]
  EXLA.Backend<rocm:0, 0.1590974901.976879655.142580>
  [
    [
      [
        [0]
      ]
    ],
    [
      [
        [1]
      ]
    ],
    [
      [
        [2]
      ]
    ],
    [
      [
        [3]
      ]
    ]
  ]
>
iex(4)> fun = Nx.Defn.jit(&Nx.conv/3)
#Function<133.35555145/3 in Nx.Defn.Compiler.fun/2>
iex(5)> fun.(left, right, [])
Segmentation fault (core dumped)

So the SEGFAULT must be happening within XLA.

I'll try to get full access to a proper VM with a GPU, which should allow me to retrieve/analyze the core dump generated.

@jonatanklosko
Copy link
Member

@costaraphael FTR I released a new XLA and the relevant changes are already on EXLA main. I also built a new precompiled binary for ROCm, so for further tests you can use this:

Mix.install(
  [
    {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
    {:exla, github: "elixir-nx/nx", sparse: "exla", override: true}
  ],
  system_env: %{
    "XLA_ARCHIVE_URL" =>
      "https://static.jonatanklosko.com/builds/0.6.0/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
  }
)

@costaraphael
Copy link
Author

@jonatanklosko thanks! I've tried running this and everything that was working before continues to work now! I even did some tests running Mistral on the latest builds of Bumblebee and it worked beautifully ❤️

I'm still having a hard time with multiple GPUs though, but the failure mode is a bit different now. Take the following notebook:

Mix.install(
  [
    {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
    {:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
    {:kino, "~> 0.11.2"},
    {:bumblebee, "~> 0.4.2"}
  ],
  system_env: %{
    "XLA_ARCHIVE_URL" =>
      "https://static.jonatanklosko.com/builds/0.6.0/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
  },
  config: [nx: [default_backend: {EXLA.Backend, client: :host}]]
)
source = "Can I get medicare?"

compare_to = [
  "How do I get a replacement Medicare card?",
  "What is the monthly premium for Medicare Part B?",
  "How do I terminate my Medicare Part B (medical insurance)?",
  "How do I sign up for Medicare?",
  "Can I sign up for Medicare Part B if I am working and have health insurance through an employer?",
  "How do I sign up for Medicare Part B if I already have Part A?",
  "What are Medicare late enrollment penalties?",
  "What is Medicare and who can get it?",
  "How can I get help with my Medicare Part A and Part B premiums?",
  "What are the different parts of Medicare?",
  "Will my Medicare premiums be higher because of my higher income?",
  "What is TRICARE ?",
  "Should I sign up for Medicare Part B if I have Veterans' Benefits?"
]

hf_model = "sentence-transformers/all-MiniLM-L6-v2"
{:ok, model} = Bumblebee.load_model({:hf, hf_model}, backend: {EXLA.Backend, client: :host})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, hf_model})

serving =
  Bumblebee.Text.text_embedding(model, tokenizer,
    compile: [batch_size: 32, sequence_length: 512],
    defn_options: [compiler: EXLA, client: :rocm],
    embedding_processor: :l2_norm
  )

Kino.start_child({Nx.Serving, serving: serving, partitions: true, name: TextEmbeddingServing})
Nx.Serving.batched_run(TextEmbeddingServing, [source | compare_to])

Running the last cell is still alternating success and the error:

** (exit) exited in: Nx.Serving.local_batched_run(TextEmbeddingServing, ["Can I get medicare?", "How do I get a replacement Medicare card?", "What is the monthly premium for Medicare Part B?", "How do I terminate my Medicare Part B (medical insurance)?", "How do I sign up for Medicare?", "Can I sign up for Medicare Part B if I am working and have health insurance through an employer?", "How do I sign up for Medicare Part B if I already have Part A?", "What are Medicare late enrollment penalties?", "What is Medicare and who can get it?", "How can I get help with my Medicare Part A and Part B premiums?", "What are the different parts of Medicare?", "Will my Medicare premiums be higher because of my higher income?", "What is TRICARE ?", "Should I sign up for Medicare Part B if I have Veterans' Benefits?"])
    ** (EXIT) an exception was raised:
        ** (ArgumentError) EXLA computation (defn) is allocated on client rocm #1 (rocm) but one of the input tensors are allocated on rocm #0 (rocm).

EXLA only transfers tensors allocated on host to other clients. You can force `:host` as your default backend with:

    # via config
    config :nx, default_backend: {EXLA.Backend, client: :host}

    # via API
    Nx.global_default_backend({EXLA.Backend, client: :host})

Otherwise ensure your tensors are allocated on the same client-device pair as your numerical definitions (defn). The default client-device is rocm #0 (rocm)

            (exla 0.7.0-dev) lib/exla/defn/buffers.ex:125: EXLA.Defn.Buffers.from_nx!/3
            (exla 0.7.0-dev) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
            (exla 0.7.0-dev) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
            (exla 0.7.0-dev) lib/exla/defn.ex:344: EXLA.Defn.maybe_outfeed/7
            (stdlib 5.0.2) timer.erl:270: :timer.tc/2
            (exla 0.7.0-dev) lib/exla/defn.ex:285: anonymous fn/7 in EXLA.Defn.__compile__/4
            (nx 0.7.0-dev) lib/nx/defn.ex:313: anonymous fn/4 in Nx.Defn.compile/3
            (nx 0.7.0-dev) lib/nx/serving.ex:1826: anonymous fn/2 in Nx.Serving.Default.handle_batch/3
    (nx 0.7.0-dev) lib/nx/serving.ex:1012: Nx.Serving.local_batched_run!/3
    #cell:6nk3gtq45uvbqbagjgjxis5kbe5aekcj:1: (file)

Where #cell:6nk3gtq45uvbqbagjgjxis5kbe5aekcj:1 is the batched_run call. I figured that because the Stackstrace is different than last time, the fix worked but then something else broke further down the line.

Trying to pass preallocate_params: true to the serving makes the Nx.Serving process fail at startup with:

{:error,
 {:shutdown,
  {:failed_to_start_child, Nx.Serving,
   {%ArgumentError{
      message: "EXLA computation (defn) is allocated on client rocm #1 (rocm) but one of the input tensors are allocated on rocm #0 (rocm).\n\nEXLA only transfers tensors allocated on host to other clients. You can force `:host` as your default backend with:\n\n    # via config\n    config :nx, default_backend: {EXLA.Backend, client: :host}\n\n    # via API\n    Nx.global_default_backend({EXLA.Backend, client: :host})\n\nOtherwise ensure your tensors are allocated on the same client-device pair as your numerical definitions (defn). The default client-device is rocm #0 (rocm)\n"
    },
    [
      {EXLA.Defn.Buffers, :from_nx!, 3, [file: ~c"lib/exla/defn/buffers.ex", line: 125]},
      {EXLA.Defn.Buffers, :filter_by_indexes_map, 4, [file: ~c"lib/exla/defn/buffers.ex", line: 25]},
      {EXLA.Defn.Buffers, :filter_by_indexes_map, 4, [file: ~c"lib/exla/defn/buffers.ex", line: 25]},
      {EXLA.Defn, :maybe_outfeed, 7, [file: ~c"lib/exla/defn.ex", line: 344]},
      {:timer, :tc, 2, [file: ~c"timer.erl", line: 270]},
      {EXLA.Defn, :"-__compile__/4-fun-3-", 7, [file: ~c"lib/exla/defn.ex", line: 285]},
      {Nx.Defn, :do_jit_apply, 3, [file: ~c"lib/nx/defn.ex", line: 433]},
      {Bumblebee.Text.TextEmbedding, :"-text_embedding/3-fun-7-", 7,
       [file: ~c"lib/bumblebee/text/text_embedding.ex", line: 85]}
    ]}}}}

I can try to have a look at this over the weekend!

@jonatanklosko
Copy link
Member

@costaraphael the startup error is very surprising, do you get the same error for this?

{:ok, model} = Bumblebee.load_model({:hf, hf_model}, backend: {EXLA.Backend, client: :host})
params = model.params
Nx.Defn.jit_apply(&Function.identity/1, [params], compiler: EXLA, client: :rocm, device_id: 0)
Nx.Defn.jit_apply(&Function.identity/1, [params], compiler: EXLA, client: :rocm, device_id: 1)

@jonatanklosko
Copy link
Member

@costaraphael we found the issue, it's fixed on EXLA main, so you can try preallocate_params: true once again and it should work :)

@costaraphael
Copy link
Author

@jonatanklosko I tested this and it does work! <3

I found a small bug though when dealing with batches larger than the configured batch size (e.g. I want to embed 30 sentences but the batch size is 24).

I see that Nx.Serving is ready to deal with this, by splitting the batch and sending the chunks to each GPU for processing in parallel (which is awesome having out of the box, BTW!)

However, there's a bug when collecting the results of the computation:

** (ArgumentError) EXLA computation (defn) is allocated on client rocm #1 (rocm) but one of the input tensors are allocated on rocm #0 (rocm).

EXLA by default only transfers tensors allocated on host to other clients. You can force `:host` as your default backend with:

    # via config
    config :nx, default_backend: {EXLA.Backend, client: :host}

    # via API
    Nx.global_default_backend({EXLA.Backend, client: :host})

Otherwise ensure your tensors are allocated on the same client-device pair as your numerical definitions (defn). The default client-device is rocm #0 (rocm)

    (exla 0.7.0-dev) lib/exla/defn/buffers.ex:125: EXLA.Defn.Buffers.from_nx!/3
    (exla 0.7.0-dev) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
    (exla 0.7.0-dev) lib/exla/defn.ex:344: EXLA.Defn.maybe_outfeed/7
    (stdlib 5.0.2) timer.erl:270: :timer.tc/2
    (exla 0.7.0-dev) lib/exla/defn.ex:285: anonymous fn/7 in EXLA.Defn.__compile__/4
    (nx 0.7.0-dev) lib/nx/defn.ex:433: Nx.Defn.do_jit_apply/3
    (nx 0.7.0-dev) lib/nx.ex:14629: Nx.concatenate/2
    #cell:6nk3gtq45uvbqbagjgjxis5kbe5aekcj:1: (file)

(#cell:6nk3gtq45uvbqbagjgjxis5kbe5aekcj:1 in the above is Nx.Serving.batched_run(TextEmbeddingServing, sentences))

I managed to investigate this one and found the bug to be here.

Essentially, the code aggregates the results from the different GPUs into a single tensor, but since the results come from different devices, Nx.concatenate/2 breaks down with the above error. This is a minimal code example that reproduces it:

tensors =
  [
    [Nx.tensor([1, 2, 3, 4], backend: {EXLA.Backend, client: :rocm, device_id: 0})],
    [Nx.tensor([5, 6, 7, 8], backend: {EXLA.Backend, client: :rocm, device_id: 1})]
  ]

Enum.zip_with(tensors, &Nx.concatenate(&1))

Using Nx.backend_transfer/2 would get it fixed, something like:

tensors =
  [
    [Nx.tensor([1, 2, 3, 4], backend: {EXLA.Backend, client: :rocm, device_id: 0})],
    [Nx.tensor([5, 6, 7, 8], backend: {EXLA.Backend, client: :rocm, device_id: 1})]
  ]

tensors
|> Enum.zip_with(fn tensors ->
  tensors
  |> Enum.map(&Nx.backend_transfer(&1, {EXLA.Backend, client: :host}))
  |> Nx.concatenate()
end)

But then I realized something like this would require some sort of abstraction to allow using EXLA while keeping Nx backend agnostic, or the use of Nx.BinaryBackend and leave some performance on the table.

I didn't go ahead to submit a fix because I'm pretty sure there's already some abstraction to fix it with the former approach that I'm unaware of 😅

WDYT?

@jonatanklosko
Copy link
Member

@josevalim this seems similar to distributed_postprocessing/2, but in this case maybe we should implicitly transfer to the same device, perhaps a callback on the compiler, since it's related to __partitions_options__. wdyt?

@jonatanklosko
Copy link
Member

@costaraphael we addressed that in Bumblebee (elixir-nx/bumblebee#282), so main should work as expected :)

@costaraphael
Copy link
Author

@jonatanklosko @josevalim it is working perfectly ❤️

I now have zero blockers for demoing distributed^2 text embedding! 🎉 🎉 🎉

There is some other stuff I still want to pursue here, like the Nx.conv/3 SEGFAULTs, but at this moment the "cannot build XLA with ROCM in the latest versions" is no longer true, and hasn't been for a while at this point 😅 So if you want to close this issue, I can open a new one for the SEGFAULTs to keep you posted of my progress.

I also learned a lot about EXLA/Nx/Bumblebee while investigating this, so hopefully I should be able to contribute more in the future! On a similar note, if there's anything I or my company can do to help test/validate ROCm builds (aside from just using it), please do let me know.

Thanks for the massive, MASSIVE help 💜 I feel like I owe you both at the very least a pint/coffee/beverage of choice 😄

@jonatanklosko
Copy link
Member

@costaraphael fantastic!! 🐈‍⬛🔥 Thanks a lot for testing and reporting all the issues, it's great that we managed to find so many improvements :D

I think we can close this issue. Feel free to open one for segfaults, those are most likely just upstream XLA things that hopefully get resolve eventually, but if you manage to trace the core dump, perhaps we can report it there :)

@jonatanklosko
Copy link
Member

@costaraphael regarding conv segfaults #63 (comment) :)

@costaraphael
Copy link
Author

@jonatanklosko it works! 🚀 🚀 🚀 I ran stable diffusion, whisper, and it just works.

I tried initially just setting ELIXIR_ERL_OPTIONS in the system_env option of Mix.install/2 and it does not work, I had to set it on the process startup itself.

I'm pretty sure this is because the env var is going to be read during the BEAM start up process, and by the time Mix.install/2 is evaluated the VM is already booted up.

Anyway, just a caveat to keep in mind, as this was the first env var where just throwing it inside Mix.install/2 wouldn't work.

@jonatanklosko
Copy link
Member

@costaraphael you are correct, you may want to put it in your .bashrc or similar. For release deployment those options go in rel/vm.args.eex :)

@costaraphael
Copy link
Author

I wonder if there's any valid reason one might want to have that env var in the system_env option of Mix.install/2, or if it should issue a warning telling the dev it won't have any effect 🤔

@jonatanklosko
Copy link
Member

It probably is not applicable after boot, but warning on specific env vars is probably too specific. Another approach in Livebook is to use System.put_env (for env vars that don't affect deps installation) and that wouldn't work either.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants