diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index f6e5b4d9..3c66ba8e 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -379,7 +379,8 @@ defmodule Bumblebee do * `:params_filename` - the file with the model parameters to be loaded * `:log_params_diff` - whether to log missing, mismatched and unused - parameters. Defaults to `true` + parameters. By default diff is logged only if some parameters + cannot be loaded * `:backend` - the backend to allocate the tensors on. It is either an atom or a tuple in the shape `{backend, options}` @@ -416,7 +417,7 @@ defmodule Bumblebee do :architecture, :params_filename, :backend, - log_params_diff: true + :log_params_diff ]) spec_response = @@ -444,8 +445,6 @@ defmodule Bumblebee do # TODO: support format: :auto | :axon | :pytorch format = :pytorch filename = opts[:params_filename] || @params_filename[format] - log_params_diff = opts[:log_params_diff] - backend = opts[:backend] input_template = module.input_template(spec) @@ -453,10 +452,11 @@ defmodule Bumblebee do with {:ok, path} <- download(repository, filename) do params = - Bumblebee.Conversion.PyTorch.load_params!(model, input_template, path, - log_params_diff: log_params_diff, - backend: backend, - params_mapping: params_mapping + Bumblebee.Conversion.PyTorch.load_params!( + model, + input_template, + path, + [params_mapping: params_mapping] ++ Keyword.take(opts, [:backend, :log_params_diff]) ) {:ok, params} diff --git a/lib/bumblebee/conversion/pytorch.ex b/lib/bumblebee/conversion/pytorch.ex index c00b1a8b..265be781 100644 --- a/lib/bumblebee/conversion/pytorch.ex +++ b/lib/bumblebee/conversion/pytorch.ex @@ -14,7 +14,8 @@ defmodule Bumblebee.Conversion.PyTorch do ## Options * `:log_params_diff` - whether to log missing, mismatched and unused - parameters. Defaults to `true` + parameters. By default diff is logged only if some parameters + cannot be loaded * `:backend` - the backend to allocate the tensors on. It is either an atom or a tuple in the shape `{backend, options}` @@ -26,7 +27,7 @@ defmodule Bumblebee.Conversion.PyTorch do """ @spec load_params!(Axon.t(), map(), Path.t(), keyword()) :: map() def load_params!(model, input_template, path, opts \\ []) do - opts = Keyword.validate!(opts, log_params_diff: true, backend: nil, params_mapping: %{}) + opts = Keyword.validate!(opts, [:log_params_diff, :backend, params_mapping: %{}]) with_default_backend(opts[:backend], fn -> pytorch_state = Bumblebee.Conversion.PyTorch.Loader.load!(path) @@ -39,15 +40,17 @@ defmodule Bumblebee.Conversion.PyTorch do {params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping]) + params_complete? = diff.missing == [] and diff.mismatched == [] + params = - if diff.missing == [] and diff.mismatched == [] do + if params_complete? do params else {init_fun, _} = Axon.build(model, compiler: Nx.Defn.Evaluator) init_fun.(input_template, params) end - if opts[:log_params_diff] do + if Keyword.get(opts, :log_params_diff, not params_complete?) do log_params_diff(diff) end diff --git a/lib/bumblebee/diffusion/layers/unet.ex b/lib/bumblebee/diffusion/layers/unet.ex index dbec3e48..4cb9576f 100644 --- a/lib/bumblebee/diffusion/layers/unet.ex +++ b/lib/bumblebee/diffusion/layers/unet.ex @@ -316,6 +316,9 @@ defmodule Bumblebee.Diffusion.Layers.UNet do num_blocks: depth, num_attention_heads: num_heads, hidden_size: hidden_size, + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, layer_norm: [ epsilon: 1.0e-5 ], diff --git a/lib/bumblebee/huggingface/transformers/utils.ex b/lib/bumblebee/huggingface/transformers/utils.ex index 79d9348b..b0753c79 100644 --- a/lib/bumblebee/huggingface/transformers/utils.ex +++ b/lib/bumblebee/huggingface/transformers/utils.ex @@ -28,7 +28,7 @@ defmodule Bumblebee.HuggingFace.Transformers.Utils do @spec map_params_source_layer_names( Transformers.Model.params_source(), (String.t() -> String.t()) - ) :: Transformers.Model.params_source() + ) :: Transformers.Model.layer_name() | Transformers.Model.params_source() def map_params_source_layer_names(%{} = params_source, fun) do Map.new(params_source, fn {param_name, {sources, source_fun}} -> sources = for {layer_name, param_name} <- sources, do: {fun.(layer_name), param_name}