Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alpaca Dataset Updates and Fixes #303

Merged
merged 5 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions tests/torchtune/datasets/test_alpaca_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from unittest.mock import patch

import pytest

from torchtune import datasets
from torchtune.datasets.alpaca import CROSS_ENTROPY_IGNORE_IDX
from torchtune.modules.tokenizer import Tokenizer

from tests.test_utils import get_assets_path


class TestAlpacaDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return Tokenizer.from_file(str(get_assets_path() / "m.model"))

@patch("torchtune.datasets.alpaca.load_dataset")
def test_prompt_generation(self, load_dataset, tokenizer):
"""
Test the the prompt generation based on the alpaca template is correct.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
Test the the prompt generation based on the alpaca template is correct.
Test that the prompt generation based on the alpaca template is correct.

"""

# mock the call to HF datasets
load_dataset.return_value =(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
f'1.Eat a balanced diet and make sure to include plenty of fruits and vegetables.'
f'2. Exercise regularly to keep your body active and strong.'
f'3. Get enough sleep and maintain a consistent sleep schedule.'
)
},
{
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant."
}
]
)

# Expected prompts are taken from the "output" field in
# https://huggingface.co/datasets/tatsu-lab/alpaca
expected_prompts = [
(
f'Below is an instruction that describes a task. Write a response that appropriately '
f'completes the request.\n\n'
f'### Instruction:\nGive three tips for staying healthy.\n\n'
f'### Response:'
),
(
f'Below is an instruction that describes a task, paired with an input that provides further context. '
f'Write a response that appropriately completes the request.\n\n'
f'### Instruction:\nEvaluate this sentence for spelling and grammar mistakes\n\n'
f'### Input:\nHe finnished his meal and left the resturant\n\n'
f'### Response:'
)
]

alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer)

# alpaca_dataset._data contains the raw data loaded from HF's dataset. We need the raw data
# to test the prompt generation since calling __get__item on the alpaca_dataset object will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
# to test the prompt generation since calling __get__item on the alpaca_dataset object will
# to test the prompt generation since calling __getitem__ on the alpaca_dataset object will

# return the encoded input and label
for idx, sample in enumerate(alpaca_dataset._data):
assert (
expected_prompts[idx] == alpaca_dataset._generate_prompt(sample["instruction"], sample["input"])
)

@patch("torchtune.datasets.alpaca.load_dataset")
def test_label_no_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is not masked.
"""

# mock the call to HF datasets
load_dataset.return_value =(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
f'1.Eat a balanced diet and make sure to include plenty of fruits and vegetables.'
f'2. Exercise regularly to keep your body active and strong.'
f'3. Get enough sleep and maintain a consistent sleep schedule.'
)
}
]
)

alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer)
input, labels = alpaca_dataset[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert CROSS_ENTROPY_IGNORE_IDX not in labels

@patch("torchtune.datasets.alpaca.load_dataset")
def test_label_masking(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is masked.
"""

# mock the call to HF datasets
load_dataset.return_value =(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
f'1.Eat a balanced diet and make sure to include plenty of fruits and vegetables.'
f'2. Exercise regularly to keep your body active and strong.'
f'3. Get enough sleep and maintain a consistent sleep schedule.'
)
}
]
)

alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer, train_on_input=False)

# Extract the prompt and tokenize it; we'll need this to test whether we're masking the
# input correctly
sample = alpaca_dataset._data[0]
prompt = alpaca_dataset._generate_prompt(sample["instruction"], sample["input"])
encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False)

# Generate the input and labels
input, labels = alpaca_dataset[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt)

@patch("torchtune.datasets.alpaca.load_dataset")
def test_alpaca_clean(self, load_dataset, tokenizer):
"""
Test whether the input and the labels are correctly created when the input is not masked.
"""

# mock the call to HF datasets
load_dataset.return_value =(
[
{
"instruction": "Give three tips for staying healthy.",
"input": "",
"output": (
f'1.Eat a balanced diet and make sure to include plenty of fruits and vegetables.'
f'2. Exercise regularly to keep your body active and strong.'
f'3. Get enough sleep and maintain a consistent sleep schedule.'
)
}
]
)

alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer, use_clean=True)
input, labels = alpaca_dataset[0]

assert len(input) == len(labels)
assert labels[-1] == tokenizer.eos_id
assert input[0] == tokenizer.bos_id
assert CROSS_ENTROPY_IGNORE_IDX not in labels
108 changes: 89 additions & 19 deletions torchtune/datasets/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,53 @@
from torchtune.modules import Tokenizer


_CROSS_ENTROPY_IGNORE_IDX = -100
CROSS_ENTROPY_IGNORE_IDX = -100

_PROMPT_TEMPLATE = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}


