Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eot token to ICL generate kwargs #2782

Merged
merged 14 commits into from
Dec 20, 2023
4 changes: 3 additions & 1 deletion composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ def collate_fn(self, data):
'generation_length': self.max_answer_length,
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'use_cache': True
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -1116,6 +1117,7 @@ def collate_fn(self, data):
'top_p': self.top_p,
'top_k': self.top_k,
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
}
}
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
Expand Down
2 changes: 2 additions & 0 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data

assert batch['generation_length'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])

Expand All @@ -559,6 +560,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
for found, expected in zip(batch['labels'], [['David Seville'], ['Skorpio', 'Scorpio']]))
assert decoded_batch[0].endswith('Q: Who was the man behind The Chipmunks?\nA:')
assert decoded_batch[1].endswith('Q: What star sign is Jamie Lee Curtis?\nA:')
assert 'eos_token_id' in batch['generation_kwargs']


@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl'])
Expand Down
Loading