Skip to content

Commit

Permalink
Remove TODO
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Dec 2, 2024
1 parent 8c489cb commit 65bc636
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -306,19 +306,12 @@ defmodule Bumblebee.Text.Generation do

output_policy = model_output_policy(model)

# TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to
# match Axon compiler

# Cast all float cache tensors to match the model output. This way
# we make sure the cache we pass as input has the same types as
# the updated cache returned from the model
cache =
Bumblebee.Utils.Nx.map(cache, fn tensor ->
if Nx.Type.integer?(Nx.type(tensor)) do
tensor
else
Axon.MixedPrecision.cast(output_policy, tensor, :output)
end
Axon.MixedPrecision.cast(output_policy, tensor, :output)
end)

Map.put(inputs, "cache", cache)
Expand Down

0 comments on commit 65bc636

Please sign in to comment.