class AlpacaDataset(Dataset):
"""
PyTorch Representation of the Alpaca Dataset
Support for the Alpaca dataset and it's variants from HuggingFace Datasets.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
Support for the Alpaca dataset and it's variants from HuggingFace Datasets.
Support for the Alpaca dataset and its variants from HuggingFace Datasets.

https://huggingface.co/datasets/tatsu-lab/alpaca
from Hugging Face.

Data input format: https://huggingface.co/datasets/tatsu-lab/alpaca#data-instances

The input is created using the prompt template from the original alpaca codebase:
https://github.com/tatsu-lab/stanford_alpaca/blob/761dc5bfbdeeffa89b8bff5d038781a4055f796a/train.py#L31

This follows the following format:
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:{output}"

where `instruction`, `input`, and `output` are fields from the dataset.

Masking of the prompt during training is controlled by the `train_on_input` flag, which is
set to `True` by default (ref: https://github.com/tloen/alpaca-lora/blob/main/finetune.py#L49)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are our thoughts on referring to reference implementations in torchtune? Not sure if citing them sort of implies that we as torchtune are sort of certifying that repo is a reference we endorse / want to compare against in an outward fashion

- If `train_on_input` is True, the prompt is used during training and
contributes to the loss.
- If `train_on_input` is False, the prompt is masked out (tokens replaced with -100)

The version of the dataset used is controlled by the `use_clean` flag which set to False by default.
- If `use_clean` is True, then https://huggingface.co/datasets/yahma/alpaca-cleaned is used
- If `use_clean` is False, then https://huggingface.co/datasets/tatsu-lab/alpaca is used

Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
use_clean (bool): Whether to use the cleaned version of the dataset or not. Default is False.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is True.
**kwargs: Additional keyword arguments to pass to the Alpaca Dataset.


Expand All @@ -37,17 +70,32 @@ class AlpacaDataset(Dataset):
>>> Batch size: 8
"""

def __init__(self, tokenizer: Tokenizer, **kwargs) -> None:
self._data = load_dataset("tatsu-lab/alpaca", split="train")
def __init__(
self,
tokenizer: Tokenizer,
train_on_input: bool = True,
use_clean: bool = True,
**kwargs
) -> None:
dataset_path = (
"yahma/alpaca-cleaned" if use_clean
else "tatsu-lab/alpaca"
)
self._data = load_dataset(dataset_path, split="train")
self._tokenizer = tokenizer
self.train_on_input = train_on_input

def __len__(self):
return len(self._data)

def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
return self._transform(self._data[index]["text"])
return self._transform(
instruction = self._data[index]["instruction"],
input = self._data[index]["input"],
output = self._data[index]["output"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could just define sample = self._data[index] to avoid multiple calls

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

)

def _transform(self, sample: str) -> Tuple[List[int], List[int]]:
def _transform(self, instruction: str, input: str, output: str) -> Tuple[List[int], List[int]]:
"""
Split a sample on ``response`` tag to create input and labels.

Expand All @@ -57,16 +105,38 @@ def _transform(self, sample: str) -> Tuple[List[int], List[int]]:
Returns:
Tuple of encoded inputs and labels.
"""
response_tag = "\n\n### Response:\n"
inst_inp_response_tag = sample[: sample.index(response_tag) + len(response_tag)]
response = sample[sample.index(response_tag) + len(response_tag) :]
inst_inp_response_tag = self._tokenizer.encode(
inst_inp_response_tag, add_bos=True, add_eos=False
prompt = self._generate_prompt(instruction, input)
prompt_with_response = prompt + output

# add bos always; LlamaTokenizer sets this to True by default and neither
# alpaca-lora or the original authors change this
encoded_prompt = self._tokenizer.encode(text=prompt, add_bos=True, add_eos=False)
encoded_prompt_with_response = self._tokenizer.encode(
text=prompt_with_response, add_bos=True, add_eos=True
)
response = self._tokenizer.encode(response, add_bos=False, add_eos=True)
input = inst_inp_response_tag + response
label = [
_CROSS_ENTROPY_IGNORE_IDX for _ in range(len(inst_inp_response_tag))
] + response
assert len(input) == len(label)
return input, label
labels = encoded_prompt_with_response.copy()

if not self.train_on_input:
labels[:len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len(encoded_prompt)

assert len(encoded_prompt_with_response) == len(labels)
return encoded_prompt_with_response, labels

def _generate_prompt(self, instruction: str, input: str) -> str:
"""
Generate prompt from instruction and input.

Args:
instruction (str): Instruction text.
input (str): Input text.

Returns:
Prompt text.
"""
if input:
prompt = _PROMPT_TEMPLATE["prompt_input"].format(
instruction=instruction, input=input
)
else:
prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction)
return prompt
Loading