diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 52e32443..c7fd810d 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -164,6 +164,10 @@ defmodule Bumblebee do "MistralModel" => {Bumblebee.Text.Mistral, :base}, "MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling}, "MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification}, + "PhiModel" => {Bumblebee.Text.Phi, :base}, + "PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling}, + "PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification}, + "PhiForTokenClassification" => {Bumblebee.Text.Phi, :for_token_classification}, "ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification}, "ResNetModel" => {Bumblebee.Vision.ResNet, :base}, "RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling}, @@ -234,6 +238,7 @@ defmodule Bumblebee do "llama" => :llama, "mistral" => :llama, "mbart" => :mbart, + "phi" => :code_gen, "roberta" => :roberta, "t5" => :t5, "whisper" => :whisper, diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index bb2a52f7..4bc4a988 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -382,9 +382,21 @@ defmodule Bumblebee.Layers do * `:kernel_initializer` - initializer for `kernel` weights. Defaults to `:glorot_uniform` + * `:bias_initializer` - initializer for `bias` weights. Defaults + to `:zeros`. + + * `:use_bias` - whether the layer should add bias to the output. + Defaults to `false` + """ def dense_transposed(%Axon{} = x, units, opts \\ []) do - opts = Keyword.validate!(opts, [:name, kernel_initializer: :glorot_uniform]) + opts = + Keyword.validate!(opts, [ + :name, + kernel_initializer: :glorot_uniform, + bias_initializer: :zeros, + use_bias: false + ]) kernel_shape = fn input_shape -> kernel_shape = Axon.Shape.dense_kernel(input_shape, units) @@ -396,13 +408,24 @@ defmodule Bumblebee.Layers do |> List.to_tuple() end + bias_shape = &Axon.Shape.dense_bias(&1, units) + kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) - op = fn x, kernel, _opts -> - Nx.dot(x, [-1], kernel, [1]) - end + {inputs, op} = + if opts[:use_bias] do + bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) + {[x, kernel, bias], &dense_transposed_impl/4} + else + {[x, kernel], &dense_transposed_impl/3} + end - Axon.layer(op, [x, kernel], name: opts[:name], op_name: :dense_transposed) + Axon.layer(op, inputs, name: opts[:name], op_name: :dense_transposed) + end + + deftransformp dense_transposed_impl(x, kernel, bias \\ 0, _opts) do + Nx.dot(x, [-1], kernel, [1]) + |> Nx.add(bias) end @doc """ diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 6deeeb2b..1b05c174 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -285,11 +285,11 @@ defmodule Bumblebee.Layers.Transformer do * `:max_positions` - the maximum number of distinct positions - * `:rotary_embedding_base` - base for computing rotary embedding frequency. Defaults - to `10_000`. + * `:base` - base for computing rotary embedding frequency. Defaults + to `10_000`. - * `:rotary_percentage` - percentage of hidden dimensions to allocate to rotary embeddings. - Defaults to `1.0`. + * `:percentage` - percentage of hidden dimensions to allocate to rotary embeddings. + Defaults to `1.0`. * `:name` - the prefix for layer names diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex index a045943a..fccde216 100644 --- a/lib/bumblebee/text/mistral.ex +++ b/lib/bumblebee/text/mistral.ex @@ -362,7 +362,7 @@ defmodule Bumblebee.Text.Mistral do gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) - hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation)) + hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation)) Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) end diff --git a/lib/bumblebee/text/phi.ex b/lib/bumblebee/text/phi.ex new file mode 100644 index 00000000..6f4f5828 --- /dev/null +++ b/lib/bumblebee/text/phi.ex @@ -0,0 +1,459 @@ +defmodule Bumblebee.Text.Phi do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 51200, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 2048, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + hidden_size: [ + default: 2048, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 8192, + doc: "the dimensionality of intermediate layers" + ], + num_blocks: [ + default: 24, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 32, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: nil, + doc: """ + the number of key-value heads used to implement Grouped Query Attention. If + this value is set to the same as the number of attention heads, it will use + regular MHA. If it's set to 1, it will use MQA, otherwise it uses Grouped Query + Attention + """ + ], + activation: [ + default: :gelu_approx_tanh, + doc: "the activation function" + ], + rotary_embedding_percentage: [ + default: 0.5, + doc: "percentage of the query and keys that will have rotary embedding" + ], + rotary_embedding_base: [ + default: 10_000, + doc: "base for computing rotary embedding frequency" + ], + layer_norm_epsilon: [ + default: 1.0e-12, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ] + ] ++ + Shared.common_options([ + :output_hidden_states, + :output_attentions, + :num_labels, + :id_to_label + ]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + Phi model family. + + ## Architectures + + * `:base` - plain Phi without any head on top + + * `:for_causal_language_modeling` - Phi with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - Phi with a sequence + classification head. The head returns logits corresponding to + possible classes + + * `:for_token_classification` - Phi with a token classification + head. The head returns logits for each token in the original + sequence + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + * `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}` + + Embedded representation of `"input_ids"`, which can be specified + for more control over how `"input_ids"` are embedded than the + model's internal embedding lookup. If `"input_embeddings"` are present, + then `"input_ids"` will be ignored. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). With cache, certain hidden states are + taken from the cache, rather than recomputed on every decoding + pass. The cache should be treated as opaque and initialized with + `Bumblebee.Text.Generation.init_cache/4`. + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification, + :for_token_classification + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output", + use_bias: false + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_token_classification} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + logits = + outputs.hidden_state + |> Axon.dropout( + rate: 0.1, + name: "token_classification_head.dropout" + ) + |> Axon.dense(spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "token_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Axon.layer_norm(decoder_outputs.hidden_state, + name: "output_norm", + epsilon: spec.layer_norm_epsilon + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + # TODO: Axon needs a way to specify ignoring pad tokens + # in gradient + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + layer_norm: [ + epsilon: spec.layer_norm_epsilon + ], + ffn: [ + intermediate_size: spec.intermediate_size, + activation: spec.activation + ], + block_type: &block_impl/3, + causal: true, + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + percentage: spec.rotary_embedding_percentage + ], + query_use_bias: true, + key_use_bias: true, + value_use_bias: true, + output_use_bias: true, + output_hidden_states: spec.output_hidden_states, + output_attentions: spec.output_attentions, + name: join(name, "blocks") + ) + end + + # :parallel block with attention norm applied earlier and without ffn norm + defp block_impl(hidden_state, steps, _name) do + shortcut = hidden_state + + hidden_state = steps.self_attention_norm.(hidden_state) + + {attention_hidden_state, attention_info} = steps.self_attention.(hidden_state) + + {_hidden_state, cross_attention_info} = + steps.cross_attention_maybe.(hidden_state, fn _hidden_state -> + raise "cross attention not supported" + end) + + ffn_hidden_state = steps.ffn.(hidden_state) + + hidden_state = Axon.add([shortcut, attention_hidden_state, ffn_hidden_state]) + + {hidden_state, attention_info, cross_attention_info} + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + # TODO: Tie lm-head to word embedding as a spec option + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output"), + use_bias: true + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + rotary_embedding_base: {"rope_theta", number()}, + rotary_embedding_percentage: {"partial_rotary_factor", number()}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"layer_norm_eps", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.token_embedding" => "model.embed_tokens", + "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.dense", + "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.self_attention.rotary_embedding" => + "model.layers.{n}.self_attn.rotary_emb", + "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.fc1", + "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.fc2", + "output_norm" => "model.final_layernorm", + "language_modeling_head.output" => "lm_head", + "sequence_classification_head.output" => "score", + "token_classification_head.output" => "classifier" + } + end + end +end diff --git a/lib/bumblebee/text/pre_trained_tokenizer.ex b/lib/bumblebee/text/pre_trained_tokenizer.ex index 353baa3a..ea2c20a9 100644 --- a/lib/bumblebee/text/pre_trained_tokenizer.ex +++ b/lib/bumblebee/text/pre_trained_tokenizer.ex @@ -127,6 +127,16 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do clip: %{ special_tokens: %{unk: "<|endoftext|>", pad: "<|endoftext|>", eos: "<|endoftext|>"} }, + code_gen: %{ + special_tokens: %{ + unk: "<|endoftext|>", + bos: "<|endoftext|>", + eos: "<|endoftext|>", + # CodeGen doesn't originally have a pad token, however when necessary + # we pad with the EOS token + pad: "<|endoftext|>" + } + }, distilbert: %{ special_tokens: %{unk: "[UNK]", sep: "[SEP]", pad: "[PAD]", cls: "[CLS]", mask: "[MASK]"} }, diff --git a/mix.exs b/mix.exs index 07ab0ba5..b0a3a366 100644 --- a/mix.exs +++ b/mix.exs @@ -94,6 +94,7 @@ defmodule Bumblebee.MixProject do Bumblebee.Text.Llama, Bumblebee.Text.Mbart, Bumblebee.Text.Mistral, + Bumblebee.Text.Phi, Bumblebee.Text.Roberta, Bumblebee.Text.T5, Bumblebee.Vision.BlipVision, diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 22ed5631..cf82ce51 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -1,4 +1,4 @@ -defmodule Bumblebee.Text.TextGenerationTest do +defmodule Bumblebee.Text.GenerationTest do use ExUnit.Case, async: false import Bumblebee.TestHelpers diff --git a/test/bumblebee/text/phi_test.exs b/test/bumblebee/text/phi_test.exs new file mode 100644 index 00000000..83a2f3d4 --- /dev/null +++ b/test/bumblebee/text/phi_test.exs @@ -0,0 +1,95 @@ +defmodule Bumblebee.Text.PhiTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiModel"}) + + assert %Bumblebee.Text.Phi{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([[[-0.3275, 0.5231, 0.5690], [0.2239, 0.5028, 0.4599], [-0.0979, 1.0183, 0.3350]]]) + ) + end + + test ":for_sequence_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-PhiForSequenceClassification"} + ) + + assert %Bumblebee.Text.Phi{architecture: :for_sequence_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + assert_all_close( + outputs.logits, + Nx.tensor([[0.1403, -0.1382]]) + ) + end + + test ":for_token_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-PhiForTokenClassification"} + ) + + assert %Bumblebee.Text.Phi{architecture: :for_token_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 2} + + assert_all_close( + outputs.logits[[.., 1..3//1, ..]], + Nx.tensor([[[-0.0364, -0.1207], [0.2520, 0.0755], [0.0243, 0.0269]]]) + ) + end + + test ":for_causal_language_modeling" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-PhiForCausalLM"}) + + assert %Bumblebee.Text.Phi{architecture: :for_causal_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 10, 1024} + + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([[[0.2541, 0.0827, 0.0526], [0.1901, 0.1289, 0.0758], [0.1051, 0.0658, -0.1167]]]) + ) + end +end diff --git a/test/bumblebee/text/pre_trained_tokenizer_test.exs b/test/bumblebee/text/pre_trained_tokenizer_test.exs index 91100e21..8f4f91f0 100644 --- a/test/bumblebee/text/pre_trained_tokenizer_test.exs +++ b/test/bumblebee/text/pre_trained_tokenizer_test.exs @@ -164,6 +164,24 @@ defmodule Bumblebee.Text.PreTrainedTokenizerTest do ) end + test ":code_gen" do + assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "microsoft/phi-2"}) + + assert %Bumblebee.Text.PreTrainedTokenizer{type: :code_gen} = tokenizer + + inputs = Bumblebee.apply_tokenizer(tokenizer, ["Hello everyobdy, how are you?"]) + + assert_equal( + inputs["input_ids"], + Nx.tensor([[15496, 790, 672, 9892, 11, 703, 389, 345, 30]]) + ) + + assert_equal( + inputs["attention_mask"], + Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]]) + ) + end + test ":distilbert" do assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "distilbert/distilbert-base-uncased"})