Skip to content

Commit

Permalink
Alpaca Dataset Updates and Fixes (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikayk authored Feb 4, 2024
1 parent f1537ee commit aaf43de
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 31 deletions.
7 changes: 6 additions & 1 deletion recipes/finetune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def recipe(
grad_scaler = GradScaler(enabled=False)

# ---- Load dataset, set up sampler, and dataloader ---- #
ds = datasets.get_dataset(params.dataset, split="train", tokenizer=tokenizer)
ds = datasets.get_dataset(
params.dataset,
split="train",
tokenizer=tokenizer,
train_on_input=params.train_on_input,
)
sampler = DistributedSampler(
ds,
num_replicas=world_size,
Expand Down
10 changes: 8 additions & 2 deletions recipes/full_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def setup(self, params: FullFinetuneParams) -> None:
# setup after both of these are initialized
self._sampler, self._dataloader = self._setup_data(
dataset=params.dataset,
train_on_input=params.train_on_input,
shuffle=params.shuffle,
batch_size=params.batch_size,
)
Expand Down Expand Up @@ -240,15 +241,20 @@ def _setup_loss(self, loss: str) -> nn.Module:
return loss_fn

def _setup_data(
self, dataset: str, shuffle: bool, batch_size: int
self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports the
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
ds = datasets.get_dataset(dataset, split="train", tokenizer=self._tokenizer)
ds = datasets.get_dataset(
dataset,
split="train",
tokenizer=self._tokenizer,
train_on_input=train_on_input,
)
sampler = DistributedSampler(
ds,
num_replicas=world_size,
Expand Down
1 change: 1 addition & 0 deletions recipes/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class FullFinetuneParams:

# Dataset and Sampler
dataset: str = ""
train_on_input: bool = True
shuffle: bool = True
batch_size: int = 2

Expand Down
19 changes: 11 additions & 8 deletions recipes/tests/test_finetune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ def _fetch_loss_values(self, output) -> Dict[str, float]:

def _fetch_expected_loss_values(self, ckpt) -> Dict[str, float]:
small_test_ckpt_loss_values = {
"1|1|": 10.5011,
"1|2|": 10.5740,
"2|1|": 10.5221,
"2|2|": 10.4835,
"1|1|": 10.5074,
"1|2|": 10.5563,
"2|1|": 10.5152,
"2|2|": 10.4851,
}
llama2_7b_ckpt_loss_values = {
"1|1|": 1.2381,
"1|2|": 1.1042,
"2|1|": 1.3086,
"2|2|": 0.9908,
"1|1|": 1.1333,
"1|2|": 1.1199,
"2|1|": 1.2614,
"2|2|": 0.9486,
}
if ckpt == "small_test_ckpt":
return small_test_ckpt_loss_values
Expand All @@ -79,6 +79,7 @@ def test_finetune_llm_loss(self, capsys, pytestconfig):

kwargs_values = {
"dataset": "alpaca",
"train_on_input": False,
"seed": 9,
"shuffle": True,
"model": ckpt,
Expand Down Expand Up @@ -120,6 +121,7 @@ def test_finetune_errors(self, capsys, pytestconfig):

kwargs_values = {
"dataset": "alpaca",
"train_on_input": False,
"seed": 9,
"shuffle": True,
"model": ckpt,
Expand Down Expand Up @@ -157,6 +159,7 @@ def test_finetune_llm_loss_refactored(self, capsys, pytestconfig):

kwargs_values = {
"dataset": "alpaca",
"train_on_input": False,
"seed": 9,
"shuffle": True,
"model": ckpt,
Expand Down
168 changes: 168 additions & 0 deletions tests/torchtune/datasets/test_alpaca_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest.mock import patch

import pytest

from torchtune import datasets
from torchtune.datasets.alpaca import CROSS_ENTROPY_IGNORE_IDX
from torchtune.modules.tokenizer import Tokenizer

from tests.test_utils import get_assets_path


class TestAlpacaDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return Tokenizer.from_file(str(get_assets_path() / "m.model"))

@patch("torchtune.datasets.alpaca.load_dataset")
def test_prompt_generation(self, load_dataset, tokenizer):
"""
Test the prompt generation based on the alpaca template is correct.
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
},
{
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant.",
},
]

# Expected prompts are taken from the "output" field in
# https://huggingface.co/datasets/tatsu-lab/alpaca
expected_prompts = [
(
"Below is an instruction that describes a task. Write a response that appropriately "
"completes the request.\n\n"
"### Instruction:\nGive three tips for staying healthy.\n\n"
"### Response:\n"
),
(
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\nEvaluate this sentence for spelling and grammar mistakes\n\n"
"### Input:\nHe finnished his meal and left the resturant\n\n"
"### Response:\n"
),
]

alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer)

# alpaca_dataset._data contains the raw data loaded from HF's dataset. We need the raw data
# to test the prompt generation since calling __getitem__ on the alpaca_dataset object will
# return the encoded input and label
for idx, sample in enumerate(alpaca_dataset._data):
assert expected_prompts[idx] == alpaca_dataset._generate_prompt(
sample["instruction"], sample["input"]
)

@patch("torchtune.datasets.alpaca.load_dataset")
def test_label_no_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is not masked.
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]

alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer)
input, labels = alpaca_dataset[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert CROSS_ENTROPY_IGNORE_IDX not in labels

@patch("torchtune.datasets.alpaca.load_dataset")
def test_label_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is masked.
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]

alpaca_dataset = datasets.get_dataset(
"alpaca", tokenizer=tokenizer, train_on_input=False
)

# Extract the prompt and tokenize it; we'll need this to test whether we're masking the
# input correctly
sample = alpaca_dataset._data[0]
prompt = alpaca_dataset._generate_prompt(sample["instruction"], sample["input"])
encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False)

# Generate the input and labels
input, labels = alpaca_dataset[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt)

@patch("torchtune.datasets.alpaca.load_dataset")
def test_alpaca_clean(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is not masked.
"""

# mock the call to HF datasets
load_dataset.return_value = [
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables."
"2. Exercise regularly to keep your body active and strong."
"3. Get enough sleep and maintain a consistent sleep schedule."
),
}
]

alpaca_dataset = datasets.get_dataset(
"alpaca", tokenizer=tokenizer, use_clean=True
)
input, labels = alpaca_dataset[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert CROSS_ENTROPY_IGNORE_IDX not in labels
Loading

0 comments on commit aaf43de

Please sign in to comment.