diff --git a/tests/assets/sentencepiece.model b/tests/assets/sentencepiece.model new file mode 100644 index 0000000000..4e28ff6ebd Binary files /dev/null and b/tests/assets/sentencepiece.model differ diff --git a/tests/torchtune/models/t5/__init__.py b/tests/torchtune/models/t5/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/tests/torchtune/models/t5/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/torchtune/models/t5/test_t5_encoder.py b/tests/torchtune/models/t5/test_t5_encoder.py new file mode 100644 index 0000000000..dbb8dbb472 --- /dev/null +++ b/tests/torchtune/models/t5/test_t5_encoder.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtune.models.t5._component_builders import t5_encoder +from torchtune.training.seed import set_seed + +VOCAB_SIZE = 512 +MAX_SEQ_LEN = 8 +BSZ = 2 +EMBED_DIM = 2 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestT5Encoder: + @pytest.fixture + def model(self): + model = t5_encoder( + embed_dim=EMBED_DIM, + mlp_dim=4, + num_heads=2, + head_dim=EMBED_DIM // 2, + num_layers=2, + rel_pos_num_buckets=4, + rel_pos_max_dist=4, + vocab_size=VOCAB_SIZE, + norm_eps=1e-6, + max_seq_len=MAX_SEQ_LEN, + ) + + for param in model.parameters(): + param.data.uniform_(0, 1) + + return model + + @pytest.fixture + def inputs(self): + return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN)) + + def test_forward(self, model, inputs): + actual = model(inputs) + expected = torch.tensor( + [ + [ + [0.3670, 0.2938], + [0.3692, 0.2921], + [0.3611, 0.2984], + [0.4207, 0.2437], + [0.3447, 0.3106], + [0.3383, 0.3150], + [0.3727, 0.2892], + [0.3996, 0.2653], + ], + [ + [0.3855, 0.2783], + [0.2627, 0.3581], + [0.3601, 0.2992], + [0.3473, 0.3087], + [0.3549, 0.3032], + [0.2871, 0.3459], + [0.2753, 0.3520], + [0.2285, 0.3728], + ], + ] + ) + assert actual.shape == (BSZ, MAX_SEQ_LEN, EMBED_DIM) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + def test_backward(self, model, inputs): + y = model(inputs) + loss = y.mean() + loss.backward() diff --git a/tests/torchtune/models/t5/test_t5_tokenizer.py b/tests/torchtune/models/t5/test_t5_tokenizer.py new file mode 100644 index 0000000000..ceeb0f3f2e --- /dev/null +++ b/tests/torchtune/models/t5/test_t5_tokenizer.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +from tests.common import ASSETS +from torchtune.models.t5._model_builders import t5_tokenizer + + +class TestT5Tokenizer: + @pytest.fixture + def tokenizer(self): + return t5_tokenizer(str(ASSETS / "sentencepiece.model")) + + def test_encoding(self, tokenizer): + texts = [ + "a cow jumping over the moon", + "a helpful AI assistant", + ] + correct_tokens = [ + [3, 9, 9321, 15539, 147, 8, 8114, 1], + [3, 9, 2690, 7833, 6165, 1], + ] + for text, correct in zip(texts, correct_tokens): + tokens = tokenizer.encode(text) + print(tokens) + assert tokens == correct + + def test_decoding(self, tokenizer): + text = "this is torchtune" + assert text == tokenizer.decode(tokenizer.encode(text)) + + def test_call(self, tokenizer): + sample = {"text": "hello world"} + sample = tokenizer(sample) + assert "text" not in sample + assert "tokens" in sample diff --git a/torchtune/models/t5/__init__.py b/torchtune/models/t5/__init__.py new file mode 100644 index 0000000000..da9511099c --- /dev/null +++ b/torchtune/models/t5/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._component_builders import t5_encoder +from ._model_builders import t5_tokenizer, t5_v1p1_xxl_encoder + +__all__ = [ + "t5_encoder", + "t5_tokenizer", + "t5_v1p1_xxl_encoder", +] diff --git a/torchtune/models/t5/_component_builders.py b/torchtune/models/t5/_component_builders.py new file mode 100644 index 0000000000..4867b5036f --- /dev/null +++ b/torchtune/models/t5/_component_builders.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torch import nn + +from torchtune.models.t5._encoder import ( + T5Encoder, + T5EncoderLayer, + T5EncoderSelfAttention, +) +from torchtune.modules.feed_forward import FeedForward +from torchtune.modules.rms_norm import RMSNorm + + +def t5_encoder( + embed_dim: int, + mlp_dim: int, + num_heads: int, + head_dim: int, + num_layers: int, + rel_pos_num_buckets: int, + rel_pos_max_dist: int, + vocab_size: int, + norm_eps: float, + max_seq_len: int, +): + """ + Builder for the T5 encoder. + + T5 paper: https://arxiv.org/abs/1910.10683 + + Args: + embed_dim (int): The model dimension. + mlp_dim (int): The inner dimension of the feed forward layers. + num_heads (int): The number of attention heads. + head_dim (int): The dimension of the attention heads (should equal `embed_dim // num_heads`) + num_layers (int): Number of encoder layers. + rel_pos_num_buckets (int): Number of discrete buckets to divide the relative positions into. + See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` + rel_pos_max_dist (int): Maximum distance for relative positions. + Distances beyond this are grouped into the last bucket. + See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` + vocab_size (int): Vocab size of the model's tokenizer. + norm_eps (float): Small value added to denominator for numerical stability. + max_seq_len (int): The maximum sequence length (context length) of the model. + + Returns: + T5Encoder + """ + token_embedding = nn.Embedding(vocab_size, embed_dim) + + attn = T5EncoderSelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, embed_dim, bias=False), + k_proj=nn.Linear(embed_dim, embed_dim, bias=False), + v_proj=nn.Linear(embed_dim, embed_dim, bias=False), + output_proj=nn.Linear(embed_dim, embed_dim, bias=False), + ) + + mlp = FeedForward( + gate_proj=nn.Linear(embed_dim, mlp_dim, bias=False), + down_proj=nn.Linear(mlp_dim, embed_dim, bias=False), + up_proj=nn.Linear(embed_dim, mlp_dim, bias=False), + activation=nn.GELU(), + ) + + layer = T5EncoderLayer( + attn=attn, + mlp=mlp, + sa_norm=RMSNorm(embed_dim, eps=norm_eps), + mlp_norm=RMSNorm(embed_dim, eps=norm_eps), + ) + + final_norm = RMSNorm(embed_dim, eps=norm_eps) + + return T5Encoder( + token_embedding=token_embedding, + layer=layer, + final_norm=final_norm, + num_layers=num_layers, + num_heads=num_heads, + rel_pos_num_buckets=rel_pos_num_buckets, + rel_pos_max_dist=rel_pos_max_dist, + max_seq_len=max_seq_len, + ) diff --git a/torchtune/models/t5/_convert_weights.py b/torchtune/models/t5/_convert_weights.py new file mode 100644 index 0000000000..bb2e72f658 --- /dev/null +++ b/torchtune/models/t5/_convert_weights.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.models.convert_weights import get_mapped_key + +# state dict key mappings from HF's format to torchtune's format +_FROM_HF = { + # emb + "encoder.embed_tokens.weight": "token_embedding.weight", + "encoder.block.{}.layer._0.SelfAttention.relative_attention_bias.weight": "relative_position_bias.embedding.weight", + # attn + "encoder.block.{}.layer._0.SelfAttention.q.weight": "layers.{}.attn.q_proj.weight", + "encoder.block.{}.layer._0.SelfAttention.k.weight": "layers.{}.attn.k_proj.weight", + "encoder.block.{}.layer._0.SelfAttention.v.weight": "layers.{}.attn.v_proj.weight", + "encoder.block.{}.layer._0.SelfAttention.o.weight": "layers.{}.attn.output_proj.weight", + # ff + "encoder.block.{}.layer._1.DenseReluDense.wi_0.weight": "layers.{}.mlp.w1.weight", + "encoder.block.{}.layer._1.DenseReluDense.wo.weight": "layers.{}.mlp.w2.weight", + "encoder.block.{}.layer._1.DenseReluDense.wi_1.weight": "layers.{}.mlp.w3.weight", + # norm + "encoder.block.{}.layer._0.layer_norm.weight": "layers.{}.sa_norm.scale", + "encoder.block.{}.layer._1.layer_norm.weight": "layers.{}.mlp_norm.scale", + "encoder.final_layer_norm.weight": "final_norm.scale", +} + +_IGNORE = { + "shared.weight", + "lm_head.weight", +} + + +def t5_encoder_hf_to_tune(state_dict): + converted_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("decoder.") or key in _IGNORE: + continue + + # NOTE: HF's T5 has ".." parts that we do NOT want to be dynamically mapped + # to corresponding ".." parts in our converted state dict. + # This breaks the `get_mapped_key` implementation, so as a temporary hack, + # we add leading underscores to these parts here and in the `_FROM_HF` map above. + key = key.replace("layer.0.", "layer._0.").replace("layer.1.", "layer._1.") + + new_key = get_mapped_key(key, _FROM_HF) + converted_state_dict[new_key] = value + return converted_state_dict diff --git a/torchtune/models/t5/_encoder.py b/torchtune/models/t5/_encoder.py new file mode 100644 index 0000000000..7828e9ecc5 --- /dev/null +++ b/torchtune/models/t5/_encoder.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy +import math + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from torchtune.modules import MultiHeadAttention + + +class T5Encoder(nn.Module): + """ + The T5 encoder module. + + T5 paper: https://arxiv.org/abs/1910.10683 + + Args: + token_embedding (nn.Embedding): PyTorch embedding layer to place tokens in an embedding space. + layer (nn.Module): A single encoder layer. + final_norm (nn.Module): Module that applies normalization to the output of the encoder + num_layers (int): Number of encoder layers. + num_heads (int): The number of attention heads. + rel_pos_num_buckets (int): Number of discrete buckets to divide the relative positions into. + See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` + rel_pos_max_dist (int): Maximum distance for relative positions. + Distances beyond this are grouped into the last bucket. + See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` + max_seq_len (int): The maximum sequence length (context length) of the model. + """ + + def __init__( + self, + *, + token_embedding: nn.Embedding, + layer: nn.Module, + final_norm: nn.Module, + num_layers: int, + num_heads: int, + rel_pos_num_buckets: int, + rel_pos_max_dist: int, + max_seq_len: int, + ): + super().__init__() + self.token_embedding = token_embedding + self.layers = nn.ModuleList([copy.deepcopy(layer) for i in range(num_layers)]) + self.final_norm = final_norm + self.max_seq_len = max_seq_len + self.relative_position_bias = T5EncoderRelativePositionBias( + num_buckets=rel_pos_num_buckets, + max_dist=rel_pos_max_dist, + num_heads=num_heads, + max_seq_len=max_seq_len, + ) + + def forward(self, tokens: Tensor) -> Tensor: + """ + Args: + tokens (Tensor): input tensor with shape ``[bsz, max_seq_len]`` + + Returns: + Tensor: output tensor with shape [bsz, max_seq_len, embed_dim] + + Raises: + ValueError: if seq_len of tokens is bigger than max_seq_len + """ + # Input validation + bsz, seq_len = tokens.shape + if seq_len > self.max_seq_len: + raise ValueError( + f"seq_len ({seq_len}) of input tensor should be smaller " + f"than max_seq_len ({self.max_seq_len})" + ) + + # Input embedding [bsz, max_seq_len] -> [bsz, max_seq_len, embed_dim] + x = self.token_embedding(tokens) + + # Bias added to the attention scores of every layer (to add relative position information) + rel_pos_bias = self.relative_position_bias() + + # Encoder + for layer in self.layers: + x = layer(x, rel_pos_bias) + + return self.final_norm(x) + + +class T5EncoderLayer(nn.Module): + """ + Single layer of the T5 encoder (standard transformer layer with relative position bias). + + Args: + attn (MultiHeadAttention): Attention module. + mlp (nn.Module): Feed-forward module. + sa_norm (nn.Module): Normalization to be applied before self-attention. + mlp_norm (nn.Module): Normalization to be applied before the feed-forward layer. + """ + + def __init__( + self, + attn: MultiHeadAttention, + mlp: nn.Module, + sa_norm: nn.Module, + mlp_norm: nn.Module, + ) -> None: + super().__init__() + self.attn = attn + self.mlp = mlp + self.sa_norm = sa_norm + self.mlp_norm = mlp_norm + + def forward(self, x: Tensor, rel_pos_bias: Tensor) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape [bsz, seq_len, embed_dim] + rel_pos_bias (Tensor): relative position bias with shape [1, num_heads, max_seq_len, max_seq_len] + See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` + + Returns: + Tensor: output tensor with shape [bsz, seq_len, embed_dim] + """ + x = x + self.attn(self.sa_norm(x), rel_pos_bias) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class T5EncoderSelfAttention(nn.Module): + """ + Self-attention for the T5 encoder. + + Standard self-attention with two differences: + - No scaling factor + - Add "relative position bias" to the attention scores. + (See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias`) + + Args: + embed_dim (int): The model dimension. + num_heads (int): Number of attention heads. + head_dim (int): Dimension of the attention heads (should equal `embed_dim // num_heads`) + q_proj (nn.Module): Projection layer for query. + k_proj (nn.Module): Projection layer for key. + v_proj (nn.Module): Projection layer for value. + output_proj (nn.Module): Projection layer for output. + + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim // num_heads != head_dim`` + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + ): + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + if embed_dim // num_heads != head_dim: + raise ValueError( + f"head_dim ({head_dim}) must be equal to embed_dim // num_heads" + ) + + self.num_heads = num_heads + self.head_dim = head_dim + + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + + def forward(self, x: Tensor, rel_pos_bias: Tensor) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape [bsz, seq_len, embed_dim] + rel_pos_bias (Tensor): relative position bias with shape [1, num_heads, max_seq_len, max_seq_len] + See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias` + + Returns: + Tensor: output tensor with shape [bsz, seq_len, embed_dim] + """ + bsz, seq_len, embed_dim = x.shape + + # QKV projections + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + # [bsz, seq_len, embed_dim] -> [bsz, num_heads, seq_len, head_dim] + q = q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + # attention with relative position bias + attn_score = torch.matmul(q, k.transpose(-2, -1)) + attn_score += rel_pos_bias + attn_weight = F.softmax(attn_score.float(), dim=-1).to(attn_score.dtype) + attn_out = torch.matmul(attn_weight, v) + + # [bsz, num_heads, seq_len, head_dim] -> [bsz, seq_len, embed_dim] + attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, embed_dim) + + return self.output_proj(attn_out) + + +class T5EncoderRelativePositionBias(nn.Module): + """ + Computes binned birectional relative position bias for the T5 encoder. + + It places relative positions into buckets and for each bucket, learns bias values for each attention head. + + Args: + num_buckets (int): Number of discrete buckets to divide the relative positions into. + max_dist (int): Maximum distance for relative positions (distances beyond this are grouped into the last bucket) + num_heads (int): Number of attention heads in the transformer. + max_seq_len (int): Maximum sequence length (context length). + """ + + def __init__( + self, num_buckets: int, max_dist: int, num_heads: int, max_seq_len: int + ): + super().__init__() + self.max_seq_len = max_seq_len + + # learnable mapping from bucket indices to bias values for each attention head + self.embedding = nn.Embedding(num_buckets, num_heads) + + # fixed mapping from relative positions to bucket indices + self.register_buffer( + "relative_position_to_bucket", + _calc_birectional_rel_pos_to_bucket(num_buckets, max_dist, max_seq_len), + persistent=False, + ) + + def forward(self) -> Tensor: + """ + Returns: + torch.Tensor: relative position bias tensor with shape [1, num_heads, max_seq_len, max_seq_len] + """ + # convert bucket numbers to bias values for each attention head + x = self.embedding(self.relative_position_to_bucket) + + # shape [max_seq_len, max_seq_len, num_heads] -> [1, num_heads, max_seq_len, max_seq_len] + return x.permute([2, 0, 1]).unsqueeze(0) + + +def _calc_birectional_rel_pos_to_bucket( + num_buckets: int, max_dist: int, max_seq_len: int +) -> Tensor: + """ + Calculate the mapping from relative positions to bucket indices. + + NOTE: This is for the T5 encoder (birectional), not the decoder (unidirectional). + + Args: + num_buckets (int): Number of discrete buckets to divide the relative positions into. + max_dist (int): Maximum distance for relative positions (distances beyond this are grouped into the last bucket) + max_seq_len (int): Maximum sequence length (context length). + + Returns: + Tensor: shape=[max_seq_len, max_seq_len], range=[0, num_buckets] + """ + query_positions = torch.arange(max_seq_len, dtype=torch.long)[:, None] + key_positions = torch.arange(max_seq_len, dtype=torch.long)[None, :] + relative_positions = key_positions - query_positions + abs_relative_positions = torch.abs(relative_positions) + # relative positions shape: [max_seq_len, max_seq_len] + + # divide the buckets into half for past/present (rel pos <= 0) and half for future (rel pos > 0) + # half of the buckets in each half are for exact relative positions + half_num_buckets = num_buckets // 2 + max_exact = half_num_buckets // 2 + is_exact = abs_relative_positions < max_exact + + # the rest are for logarithmically bigger bins in positions up to max_distance + relative_position_if_not_exact = max_exact + ( + torch.log(abs_relative_positions.float() / max_exact) + / math.log(max_dist / max_exact) + * (half_num_buckets - max_exact) + ).to(torch.long) + relative_position_if_not_exact = torch.min( + relative_position_if_not_exact, + torch.full_like(relative_position_if_not_exact, half_num_buckets - 1), + ) + + # calculate the mapping from relative postion to bucket + relative_position_to_bucket = (relative_positions > 0).to( + torch.long + ) * half_num_buckets + torch.where( + is_exact, abs_relative_positions, relative_position_if_not_exact + ) + + return relative_position_to_bucket diff --git a/torchtune/models/t5/_model_builders.py b/torchtune/models/t5/_model_builders.py new file mode 100644 index 0000000000..641d8e0081 --- /dev/null +++ b/torchtune/models/t5/_model_builders.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.models.t5._component_builders import t5_encoder +from torchtune.models.t5._encoder import T5Encoder +from torchtune.models.t5._tokenizer import T5Tokenizer + + +def t5_v1p1_xxl_encoder(max_seq_len: int = 512) -> T5Encoder: + """ + Builder for the T5 v1.1 XXL (11B parameters) encoder. + + T5 paper: https://arxiv.org/abs/1910.10683 + + 1.1 release: + https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511 + + Args: + max_seq_len (int): The maximum sequence length (context length) of the model. + Default: 512 + + Returns: + T5Encoder: Instantiation of the T5 encoder + """ + return t5_encoder( + embed_dim=4096, + mlp_dim=10240, + num_heads=64, + head_dim=64, + num_layers=24, + rel_pos_num_buckets=32, + rel_pos_max_dist=128, + vocab_size=32128, + norm_eps=1e-6, + max_seq_len=max_seq_len, + ) + + +def t5_tokenizer(path: str, max_seq_len: int = 512, truncate: bool = True): + """ + Builder for the T5 tokenizer. + + Args: + path (str): the path to the T5 sentencepiece tokenizer file + max_seq_len (int): the context length + truncate (bool): whether to truncate the token sequence when longer than max_seq_len + + Returns: + T5Tokenizer: Instantiation of the T5 tokenizer + """ + return T5Tokenizer(path, max_seq_len=max_seq_len, truncate=truncate) diff --git a/torchtune/models/t5/_tokenizer.py b/torchtune/models/t5/_tokenizer.py new file mode 100644 index 0000000000..f89dff00f6 --- /dev/null +++ b/torchtune/models/t5/_tokenizer.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any, Dict, List + +from torchtune.modules.tokenizers._sentencepiece import SentencePieceBaseTokenizer + + +class T5Tokenizer(SentencePieceBaseTokenizer): + """ + Text tokenizer for T5. + + Args: + path (str): the path to the T5 sentencepiece tokenizer file + max_seq_len (int): the context length + truncate (bool): whether to truncate the token sequence when longer than max_seq_len + """ + + def __init__(self, path: str, max_seq_len: int = 512, truncate: bool = True): + super().__init__(path) + self.max_seq_len = max_seq_len + self.truncate = truncate + + def encode(self, text: str) -> List[int]: + """ + Given a string, return the encoded list of token ids. + + Args: + text (str): The text to encode. + + Returns: + List[int]: The encoded list of token ids. + """ + tokens = super().encode( + text, + add_bos=False, + add_eos=True, + trim_leading_whitespace=False, + prefix=None, + ) + if len(tokens) > self.max_seq_len: + assert self.truncate, ( + "Tokenized text is larger than the maximum sequence length but " + "truncate is set to False." + ) + tokens = tokens[: self.max_seq_len] + tokens[-1] = self.eos_id + return tokens + + def __call__( + self, sample: Dict[str, Any], inference: bool = False + ) -> Dict[str, Any]: + """ + Tokenize the "text" field in the sample. + + Args: + sample (Dict[str, Any]): A sample with a "text" field containing a string to tokenize + inference (bool): Unused by this tokenizer + + Returns: + Dict[str, Any]: The sample with added "tokens" field and the "messages" field removed. + """ + text = sample.pop("text") + sample["tokens"] = self.encode(text) + return sample diff --git a/torchtune/modules/tokenizers/_sentencepiece.py b/torchtune/modules/tokenizers/_sentencepiece.py index 13f75a52ec..0b22b63ee3 100644 --- a/torchtune/modules/tokenizers/_sentencepiece.py +++ b/torchtune/modules/tokenizers/_sentencepiece.py @@ -7,6 +7,7 @@ from typing import List, Optional from sentencepiece import SentencePieceProcessor + from torchtune.modules.tokenizers._utils import BaseTokenizer WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"] diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index fcb4bd131e..a7e15a9768 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -18,6 +18,7 @@ from torchtune.models.clip._convert_weights import clip_text_hf_to_tune from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf +from torchtune.models.t5._convert_weights import t5_encoder_hf_to_tune from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf from torchtune.training.checkpointing._utils import ( FormattedCheckpointFiles, @@ -502,6 +503,10 @@ def load_checkpoint(self) -> Dict[str, Any]: dim=self._config["hidden_size"], head_dim=self._config.get("head_dim", None), ) + elif self._model_type == ModelType.T5_ENCODER: + converted_state_dict[training.MODEL_KEY] = t5_encoder_hf_to_tune( + merged_state_dict, + ) else: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index e6a7b1afa1..ef5bfd57c2 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -57,6 +57,7 @@ class ModelType(Enum): See :func:`~torchtune.models.mistral.mistral_reward_7b` or :func:`~torchtune.models.llama2.llama2_reward_7b` QWEN2 (str): Qwen2 family of models. See :func:`~torchtune.models.qwen2.qwen2` CLIP_TEXT (str): CLIP text encoder. See :func:`~torchtune.models.clip.clip_text_encoder_large` + T5_ENCODER (str): T5 text encoder. See :func:`~torchtune.models.t5.t5_v1p1_xxl_encoder` Example: >>> # Usage in a checkpointer class @@ -77,6 +78,7 @@ class ModelType(Enum): REWARD: str = "reward" QWEN2: str = "qwen2" CLIP_TEXT: str = "clip_text" + T5_ENCODER: str = "t5_encoder" class FormattedCheckpointFiles: