-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot build XLA with ROCM in the latest versions #58
Comments
Hey @costaraphael, here are some notes from Jax, perhaps you are missing some of these packages, There are likely many changes on XLA main that we would need to account for to build, but let's see if we can address that specific error first. FTR the ROCm experience is rough, because we don't have resources to properly test it and we generally try to address issues as people try using it. |
Hey @jonatanklosko! Thanks for taking a look at this! Yeah, I was having another look at this yesterday and indeed the second error I have up there seems to be a problem with my setup. I didn't have enough time to attempt fixing it but I will at some point this week and I'll come back here with the results. Thanks again for the Jax material you linked, if it works it might save me a couple of hours trying to understand the Bazel setup XLA has 😄
I expected this to be the case, so no worries! Apparently, ROCm is being a pain in all ML ecosystems, even in Python 😅 |
@jonatanklosko quick question, when you mention lack of resources for testing, do you mean actual GPUs? Or people? Or both? 😅 |
I'd say both. XLA is still evolving quickly, so whenever we update there are changes and we test those across devices/GPUs. Currently most people use Nvidia, so that's our primary focus, we are a small group of people and ROCm is not high priority, but reports/contributions in this regard are welcome :) |
@jonatanklosko I had a look at this, and I was able to move forward a bit more: I was able to get past all the checks for ROCm libraries. I installed everything in the Jax material you linked, which the slight additions that I also had to install
The bad news is that the problem still seems to be on Bazeland. Building with:
yields
I had a quick look, and it seems like this is because the https://github.com/elixir-nx/xla/blob/cf9753eb4c70312ab4b195cca3f568779e731c59/extension/BUILD file is pointing to stuff that used to exist but doesn't anymore. |
@costaraphael you can try building on the |
Hey @jonatanklosko thank you so much for looking into this! I've tried building it like: Mix.install(
[
{:xla, github: "elixir-nx/xla", branch: "jk-bump"}
],
system_env: %{
"XLA_BUILD" => "true",
"XLA_TARGET" => "rocm"
}
) It initially gave me an error message complaining about After installing
|
Hmm, this again looks more environment specific (not bazel). Looks similar to tensorflow/tensorflow#61354, the suggestion is to try Clang 16. Are these header files present in your file system? If not, you can also try locating |
@jonatanklosko I've checked the environment, all the files listed there do exist. What I think the problem was is that those specific header files were being included (directly or indirectly) by I suspect this is occurring because the setup looks by default into the So in other words, Bazel knows about I tried pointing to the actual ROCm folder using
Looking at the target causing the error, I see that the Abseil headers are only being included if CUDA is configured, which is a problem since the source file includes the Abseil headers regardless of if CUDA is configured or not 😅 I manually edited the Bazel definition there to use
I looked a bit into the error, and it seems like:
I'm kinda lost as to what to do from here 😓 The source file is set up as if it was GPU-agnostic, but it all seems very CUDA-specific. I tried checking AMD's fork of XLA but their project structure is fairly different (and 400+ commits behind at this point). |
Correcting my comment above, The supposedly undefined My suspicion is that these packages are not being loaded properly, as if you check the entrypoint for hipCUB, if neither I tried to confirm this by adding a
Which tells me a bunch of stuff that should be loaded was not, probably due to something missing (in the environment or in the build flags). I'm going to keep looking at this, but let me know if you have any ideas 😅 |
Just an update, I almost managed to build it inside ROCM Dockerfile (
A couple errors were actually fixed very recently on XLA main (including the At the end I run into a linking error, which is the same one I run on macOS when building locally (duplicate symbols). XLA was recently vendored into tensorflow and there may be some structural changes that lead to duplicated files. |
@jonatanklosko that is amazing news! I'll give it a go using the latest XLA revisions and the envs you shared tomorrow! I'm also going to have a look at that I'll report back here with the results. |
FTR @costaraphael please try the following and let me know if there are any runtime errors: Mix.install(
[
{:nx, github: "elixir-nx/nx", sparse: "nx", ref: "cef5a12d", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", ref: "cef5a12d", override: true}
],
system_env: %{
"XLA_ARCHIVE_URL" =>
"https://static.jonatanklosko.com/builds/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
}
)
Nx.global_default_backend(EXLA.Backend)
Nx.iota({3}) |
@jonatanklosko it worked! ❤️ 🎉
I tried going further with some more involved tasks, but they all failed though:
(BTW, they fail either by trying to run the serving or by trying to start a serving process) I'm not sure if these ☝️ were supposed to just work at this point, nor if this issue is the right place to talk about them, just mentioning for completeness sake. |
Can you also run EXLA tests? git clone https://github.com/elixir-nx/nx && \
cd nx/exla && \
git checkout cef5a12d && \
mix deps.get && \
XLA_ARCHIVE_URL="https://static.jonatanklosko.com/builds/xla_extension-x86_64-linux-gnu-rocm.tar.gz" EXLA_TARGET=rocm mix test |
Looks like ROCm/ROCm#1796, does it change anything if you set |
@jonatanklosko I've ran it, and found the same error I mentioned above:
I did some searching on the first error, and apparently setting Now I'm also getting the segmentation faults I mentioned in the other comment:
plus this error:
I'm currently looking for libraries I may have missed through |
I actually went checking for that dynamic lib, and found it:
So I'm not sure where it is looking for it. |
You can also try running the EXLA tests in Docker, just to see if any of these errors are environment specific. Here's a simple docker file with ROCM and Elixir: Dockerfile
FROM hexpm/elixir:1.15.4-erlang-26.0.2-ubuntu-focal-20230126 AS elixir
FROM rocm/dev-ubuntu-20.04:5.7-complete
# Set the missing UTF-8 locale, otherwise Elixir warns
ENV LC_ALL C.UTF-8
# Make sure installing packages (like tzdata) doesn't prompt for configuration
ENV DEBIAN_FRONTEND noninteractive
# Install Erlang and Elixir
# Erlang runtime dependencies, see https://github.com/hexpm/bob/blob/3b5721dccdfe9d59766f374e7b4fb7fb8a7c720e/priv/scripts/docker/erlang-ubuntu-focal.dockerfile#L41-L45
RUN apt-get update && apt-get install -y --no-install-recommends libodbc1 libssl1.1 libsctp1
# We copy the top-level directory first to preserve symlinks in /usr/local/bin
COPY --from=elixir /usr/local /usr/ELIXIR_LOCAL
RUN cp -r /usr/ELIXIR_LOCAL/lib/* /usr/local/lib && \
cp -r /usr/ELIXIR_LOCAL/bin/* /usr/local/bin && \
rm -rf /usr/ELIXIR_LOCAL
# ---
RUN apt-get install -y git curl build-essential
ENV ROCM_PATH "/opt/rocm-5.7.0" |
I'm already running it in a Docker container 😅 I'll update the Dockerfile to look exactly like yours though (it will take a while to run). In the meantime, I forced the dynamic library lookup path by setting
(☝️ could not find it by running with
I'll let you know once the new Docker image is up and running and I had a chance to run the tests. |
@jonatanklosko I've ran the tests using the Dockerfile you mentioned, it is definitely cleaner! Unfortunately I'm still seeing errors. I kept running the tests with the
Caused by any test that touches
There are also 7 tests failing due to precision errors (posting the screenshot to get the visual diff). After skipping the tests for the three functions mentioned above, only these precision errors remain (everything seems to be working fine). |
@jonatanklosko I was running the tests in parallel (no
Enabling that XLA flag yields the following logs/error:
It seems like this is coming from the complex numbers example, which for some reason never got hit when I was running the tests with Maybe it is related? |
Precision errors and features not being implemented are definitely expected. |
Yeah, I was actually expecting more issues TBH 😅 Blows me away how fast the roadblocks in this issue got lifted ❤️ I'll have some time tomorrow to take a jab at this. My plan is to take a look at the SEGFAULT issues first. I'll try to retrieve and make sense of a core dump, that failing I'll try following the code path triggered by Does this approach make sense? Or do you think I should try tackling something else/a different approach? |
@costaraphael are you able to run the bumblebee models now, or do you still get segfaults there too? The segfaults related to complex conv may be just that rocm doesn't support complex numbers there, but complex numbers shouldn't be a blocker for running models :) As for segfaults, when debugging these in the past I did:
|
@jonatanklosko I was able to run embeddings and text generation/conversational with zero problems! Stable diffusion and speech-to-text are still causing SEGFAULTs. I also had a small issue using Thanks for the SEGFAULT investigation tips! They'll come in handy later! |
This is most likely about params which are loaded into one of the GPUs. Try passing |
@jonatanklosko will pre-allocation help? If the data is loaded on GPU-0, I think preallocation will still complain when data on GPU-0 is attempted to be loaded into GPU-1. I think you need both to load the params into the host and then preallocate, right? |
Ah, I think you are right! @costaraphael so you also need |
So, this is what I have on a Livebook so far: Mix.install(
[
{:nx, github: "elixir-nx/nx", sparse: "nx", ref: "cef5a12d", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", ref: "cef5a12d", override: true},
{:bumblebee, "~> 0.4.2"},
{:kino_bumblebee, "~> 0.4.0"}
],
system_env: %{
"XLA_ARCHIVE_URL" =>
"https://static.jonatanklosko.com/builds/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
},
config: [nx: [default_backend: {EXLA.Backend, client: :host}]]
) source = "Can I get medicare?"
compare_to = [
"How do I get a replacement Medicare card?",
"What is the monthly premium for Medicare Part B?",
"How do I terminate my Medicare Part B (medical insurance)?",
"How do I sign up for Medicare?",
"Can I sign up for Medicare Part B if I am working and have health insurance through an employer?",
"How do I sign up for Medicare Part B if I already have Part A?",
"What are Medicare late enrollment penalties?",
"What is Medicare and who can get it?",
"How can I get help with my Medicare Part A and Part B premiums?",
"What are the different parts of Medicare?",
"Will my Medicare premiums be higher because of my higher income?",
"What is TRICARE ?",
"Should I sign up for Medicare Part B if I have Veterans' Benefits?"
]
hf_model = "sentence-transformers/all-MiniLM-L6-v2" {:ok, model} = Bumblebee.load_model({:hf, hf_model}, backend: {EXLA.Backend, client: :host})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, hf_model})
serving =
Bumblebee.Text.text_embedding(model, tokenizer,
compile: [batch_size: 32, sequence_length: 512],
defn_options: [compiler: EXLA, client: :rocm],
embedding_processor: :l2_norm,
preallocate_params: true
)
Kino.start_child!({Nx.Serving, serving: serving, partitions: true, name: TextEmbeddingServing}) import Kino.Shorts
[orig | compare] =
TextEmbeddingServing
|> Nx.Serving.batched_run([source | compare_to])
|> Enum.map(& &1.embedding)
scores =
orig
|> Nx.dot(compare |> Nx.stack() |> Nx.vectorize(:similarity))
|> Nx.multiply(100)
grid([
markdown("This is how the phrase `#{source}` compares to the other phrases:"),
scores
|> Nx.to_list()
|> Enum.zip_with(compare_to, &%{phrase: &2, similarity: &1})
|> data_table()
]) The error I get is:
Let me know if you want me to open a parallel discussion for this in the forum or in a new issue. |
I am not at home but you may be able to reproduce this locally @jonatanklosko by using @costaraphael do you have the full stacktrace for that error message? |
@josevalim here:
|
What is the code on line 3?
|
My bad: |> Nx.Serving.batched_run([source | compare_to]) I also tried passing |
BTW, the serving does work with the current setup, but intercalating between working and not working. My assumption is that this is because Nx is jumping between the two GPUs for each run. |
The bug is here: https://github.com/elixir-nx/nx/blob/main/exla/lib/exla/backend.ex#L322 It is considering the client but not the device. I will work on a fix tomorrow and let you know. :) |
I unfortunately was not able to make any progress on the SEGFAULT yesterday 😞 I was not able to retrieve a core dump because I'm accessing the GPU through a container in K8s, the base image is Ubuntu which uses I then tried following the code, but got completely lost trying to follow what happens inside XLA 😅 One new piece of information I learned is that the compilation in itself seems to be working fine: iex(2)> left = Nx.reshape(Nx.iota({9}), {1, 1, 3, 3})
#Nx.Tensor<
s64[1][1][3][3]
EXLA.Backend<rocm:0, 0.1590974901.976879655.142578>
[
[
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
]
]
]
>
iex(3)> right = Nx.reshape(Nx.iota({4}), {4, 1, 1, 1})
#Nx.Tensor<
s64[4][1][1][1]
EXLA.Backend<rocm:0, 0.1590974901.976879655.142580>
[
[
[
[0]
]
],
[
[
[1]
]
],
[
[
[2]
]
],
[
[
[3]
]
]
]
>
iex(4)> fun = Nx.Defn.jit(&Nx.conv/3)
#Function<133.35555145/3 in Nx.Defn.Compiler.fun/2>
iex(5)> fun.(left, right, [])
Segmentation fault (core dumped) So the SEGFAULT must be happening within XLA. I'll try to get full access to a proper VM with a GPU, which should allow me to retrieve/analyze the core dump generated. |
@costaraphael FTR I released a new XLA and the relevant changes are already on EXLA main. I also built a new precompiled binary for ROCm, so for further tests you can use this: Mix.install(
[
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true}
],
system_env: %{
"XLA_ARCHIVE_URL" =>
"https://static.jonatanklosko.com/builds/0.6.0/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
}
) |
@jonatanklosko thanks! I've tried running this and everything that was working before continues to work now! I even did some tests running Mistral on the latest builds of Bumblebee and it worked beautifully ❤️ I'm still having a hard time with multiple GPUs though, but the failure mode is a bit different now. Take the following notebook: Mix.install(
[
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
{:kino, "~> 0.11.2"},
{:bumblebee, "~> 0.4.2"}
],
system_env: %{
"XLA_ARCHIVE_URL" =>
"https://static.jonatanklosko.com/builds/0.6.0/xla_extension-x86_64-linux-gnu-rocm.tar.gz"
},
config: [nx: [default_backend: {EXLA.Backend, client: :host}]]
) source = "Can I get medicare?"
compare_to = [
"How do I get a replacement Medicare card?",
"What is the monthly premium for Medicare Part B?",
"How do I terminate my Medicare Part B (medical insurance)?",
"How do I sign up for Medicare?",
"Can I sign up for Medicare Part B if I am working and have health insurance through an employer?",
"How do I sign up for Medicare Part B if I already have Part A?",
"What are Medicare late enrollment penalties?",
"What is Medicare and who can get it?",
"How can I get help with my Medicare Part A and Part B premiums?",
"What are the different parts of Medicare?",
"Will my Medicare premiums be higher because of my higher income?",
"What is TRICARE ?",
"Should I sign up for Medicare Part B if I have Veterans' Benefits?"
]
hf_model = "sentence-transformers/all-MiniLM-L6-v2" {:ok, model} = Bumblebee.load_model({:hf, hf_model}, backend: {EXLA.Backend, client: :host})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, hf_model})
serving =
Bumblebee.Text.text_embedding(model, tokenizer,
compile: [batch_size: 32, sequence_length: 512],
defn_options: [compiler: EXLA, client: :rocm],
embedding_processor: :l2_norm
)
Kino.start_child({Nx.Serving, serving: serving, partitions: true, name: TextEmbeddingServing}) Nx.Serving.batched_run(TextEmbeddingServing, [source | compare_to]) Running the last cell is still alternating success and the error:
Where Trying to pass {:error,
{:shutdown,
{:failed_to_start_child, Nx.Serving,
{%ArgumentError{
message: "EXLA computation (defn) is allocated on client rocm #1 (rocm) but one of the input tensors are allocated on rocm #0 (rocm).\n\nEXLA only transfers tensors allocated on host to other clients. You can force `:host` as your default backend with:\n\n # via config\n config :nx, default_backend: {EXLA.Backend, client: :host}\n\n # via API\n Nx.global_default_backend({EXLA.Backend, client: :host})\n\nOtherwise ensure your tensors are allocated on the same client-device pair as your numerical definitions (defn). The default client-device is rocm #0 (rocm)\n"
},
[
{EXLA.Defn.Buffers, :from_nx!, 3, [file: ~c"lib/exla/defn/buffers.ex", line: 125]},
{EXLA.Defn.Buffers, :filter_by_indexes_map, 4, [file: ~c"lib/exla/defn/buffers.ex", line: 25]},
{EXLA.Defn.Buffers, :filter_by_indexes_map, 4, [file: ~c"lib/exla/defn/buffers.ex", line: 25]},
{EXLA.Defn, :maybe_outfeed, 7, [file: ~c"lib/exla/defn.ex", line: 344]},
{:timer, :tc, 2, [file: ~c"timer.erl", line: 270]},
{EXLA.Defn, :"-__compile__/4-fun-3-", 7, [file: ~c"lib/exla/defn.ex", line: 285]},
{Nx.Defn, :do_jit_apply, 3, [file: ~c"lib/nx/defn.ex", line: 433]},
{Bumblebee.Text.TextEmbedding, :"-text_embedding/3-fun-7-", 7,
[file: ~c"lib/bumblebee/text/text_embedding.ex", line: 85]}
]}}}} I can try to have a look at this over the weekend! |
@costaraphael the startup error is very surprising, do you get the same error for this? {:ok, model} = Bumblebee.load_model({:hf, hf_model}, backend: {EXLA.Backend, client: :host})
params = model.params
Nx.Defn.jit_apply(&Function.identity/1, [params], compiler: EXLA, client: :rocm, device_id: 0)
Nx.Defn.jit_apply(&Function.identity/1, [params], compiler: EXLA, client: :rocm, device_id: 1) |
@costaraphael we found the issue, it's fixed on EXLA main, so you can try |
@jonatanklosko I tested this and it does work! <3 I found a small bug though when dealing with batches larger than the configured batch size (e.g. I want to embed 30 sentences but the batch size is 24). I see that However, there's a bug when collecting the results of the computation:
( I managed to investigate this one and found the bug to be here. Essentially, the code aggregates the results from the different GPUs into a single tensor, but since the results come from different devices, tensors =
[
[Nx.tensor([1, 2, 3, 4], backend: {EXLA.Backend, client: :rocm, device_id: 0})],
[Nx.tensor([5, 6, 7, 8], backend: {EXLA.Backend, client: :rocm, device_id: 1})]
]
Enum.zip_with(tensors, &Nx.concatenate(&1)) Using tensors =
[
[Nx.tensor([1, 2, 3, 4], backend: {EXLA.Backend, client: :rocm, device_id: 0})],
[Nx.tensor([5, 6, 7, 8], backend: {EXLA.Backend, client: :rocm, device_id: 1})]
]
tensors
|> Enum.zip_with(fn tensors ->
tensors
|> Enum.map(&Nx.backend_transfer(&1, {EXLA.Backend, client: :host}))
|> Nx.concatenate()
end) But then I realized something like this would require some sort of abstraction to allow using I didn't go ahead to submit a fix because I'm pretty sure there's already some abstraction to fix it with the former approach that I'm unaware of 😅 WDYT? |
@josevalim this seems similar to |
@costaraphael we addressed that in Bumblebee (elixir-nx/bumblebee#282), so main should work as expected :) |
@jonatanklosko @josevalim it is working perfectly ❤️ I now have zero blockers for demoing distributed^2 text embedding! 🎉 🎉 🎉 There is some other stuff I still want to pursue here, like the I also learned a lot about EXLA/Nx/Bumblebee while investigating this, so hopefully I should be able to contribute more in the future! On a similar note, if there's anything I or my company can do to help test/validate ROCm builds (aside from just using it), please do let me know. Thanks for the massive, MASSIVE help 💜 I feel like I owe you both at the very least a pint/coffee/beverage of choice 😄 |
@costaraphael fantastic!! 🐈⬛🔥 Thanks a lot for testing and reporting all the issues, it's great that we managed to find so many improvements :D I think we can close this issue. Feel free to open one for segfaults, those are most likely just upstream XLA things that hopefully get resolve eventually, but if you manage to trace the core dump, perhaps we can report it there :) |
@costaraphael regarding conv segfaults #63 (comment) :) |
@jonatanklosko it works! 🚀 🚀 🚀 I ran stable diffusion, whisper, and it just works. I tried initially just setting I'm pretty sure this is because the env var is going to be read during the BEAM start up process, and by the time Anyway, just a caveat to keep in mind, as this was the first env var where just throwing it inside |
@costaraphael you are correct, you may want to put it in your |
I wonder if there's any valid reason one might want to have that env var in the |
It probably is not applicable after boot, but warning on specific env vars is probably too specific. Another approach in Livebook is to use |
Hey folks!
First of all thanks for the great work supporting Elixir's ML ecosystem! <3
I'm trying to set up a demo of GPU distribution at my company, but whenever I try to compile
xla
0.5.1
withXLA_BUILD=true
andXLA_TARGET=rocm
I get the following error:I did some digging, and I think it is because the openxla version the Makefile is pointing to doesn't have the
rocm
configuration:.bazelrc
file at the version the Makefile points to: https://github.com/openxla/xla/blob/b938cfdf2d4e9a5f69c494a316e92638c1a119ef/.bazelrc.bazelrc
file at the current version: https://github.com/openxla/xla/blob/main/.bazelrcThe
rocm
config was apparently added back in this commit: openxla/xla@98b6197.I tried forcing XLA to download the latest passing build of openxla (openxla/xla@dba73eb) by editing the Makefile locally, but it failed with the following log:
So my guess is that it's not going to be as simple as just pointing to the new version 😅
I'd love to help get this sorted, but I'll need some pointers of where to start looking.
The text was updated successfully, but these errors were encountered: