Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supports loading .safetensors params file #231

Merged
merged 7 commits into from
Aug 4, 2023

Conversation

grzuy
Copy link
Contributor

@grzuy grzuy commented Aug 2, 2023

closes #96

Opening proof of concept as draft while I continue working on some improvements and test coverage, and potentially any other feedback folks have :-)

lib/bumblebee.ex Outdated Show resolved Hide resolved
@grzuy grzuy changed the title feat: supports loading .safetensors params feat: supports loading .safetensors params file Aug 2, 2023
lib/bumblebee.ex Outdated Show resolved Hide resolved
mix.exs Outdated Show resolved Hide resolved
@jonatanklosko
Copy link
Member

Thanks for the PR, a couple minor comments :)

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@grzuy grzuy marked this pull request as ready for review August 4, 2023 12:38
@grzuy
Copy link
Contributor Author

grzuy commented Aug 4, 2023

FWIW I plan to explore in a follow up PR also supporting safetensors sharded params files.

mix.exs Outdated Show resolved Hide resolved
Co-authored-by: Jonatan Kłosko <[email protected]>
@jonatanklosko jonatanklosko merged commit 3f691f2 into elixir-nx:main Aug 4, 2023
@jonatanklosko
Copy link
Member

Thanks a lot!

@grzuy grzuy deleted the safetensors branch August 4, 2023 16:45
@grzuy
Copy link
Contributor Author

grzuy commented Aug 4, 2023

FWIW I plan to explore in a follow up PR also supporting safetensors sharded params files.

Mmm, correcting myself.

I think with the changes in this PR one would be able to load sharded safetensors param files.

For example, for

https://huggingface.co/stabilityai/StableBeluga-7B/tree/main

which contains the following files

model-00001-of-00002.safetensors
model-00002-of-00002.safetensors
model.safetensors.index.json

using

Bumblebee.load_model({:hf, "stabilityai/StableBeluga-7B"}, params_filename: "model.safetensors")

should use the existing sharded params loading logic, look at the index file and "just work".

@grzuy
Copy link
Contributor Author

grzuy commented Aug 4, 2023

What would be a good improvement might be adding some decent auto-selection of the preferred file format based on what's available in the model repo without having the user needed to explicitly provide the file name.

@jonatanklosko
Copy link
Member

Good call, so far most repos had the pytorch file and optionally other formats, but as safetensors become more popular there may be cases where it's just safetensors. Currently we do fallbacks, that is, request one file, if doesn't exist request another, and so on. I checked and looks like HF API now allows listing files, so I will later reevaluate if we can improve :)

@jonatanklosko
Copy link
Member

FTR as of #256 we automatically detect if there are no parameters in the pytorch format, but safetensors one is available :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support .safetensors deserialization/serialization
3 participants