-
Notifications
You must be signed in to change notification settings - Fork 457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Alpaca Dataset Updates and Fixes #303
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for getting this up and tested so quickly! My few comments are all nits so feel free to take or leave any of them
torchtune/datasets/alpaca.py
Outdated
instruction = self._data[index]["instruction"], | ||
input = self._data[index]["input"], | ||
output = self._data[index]["output"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could just define sample = self._data[index] to avoid multiple calls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch!
torchtune/datasets/alpaca.py
Outdated
|
||
|
||
class AlpacaDataset(Dataset): | ||
""" | ||
PyTorch Representation of the Alpaca Dataset | ||
Support for the Alpaca dataset and it's variants from HuggingFace Datasets. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
Support for the Alpaca dataset and it's variants from HuggingFace Datasets. | |
Support for the Alpaca dataset and its variants from HuggingFace Datasets. |
alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer) | ||
|
||
# alpaca_dataset._data contains the raw data loaded from HF's dataset. We need the raw data | ||
# to test the prompt generation since calling __get__item on the alpaca_dataset object will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
# to test the prompt generation since calling __get__item on the alpaca_dataset object will | |
# to test the prompt generation since calling __getitem__ on the alpaca_dataset object will |
@patch("torchtune.datasets.alpaca.load_dataset") | ||
def test_prompt_generation(self, load_dataset, tokenizer): | ||
""" | ||
Test the the prompt generation based on the alpaca template is correct. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
Test the the prompt generation based on the alpaca template is correct. | |
Test that the prompt generation based on the alpaca template is correct. |
Thanks so much @ebsmothers for the quick review! Addressed all comments. |
] | ||
|
||
alpaca_dataset = datasets.get_dataset( | ||
"alpaca", tokenizer=tokenizer, use_clean=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we parametrize instead with the test_label_masking
, as only difference is the use_clean flag?
where `instruction`, `input`, and `output` are fields from the dataset. | ||
|
||
Masking of the prompt during training is controlled by the `train_on_input` flag, which is | ||
set to `True` by default (ref: https://github.com/tloen/alpaca-lora/blob/main/finetune.py#L49) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are our thoughts on referring to reference implementations in torchtune? Not sure if citing them sort of implies that we as torchtune are sort of certifying that repo is a reference we endorse / want to compare against in an outward fashion
Context
Our current Alpaca dataset implementation doesn't allow us to train on the inputs i.e. not mask the input during training. Looking at reference implementations, this is pretty common and the only way we can replicate training curves.
The class is also written in a way which doesn't allow the user to easily switch in and out the different variations of the alpaca datasets. Using the clean version of the dataset, allows the loss to go down faster.
In this PR, both of these features are added. Plus tests for the alpaca dataset are added.
Thanks @ebsmothers for helping find some of these issues!
Changelog
use_clean
flagtrain_on_input
flagTest plan
test_alpaca_dataset
succeeded.Comment on why we're changing the loss values in
test_finetune_llm.py
The loss changes because we have a small difference in the input and label generation. This change is:
This creates a small difference in the output which results in changes in the loss: