Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer serving computation result to binary backend upfront #282

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
generate_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
7 changes: 3 additions & 4 deletions lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
%{image: image}
end

Bumblebee.Utils.Nx.composite_unflatten_batch(output, inputs.size)
output
|> Bumblebee.Utils.Nx.composite_unflatten_batch(inputs.size)
|> Shared.serving_post_computation()
end
end

Expand Down Expand Up @@ -318,9 +320,6 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
end

defp client_postprocessing({outputs, _metadata}, multi?, safety_checker?) do
# We use binary backend so we are not blocked by the serving computation
outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend)

for outputs <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
results =
for outputs = %{image: image} <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
Expand Down
13 changes: 13 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ defmodule Bumblebee.Shared do
Nx.Batch.pad(batch, batch_size - size)
end

@doc """
Shared logic applied after serving computation to the resulting tensor
or container.
"""
@spec serving_post_computation(result) :: result when result: Nx.Tensor.t() | Nx.Container.t()
def serving_post_computation(result) do
# We transfer to binary backend so tensor access in post-processing
# is not blocked by the serving the serving computation. It is also
# necessary when partitions are enabled since we may need to
# concatenate results for input exceeding the expected batch size.
Nx.backend_transfer(result, Nx.BinaryBackend)
end

@doc """
Compiles or wraps the function with just-in-time compilation.

Expand Down
1 change: 1 addition & 0 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ defmodule Bumblebee.Text.Conversation do
end

sequences[[.., start_idx..-1//1]]
|> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/fill_mask.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ defmodule Bumblebee.Text.FillMask do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ defmodule Bumblebee.Text.Generation do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
generate_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
7 changes: 1 addition & 6 deletions lib/bumblebee/text/question_answering.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ defmodule Bumblebee.Text.QuestionAnswering do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)

predict_fun.(params, inputs)
predict_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down Expand Up @@ -103,10 +102,6 @@ defmodule Bumblebee.Text.QuestionAnswering do
{batch, {all_inputs, raw_inputs, multi?}}
end)
|> Nx.Serving.client_postprocessing(fn {outputs, _metadata}, {inputs, raw_inputs, multi?} ->
# We use binary backend so we are not blocked by the serving computation
inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend)
outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend)

Enum.zip_with(
[raw_inputs, Utils.Nx.batch_to_list(inputs), Utils.Nx.batch_to_list(outputs)],
fn [{_question_text, context_text}, inputs, outputs] ->
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/text_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ defmodule Bumblebee.Text.TextClassification do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
5 changes: 1 addition & 4 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ defmodule Bumblebee.Text.TextEmbedding do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
embedding_fun.(params, inputs)
embedding_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand All @@ -131,9 +131,6 @@ defmodule Bumblebee.Text.TextEmbedding do
{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
# We use binary backend so we are not blocked by the serving computation
embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend)

for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
%{embedding: embedding}
end
Expand Down
9 changes: 1 addition & 8 deletions lib/bumblebee/text/token_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ defmodule Bumblebee.Text.TokenClassification do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand All @@ -88,10 +88,6 @@ defmodule Bumblebee.Text.TokenClassification do
{batch, {all_inputs, multi?}}
end)
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, {inputs, multi?} ->
# We use binary backend so we are not blocked by the serving computation
scores = Nx.backend_transfer(scores, Nx.BinaryBackend)
inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend)

Enum.zip_with(
Utils.Nx.batch_to_list(inputs),
Utils.Nx.batch_to_list(scores),
Expand All @@ -110,9 +106,6 @@ defmodule Bumblebee.Text.TokenClassification do
end

defp gather_raw_entities(scores, tokenizer, inputs) do
# We use binary backend so we are not blocked by the serving computation
scores = Nx.backend_transfer(scores, Nx.BinaryBackend)

{sequence_length, _} = Nx.shape(scores)
flat_special_tokens_mask = Nx.to_flat_list(inputs["special_tokens_mask"])
flat_input_ids = Nx.to_flat_list(inputs["input_ids"])
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/zero_shot_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ defmodule Bumblebee.Text.ZeroShotClassification do
scores = Axon.Activations.softmax(logits[[.., .., entailment_id]])
k = min(top_k, Nx.axis_size(scores, 1))
{top_scores, top_indices} = Nx.top_k(scores, k: k)
{top_scores, top_indices}
{top_scores, top_indices} |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ defmodule Bumblebee.Vision.ImageClassification do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
5 changes: 1 addition & 4 deletions lib/bumblebee/vision/image_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ defmodule Bumblebee.Vision.ImageEmbedding do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
embedding_fun.(params, inputs)
embedding_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand All @@ -94,9 +94,6 @@ defmodule Bumblebee.Vision.ImageEmbedding do
{Nx.Batch.concatenate([inputs]), multi?}
end)
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
# We use binary backend so we are not blocked by the serving computation
embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend)

for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
%{embedding: embedding}
end
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ defmodule Bumblebee.Vision.ImageToText do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
generate_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down
6 changes: 4 additions & 2 deletions test/bumblebee/text/text_embedding_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ defmodule Bumblebee.Text.TextEmbeddingTest do

text = "query: Cats are cute."

assert %{embedding: %Nx.Tensor{} = embedding1} = Nx.Serving.batched_run(test, text)
assert %{embedding: %Nx.Tensor{} = embedding2} = Nx.Serving.batched_run(test, text)
assert [
%{embedding: %Nx.Tensor{} = embedding1},
%{embedding: %Nx.Tensor{} = embedding2}
] = Nx.Serving.batched_run(test, [text, text])

assert_equal(embedding1, embedding2)
end
Expand Down
Loading