diff --git a/lib/bumblebee/audio/speech_to_text_whisper.ex b/lib/bumblebee/audio/speech_to_text_whisper.ex index a538fddd..99434aa1 100644 --- a/lib/bumblebee/audio/speech_to_text_whisper.ex +++ b/lib/bumblebee/audio/speech_to_text_whisper.ex @@ -67,7 +67,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - generate_fun.(params, inputs) + generate_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options diff --git a/lib/bumblebee/diffusion/stable_diffusion.ex b/lib/bumblebee/diffusion/stable_diffusion.ex index 8661a721..7b483732 100644 --- a/lib/bumblebee/diffusion/stable_diffusion.ex +++ b/lib/bumblebee/diffusion/stable_diffusion.ex @@ -284,7 +284,9 @@ defmodule Bumblebee.Diffusion.StableDiffusion do %{image: image} end - Bumblebee.Utils.Nx.composite_unflatten_batch(output, inputs.size) + output + |> Bumblebee.Utils.Nx.composite_unflatten_batch(inputs.size) + |> Shared.serving_post_computation() end end @@ -318,9 +320,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do end defp client_postprocessing({outputs, _metadata}, multi?, safety_checker?) do - # We use binary backend so we are not blocked by the serving computation - outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend) - for outputs <- Bumblebee.Utils.Nx.batch_to_list(outputs) do results = for outputs = %{image: image} <- Bumblebee.Utils.Nx.batch_to_list(outputs) do diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index dbd72869..d789ab28 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -276,6 +276,19 @@ defmodule Bumblebee.Shared do Nx.Batch.pad(batch, batch_size - size) end + @doc """ + Shared logic applied after serving computation to the resulting tensor + or container. + """ + @spec serving_post_computation(result) :: result when result: Nx.Tensor.t() | Nx.Container.t() + def serving_post_computation(result) do + # We transfer to binary backend so tensor access in post-processing + # is not blocked by the serving the serving computation. It is also + # necessary when partitions are enabled since we may need to + # concatenate results for input exceeding the expected batch size. + Nx.backend_transfer(result, Nx.BinaryBackend) + end + @doc """ Compiles or wraps the function with just-in-time compilation. diff --git a/lib/bumblebee/text/conversation.ex b/lib/bumblebee/text/conversation.ex index e6c9e959..c5ce4d64 100644 --- a/lib/bumblebee/text/conversation.ex +++ b/lib/bumblebee/text/conversation.ex @@ -77,6 +77,7 @@ defmodule Bumblebee.Text.Conversation do end sequences[[.., start_idx..-1//1]] + |> Shared.serving_post_computation() end end, defn_options diff --git a/lib/bumblebee/text/fill_mask.ex b/lib/bumblebee/text/fill_mask.ex index 798cfa2a..aa1179d1 100644 --- a/lib/bumblebee/text/fill_mask.ex +++ b/lib/bumblebee/text/fill_mask.ex @@ -75,7 +75,7 @@ defmodule Bumblebee.Text.FillMask do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - scores_fun.(params, inputs) + scores_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 77345aa7..36ffa596 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -864,7 +864,7 @@ defmodule Bumblebee.Text.Generation do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - generate_fun.(params, inputs) + generate_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options diff --git a/lib/bumblebee/text/question_answering.ex b/lib/bumblebee/text/question_answering.ex index c9d4f322..59970c49 100644 --- a/lib/bumblebee/text/question_answering.ex +++ b/lib/bumblebee/text/question_answering.ex @@ -66,8 +66,7 @@ defmodule Bumblebee.Text.QuestionAnswering do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - - predict_fun.(params, inputs) + predict_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options @@ -103,10 +102,6 @@ defmodule Bumblebee.Text.QuestionAnswering do {batch, {all_inputs, raw_inputs, multi?}} end) |> Nx.Serving.client_postprocessing(fn {outputs, _metadata}, {inputs, raw_inputs, multi?} -> - # We use binary backend so we are not blocked by the serving computation - inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend) - outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend) - Enum.zip_with( [raw_inputs, Utils.Nx.batch_to_list(inputs), Utils.Nx.batch_to_list(outputs)], fn [{_question_text, context_text}, inputs, outputs] -> diff --git a/lib/bumblebee/text/text_classification.ex b/lib/bumblebee/text/text_classification.ex index 2519d9e7..c5c25804 100644 --- a/lib/bumblebee/text/text_classification.ex +++ b/lib/bumblebee/text/text_classification.ex @@ -61,7 +61,7 @@ defmodule Bumblebee.Text.TextClassification do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - scores_fun.(params, inputs) + scores_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index c1fd93ba..5842e21e 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -107,7 +107,7 @@ defmodule Bumblebee.Text.TextEmbedding do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - embedding_fun.(params, inputs) + embedding_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options @@ -131,9 +131,6 @@ defmodule Bumblebee.Text.TextEmbedding do {batch, multi?} end) |> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? -> - # We use binary backend so we are not blocked by the serving computation - embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend) - for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do %{embedding: embedding} end diff --git a/lib/bumblebee/text/token_classification.ex b/lib/bumblebee/text/token_classification.ex index 341beac6..08f6f56f 100644 --- a/lib/bumblebee/text/token_classification.ex +++ b/lib/bumblebee/text/token_classification.ex @@ -61,7 +61,7 @@ defmodule Bumblebee.Text.TokenClassification do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - scores_fun.(params, inputs) + scores_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options @@ -88,10 +88,6 @@ defmodule Bumblebee.Text.TokenClassification do {batch, {all_inputs, multi?}} end) |> Nx.Serving.client_postprocessing(fn {scores, _metadata}, {inputs, multi?} -> - # We use binary backend so we are not blocked by the serving computation - scores = Nx.backend_transfer(scores, Nx.BinaryBackend) - inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend) - Enum.zip_with( Utils.Nx.batch_to_list(inputs), Utils.Nx.batch_to_list(scores), @@ -110,9 +106,6 @@ defmodule Bumblebee.Text.TokenClassification do end defp gather_raw_entities(scores, tokenizer, inputs) do - # We use binary backend so we are not blocked by the serving computation - scores = Nx.backend_transfer(scores, Nx.BinaryBackend) - {sequence_length, _} = Nx.shape(scores) flat_special_tokens_mask = Nx.to_flat_list(inputs["special_tokens_mask"]) flat_input_ids = Nx.to_flat_list(inputs["input_ids"]) diff --git a/lib/bumblebee/text/zero_shot_classification.ex b/lib/bumblebee/text/zero_shot_classification.ex index bfe2ad66..32a6cdc0 100644 --- a/lib/bumblebee/text/zero_shot_classification.ex +++ b/lib/bumblebee/text/zero_shot_classification.ex @@ -81,7 +81,7 @@ defmodule Bumblebee.Text.ZeroShotClassification do scores = Axon.Activations.softmax(logits[[.., .., entailment_id]]) k = min(top_k, Nx.axis_size(scores, 1)) {top_scores, top_indices} = Nx.top_k(scores, k: k) - {top_scores, top_indices} + {top_scores, top_indices} |> Shared.serving_post_computation() end end, defn_options diff --git a/lib/bumblebee/vision/image_classification.ex b/lib/bumblebee/vision/image_classification.ex index 972519c4..cb86456a 100644 --- a/lib/bumblebee/vision/image_classification.ex +++ b/lib/bumblebee/vision/image_classification.ex @@ -57,7 +57,7 @@ defmodule Bumblebee.Vision.ImageClassification do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - scores_fun.(params, inputs) + scores_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options diff --git a/lib/bumblebee/vision/image_embedding.ex b/lib/bumblebee/vision/image_embedding.ex index c51fb8bd..d1975ca6 100644 --- a/lib/bumblebee/vision/image_embedding.ex +++ b/lib/bumblebee/vision/image_embedding.ex @@ -80,7 +80,7 @@ defmodule Bumblebee.Vision.ImageEmbedding do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - embedding_fun.(params, inputs) + embedding_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options @@ -94,9 +94,6 @@ defmodule Bumblebee.Vision.ImageEmbedding do {Nx.Batch.concatenate([inputs]), multi?} end) |> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? -> - # We use binary backend so we are not blocked by the serving computation - embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend) - for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do %{embedding: embedding} end diff --git a/lib/bumblebee/vision/image_to_text.ex b/lib/bumblebee/vision/image_to_text.ex index b526786e..29d96b8d 100644 --- a/lib/bumblebee/vision/image_to_text.ex +++ b/lib/bumblebee/vision/image_to_text.ex @@ -49,7 +49,7 @@ defmodule Bumblebee.Vision.ImageToText do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - generate_fun.(params, inputs) + generate_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options diff --git a/test/bumblebee/text/text_embedding_test.exs b/test/bumblebee/text/text_embedding_test.exs index a9b9974e..150eb723 100644 --- a/test/bumblebee/text/text_embedding_test.exs +++ b/test/bumblebee/text/text_embedding_test.exs @@ -96,8 +96,10 @@ defmodule Bumblebee.Text.TextEmbeddingTest do text = "query: Cats are cute." - assert %{embedding: %Nx.Tensor{} = embedding1} = Nx.Serving.batched_run(test, text) - assert %{embedding: %Nx.Tensor{} = embedding2} = Nx.Serving.batched_run(test, text) + assert [ + %{embedding: %Nx.Tensor{} = embedding1}, + %{embedding: %Nx.Tensor{} = embedding2} + ] = Nx.Serving.batched_run(test, [text, text]) assert_equal(embedding1, embedding2) end