Defer LayerScale initialization for compatibility with "meta" devices #476
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Context
Pytorch introduced a "meta" device to allow creating a model without allocating memory for the weights: https://pytorch.org/docs/stable/meta.html
Meta devices are useful when a model is too big for a single GPU: create the model on a meta device, configure sharding across multiple devices, allocate memory per-device, initialize the weights last.
Usage example from the docs:
Issue
The current implementation of
LayerScale
does not support meta devices: the weightgamma
is initialized immediately in the constructor, but moving from a meta device to a physical device results in the weight being set to 0.0 and not to the desired value.Example:
Solution
Move the initialization of the weights using to an accessory function and call it from the constructor. If not using meta devices the behavior is identical to the original implementation. In a meta device context,
nn.init.constant
is a no-op andreset_parameters()
can be called later if/when needed.Example: