Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defer LayerScale initialization for compatibility with "meta" devices #476

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

baldassarreFe
Copy link

Context

Pytorch introduced a "meta" device to allow creating a model without allocating memory for the weights: https://pytorch.org/docs/stable/meta.html

Meta devices are useful when a model is too big for a single GPU: create the model on a meta device, configure sharding across multiple devices, allocate memory per-device, initialize the weights last.

Usage example from the docs:

with torch.device("meta"):
    m = nn.Linear(20, 30)
m.to_empty(device="cuda")
m.reset_parameters()

Issue

The current implementation of LayerScale does not support meta devices: the weight gamma is initialized immediately in the constructor, but moving from a meta device to a physical device results in the weight being set to 0.0 and not to the desired value.

Example:

class LayerScale(nn.Module):
    def __init__(
        self,
        dim: int,
        init_values: Union[float, Tensor] = 1e-5,
        inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: Tensor) -> Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

with torch.device("meta"):
    ls = LayerScale(1024, init_values=1e-5)
ls.to_empty(device="cuda")
ls.gamma.mean(), ls.gamma.std()
# 0.0, 0.0

Solution

Move the initialization of the weights using to an accessory function and call it from the constructor. If not using meta devices the behavior is identical to the original implementation. In a meta device context, nn.init.constant is a no-op and reset_parameters() can be called later if/when needed.

Example:

class LayerScale(nn.Module):
    def __init__(
        self,
        dim: int,
        init_values: Union[float, Tensor] = 1e-5,
        inplace: bool = False,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.init_values = init_values
        self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.constant_(self.gamma, self.init_values)

    def forward(self, x: Tensor) -> Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

with torch.device("meta"):
    ls = LayerScale(1024, init_values=1e-5)
ls.to_empty(device="cuda")
ls.reset_parameters()
ls.gamma.mean(), ls.gamma.std()
# 1e-05, 0.0

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants