Skip to content

Commit

Permalink
Add Phi model (#356)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
seanmor5 and jonatanklosko authored Mar 1, 2024
1 parent 9d7ce31 commit a09b230
Show file tree
Hide file tree
Showing 10 changed files with 622 additions and 11 deletions.
5 changes: 5 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ defmodule Bumblebee do
"MistralModel" => {Bumblebee.Text.Mistral, :base},
"MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling},
"MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification},
"PhiModel" => {Bumblebee.Text.Phi, :base},
"PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling},
"PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification},
"PhiForTokenClassification" => {Bumblebee.Text.Phi, :for_token_classification},
"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
Expand Down Expand Up @@ -234,6 +238,7 @@ defmodule Bumblebee do
"llama" => :llama,
"mistral" => :llama,
"mbart" => :mbart,
"phi" => :code_gen,
"roberta" => :roberta,
"t5" => :t5,
"whisper" => :whisper,
Expand Down
33 changes: 28 additions & 5 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,21 @@ defmodule Bumblebee.Layers do
* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`
* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.
* `:use_bias` - whether the layer should add bias to the output.
Defaults to `false`
"""
def dense_transposed(%Axon{} = x, units, opts \\ []) do
opts = Keyword.validate!(opts, [:name, kernel_initializer: :glorot_uniform])
opts =
Keyword.validate!(opts, [
:name,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros,
use_bias: false
])

kernel_shape = fn input_shape ->
kernel_shape = Axon.Shape.dense_kernel(input_shape, units)
Expand All @@ -396,13 +408,24 @@ defmodule Bumblebee.Layers do
|> List.to_tuple()
end

bias_shape = &Axon.Shape.dense_bias(&1, units)

kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer])

op = fn x, kernel, _opts ->
Nx.dot(x, [-1], kernel, [1])
end
{inputs, op} =
if opts[:use_bias] do
bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], &dense_transposed_impl/4}
else
{[x, kernel], &dense_transposed_impl/3}
end

Axon.layer(op, [x, kernel], name: opts[:name], op_name: :dense_transposed)
Axon.layer(op, inputs, name: opts[:name], op_name: :dense_transposed)
end

deftransformp dense_transposed_impl(x, kernel, bias \\ 0, _opts) do
Nx.dot(x, [-1], kernel, [1])
|> Nx.add(bias)
end

@doc """
Expand Down
8 changes: 4 additions & 4 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,11 @@ defmodule Bumblebee.Layers.Transformer do
* `:max_positions` - the maximum number of distinct positions
* `:rotary_embedding_base` - base for computing rotary embedding frequency. Defaults
to `10_000`.
* `:base` - base for computing rotary embedding frequency. Defaults
to `10_000`.
* `:rotary_percentage` - percentage of hidden dimensions to allocate to rotary embeddings.
Defaults to `1.0`.
* `:percentage` - percentage of hidden dimensions to allocate to rotary embeddings.
Defaults to `1.0`.
* `:name` - the prefix for layer names
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/mistral.ex
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ defmodule Bumblebee.Text.Mistral do

gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false)

hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation))
hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation))

Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false)
end
Expand Down
Loading

0 comments on commit a09b230

Please sign in to comment.