Skip to content

Commit

Permalink
Add eot token to ICL generate kwargs (#2782)
Browse files Browse the repository at this point in the history
* add custome gen kwargs and stopping on eos token

* modify test

* modify test

* finish

* finish

* finish

* finish
  • Loading branch information
bmosaicml authored Dec 20, 2023
1 parent 96df92d commit a8a261b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
8 changes: 7 additions & 1 deletion composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def __init__(
fewshot_random_seed: int,
cot_delimiter: str = '',
):
if tokenizer.eos_token_id is None:
raise ValueError('`InContextLearningQATaskDataset` tokenizer must have non-null `eos_token_id`')
try:
from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues]
except ImportError as e:
Expand Down Expand Up @@ -306,7 +308,8 @@ def collate_fn(self, data):
'generation_length': self.max_answer_length,
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'use_cache': True
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
}
}

Expand Down Expand Up @@ -948,6 +951,8 @@ def __init__(
top_p: Optional[float] = 0.95,
top_k: Optional[int] = 40,
):
if tokenizer.eos_token_id is None:
raise ValueError('`InContextLearningCodeEvalDataset` tokenizer must have non-null `eos_token_id`')
try:
from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues]
except ImportError as e:
Expand Down Expand Up @@ -1116,6 +1121,7 @@ def collate_fn(self, data):
'top_p': self.top_p,
'top_k': self.top_k,
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
}
}
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
Expand Down
30 changes: 30 additions & 0 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,34 @@ def test_qa_split_batch(tiny_opt_tokenizer, dataset_uri, tmp_path):
assert isinstance(split2['generation_kwargs'], dict)


@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0])
@pytest.mark.parametrize('prompt_string', ['I am a prompt', ''])
def test_qa_task_dataloader_w_null_eos(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fewshot, prompt_string):
pytest.importorskip('datasets')

local_data = os.path.join(os.path.dirname(__file__), 'local_data')

tokenizer = tiny_gpt2_tokenizer
dataset_uri = f'{local_data}/{dataset_uri}'
batch_size = 4
seqlen = 512
tiny_gpt2_tokenizer.eos_token_id = None
with pytest.raises(ValueError):
_ = get_icl_task_dataloader('question_answering',
dataset_uri,
tokenizer,
batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter='\n',
question_prelimiter='Q: ',
continuation_delimiter='\nA:',
destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'))


@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0, 2])
@pytest.mark.parametrize('prompt_string', ['I am a prompt', ''])
Expand Down Expand Up @@ -545,6 +573,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data

assert batch['generation_length'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])

Expand All @@ -559,6 +588,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
for found, expected in zip(batch['labels'], [['David Seville'], ['Skorpio', 'Scorpio']]))
assert decoded_batch[0].endswith('Q: Who was the man behind The Chipmunks?\nA:')
assert decoded_batch[1].endswith('Q: What star sign is Jamie Lee Curtis?\nA:')
assert 'eos_token_id' in batch['generation_kwargs']


@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl'])
Expand Down

0 comments on commit a8a261b

Please sign in to comment.