-
Notifications
You must be signed in to change notification settings - Fork 457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Alpaca Dataset Updates and Fixes #303
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,166 @@ | ||||||
from unittest.mock import patch | ||||||
|
||||||
import pytest | ||||||
|
||||||
from torchtune import datasets | ||||||
from torchtune.datasets.alpaca import CROSS_ENTROPY_IGNORE_IDX | ||||||
from torchtune.modules.tokenizer import Tokenizer | ||||||
|
||||||
from tests.test_utils import get_assets_path | ||||||
|
||||||
|
||||||
class TestAlpacaDataset: | ||||||
@pytest.fixture | ||||||
def tokenizer(self): | ||||||
# m.model is a pretrained Sentencepiece model using the following command: | ||||||
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000') | ||||||
return Tokenizer.from_file(str(get_assets_path() / "m.model")) | ||||||
|
||||||
@patch("torchtune.datasets.alpaca.load_dataset") | ||||||
def test_prompt_generation(self, load_dataset, tokenizer): | ||||||
""" | ||||||
Test the the prompt generation based on the alpaca template is correct. | ||||||
""" | ||||||
|
||||||
# mock the call to HF datasets | ||||||
load_dataset.return_value =( | ||||||
[ | ||||||
{ | ||||||
"instruction": "Give three tips for staying healthy.", | ||||||
"input": "", | ||||||
"output": ( | ||||||
f'1.Eat a balanced diet and make sure to include plenty of fruits and vegetables.' | ||||||
f'2. Exercise regularly to keep your body active and strong.' | ||||||
f'3. Get enough sleep and maintain a consistent sleep schedule.' | ||||||
) | ||||||
}, | ||||||
{ | ||||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes", | ||||||
"input": "He finnished his meal and left the resturant", | ||||||
"output": "He finished his meal and left the restaurant." | ||||||
} | ||||||
] | ||||||
) | ||||||
|
||||||
# Expected prompts are taken from the "output" field in | ||||||
# https://huggingface.co/datasets/tatsu-lab/alpaca | ||||||
expected_prompts = [ | ||||||
( | ||||||
f'Below is an instruction that describes a task. Write a response that appropriately ' | ||||||
f'completes the request.\n\n' | ||||||
f'### Instruction:\nGive three tips for staying healthy.\n\n' | ||||||
f'### Response:' | ||||||
), | ||||||
( | ||||||
f'Below is an instruction that describes a task, paired with an input that provides further context. ' | ||||||
f'Write a response that appropriately completes the request.\n\n' | ||||||
f'### Instruction:\nEvaluate this sentence for spelling and grammar mistakes\n\n' | ||||||
f'### Input:\nHe finnished his meal and left the resturant\n\n' | ||||||
f'### Response:' | ||||||
) | ||||||
] | ||||||
|
||||||
alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer) | ||||||
|
||||||
# alpaca_dataset._data contains the raw data loaded from HF's dataset. We need the raw data | ||||||
# to test the prompt generation since calling __get__item on the alpaca_dataset object will | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
|
||||||
# return the encoded input and label | ||||||
for idx, sample in enumerate(alpaca_dataset._data): | ||||||
assert ( | ||||||
expected_prompts[idx] == alpaca_dataset._generate_prompt(sample["instruction"], sample["input"]) | ||||||
) | ||||||
|
||||||
@patch("torchtune.datasets.alpaca.load_dataset") | ||||||
def test_label_no_masking(self, load_dataset, tokenizer): | ||||||
""" | ||||||
Test whether the input and the labels are correctly created when the input is not masked. | ||||||
""" | ||||||
|
||||||
# mock the call to HF datasets | ||||||
load_dataset.return_value =( | ||||||
[ | ||||||
{ | ||||||
"instruction": "Give three tips for staying healthy.", | ||||||
"input": "", | ||||||
"output": ( | ||||||
f'1.Eat a balanced diet and make sure to include plenty of fruits and vegetables.' | ||||||
f'2. Exercise regularly to keep your body active and strong.' | ||||||
f'3. Get enough sleep and maintain a consistent sleep schedule.' | ||||||
) | ||||||
} | ||||||
] | ||||||
) | ||||||
|
||||||
alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer) | ||||||
input, labels = alpaca_dataset[0] | ||||||
|
||||||
assert len(input) == len(labels) | ||||||
assert labels[-1] == tokenizer.eos_id | ||||||
assert input[0] == tokenizer.bos_id | ||||||
assert CROSS_ENTROPY_IGNORE_IDX not in labels | ||||||
|
||||||
@patch("torchtune.datasets.alpaca.load_dataset") | ||||||
def test_label_masking(self, load_dataset, tokenizer): | ||||||
""" | ||||||
Test whether the input and the labels are correctly created when the input is masked. | ||||||
""" | ||||||
|
||||||
# mock the call to HF datasets | ||||||
load_dataset.return_value =( | ||||||
[ | ||||||
{ | ||||||
"instruction": "Give three tips for staying healthy.", | ||||||
"input": "", | ||||||
"output": ( | ||||||
f'1.Eat a balanced diet and make sure to include plenty of fruits and vegetables.' | ||||||
f'2. Exercise regularly to keep your body active and strong.' | ||||||
f'3. Get enough sleep and maintain a consistent sleep schedule.' | ||||||
) | ||||||
} | ||||||
] | ||||||
) | ||||||
|
||||||
alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer, train_on_input=False) | ||||||
|
||||||
# Extract the prompt and tokenize it; we'll need this to test whether we're masking the | ||||||
# input correctly | ||||||
sample = alpaca_dataset._data[0] | ||||||
prompt = alpaca_dataset._generate_prompt(sample["instruction"], sample["input"]) | ||||||
encoded_prompt = tokenizer.encode(text=prompt, add_bos=True, add_eos=False) | ||||||
|
||||||
# Generate the input and labels | ||||||
input, labels = alpaca_dataset[0] | ||||||
|
||||||
assert len(input) == len(labels) | ||||||
assert labels[-1] == tokenizer.eos_id | ||||||
assert input[0] == tokenizer.bos_id | ||||||
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == len(encoded_prompt) | ||||||
|
||||||
@patch("torchtune.datasets.alpaca.load_dataset") | ||||||
def test_alpaca_clean(self, load_dataset, tokenizer): | ||||||
""" | ||||||
Test whether the input and the labels are correctly created when the input is not masked. | ||||||
""" | ||||||
|
||||||
# mock the call to HF datasets | ||||||
load_dataset.return_value =( | ||||||
[ | ||||||
{ | ||||||
"instruction": "Give three tips for staying healthy.", | ||||||
"input": "", | ||||||
"output": ( | ||||||
f'1.Eat a balanced diet and make sure to include plenty of fruits and vegetables.' | ||||||
f'2. Exercise regularly to keep your body active and strong.' | ||||||
f'3. Get enough sleep and maintain a consistent sleep schedule.' | ||||||
) | ||||||
} | ||||||
] | ||||||
) | ||||||
|
||||||
alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer, use_clean=True) | ||||||
input, labels = alpaca_dataset[0] | ||||||
|
||||||
assert len(input) == len(labels) | ||||||
assert labels[-1] == tokenizer.eos_id | ||||||
assert input[0] == tokenizer.bos_id | ||||||
assert CROSS_ENTROPY_IGNORE_IDX not in labels |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -13,20 +13,53 @@ | |||||
from torchtune.modules import Tokenizer | ||||||
|
||||||
|
||||||
_CROSS_ENTROPY_IGNORE_IDX = -100 | ||||||
CROSS_ENTROPY_IGNORE_IDX = -100 | ||||||
|
||||||
_PROMPT_TEMPLATE = { | ||||||
"prompt_input": ( | ||||||
"Below is an instruction that describes a task, paired with an input that provides further context. " | ||||||
"Write a response that appropriately completes the request.\n\n" | ||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" | ||||||
), | ||||||
"prompt_no_input": ( | ||||||
"Below is an instruction that describes a task. " | ||||||
"Write a response that appropriately completes the request.\n\n" | ||||||
"### Instruction:\n{instruction}\n\n### Response:" | ||||||
), | ||||||
} | ||||||
|
||||||
|
||||||
class AlpacaDataset(Dataset): | ||||||
""" | ||||||
PyTorch Representation of the Alpaca Dataset | ||||||
Support for the Alpaca dataset and it's variants from HuggingFace Datasets. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
|
||||||
https://huggingface.co/datasets/tatsu-lab/alpaca | ||||||
from Hugging Face. | ||||||
|
||||||
Data input format: https://huggingface.co/datasets/tatsu-lab/alpaca#data-instances | ||||||
|
||||||
The input is created using the prompt template from the original alpaca codebase: | ||||||
https://github.com/tatsu-lab/stanford_alpaca/blob/761dc5bfbdeeffa89b8bff5d038781a4055f796a/train.py#L31 | ||||||
|
||||||
This follows the following format: | ||||||
"Below is an instruction that describes a task, paired with an input that provides further context. " | ||||||
"Write a response that appropriately completes the request.\n\n" | ||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:{output}" | ||||||
|
||||||
where `instruction`, `input`, and `output` are fields from the dataset. | ||||||
|
||||||
Masking of the prompt during training is controlled by the `train_on_input` flag, which is | ||||||
set to `True` by default (ref: https://github.com/tloen/alpaca-lora/blob/main/finetune.py#L49) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are our thoughts on referring to reference implementations in torchtune? Not sure if citing them sort of implies that we as torchtune are sort of certifying that repo is a reference we endorse / want to compare against in an outward fashion |
||||||
- If `train_on_input` is True, the prompt is used during training and | ||||||
contributes to the loss. | ||||||
- If `train_on_input` is False, the prompt is masked out (tokens replaced with -100) | ||||||
|
||||||
The version of the dataset used is controlled by the `use_clean` flag which set to False by default. | ||||||
- If `use_clean` is True, then https://huggingface.co/datasets/yahma/alpaca-cleaned is used | ||||||
- If `use_clean` is False, then https://huggingface.co/datasets/tatsu-lab/alpaca is used | ||||||
|
||||||
Args: | ||||||
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. | ||||||
use_clean (bool): Whether to use the cleaned version of the dataset or not. Default is False. | ||||||
train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. | ||||||
**kwargs: Additional keyword arguments to pass to the Alpaca Dataset. | ||||||
|
||||||
|
||||||
|
@@ -37,17 +70,32 @@ class AlpacaDataset(Dataset): | |||||
>>> Batch size: 8 | ||||||
""" | ||||||
|
||||||
def __init__(self, tokenizer: Tokenizer, **kwargs) -> None: | ||||||
self._data = load_dataset("tatsu-lab/alpaca", split="train") | ||||||
def __init__( | ||||||
self, | ||||||
tokenizer: Tokenizer, | ||||||
train_on_input: bool = True, | ||||||
use_clean: bool = True, | ||||||
**kwargs | ||||||
) -> None: | ||||||
dataset_path = ( | ||||||
"yahma/alpaca-cleaned" if use_clean | ||||||
else "tatsu-lab/alpaca" | ||||||
) | ||||||
self._data = load_dataset(dataset_path, split="train") | ||||||
self._tokenizer = tokenizer | ||||||
self.train_on_input = train_on_input | ||||||
|
||||||
def __len__(self): | ||||||
return len(self._data) | ||||||
|
||||||
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: | ||||||
return self._transform(self._data[index]["text"]) | ||||||
return self._transform( | ||||||
instruction = self._data[index]["instruction"], | ||||||
input = self._data[index]["input"], | ||||||
output = self._data[index]["output"] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could just define sample = self._data[index] to avoid multiple calls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great catch! |
||||||
) | ||||||
|
||||||
def _transform(self, sample: str) -> Tuple[List[int], List[int]]: | ||||||
def _transform(self, instruction: str, input: str, output: str) -> Tuple[List[int], List[int]]: | ||||||
""" | ||||||
Split a sample on ``response`` tag to create input and labels. | ||||||
|
||||||
|
@@ -57,16 +105,38 @@ def _transform(self, sample: str) -> Tuple[List[int], List[int]]: | |||||
Returns: | ||||||
Tuple of encoded inputs and labels. | ||||||
""" | ||||||
response_tag = "\n\n### Response:\n" | ||||||
inst_inp_response_tag = sample[: sample.index(response_tag) + len(response_tag)] | ||||||
response = sample[sample.index(response_tag) + len(response_tag) :] | ||||||
inst_inp_response_tag = self._tokenizer.encode( | ||||||
inst_inp_response_tag, add_bos=True, add_eos=False | ||||||
prompt = self._generate_prompt(instruction, input) | ||||||
prompt_with_response = prompt + output | ||||||
|
||||||
# add bos always; LlamaTokenizer sets this to True by default and neither | ||||||
# alpaca-lora or the original authors change this | ||||||
encoded_prompt = self._tokenizer.encode(text=prompt, add_bos=True, add_eos=False) | ||||||
encoded_prompt_with_response = self._tokenizer.encode( | ||||||
text=prompt_with_response, add_bos=True, add_eos=True | ||||||
) | ||||||
response = self._tokenizer.encode(response, add_bos=False, add_eos=True) | ||||||
input = inst_inp_response_tag + response | ||||||
label = [ | ||||||
_CROSS_ENTROPY_IGNORE_IDX for _ in range(len(inst_inp_response_tag)) | ||||||
] + response | ||||||
assert len(input) == len(label) | ||||||
return input, label | ||||||
labels = encoded_prompt_with_response.copy() | ||||||
|
||||||
if not self.train_on_input: | ||||||
labels[:len(encoded_prompt)] = [CROSS_ENTROPY_IGNORE_IDX] * len(encoded_prompt) | ||||||
|
||||||
assert len(encoded_prompt_with_response) == len(labels) | ||||||
return encoded_prompt_with_response, labels | ||||||
|
||||||
def _generate_prompt(self, instruction: str, input: str) -> str: | ||||||
""" | ||||||
Generate prompt from instruction and input. | ||||||
|
||||||
Args: | ||||||
instruction (str): Instruction text. | ||||||
input (str): Input text. | ||||||
|
||||||
Returns: | ||||||
Prompt text. | ||||||
""" | ||||||
if input: | ||||||
prompt = _PROMPT_TEMPLATE["prompt_input"].format( | ||||||
instruction=instruction, input=input | ||||||
) | ||||||
else: | ||||||
prompt = _PROMPT_TEMPLATE["prompt_no_input"].format(instruction=instruction) | ||||||
return prompt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit