Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom stopping criteria to ICL generate tasks #2800

Merged
merged 45 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ca4d3c4
add custome gen kwargs and stopping on eos token
bmosaicml Dec 14, 2023
1e39623
modify test
bmosaicml Dec 14, 2023
09af753
modify test
bmosaicml Dec 14, 2023
47f3c91
Merge branch 'dev' into pass_on_custom_generation_kwargs
bmosaicml Dec 15, 2023
a3501e9
finish
bmosaicml Dec 18, 2023
fadce0e
finish
bmosaicml Dec 18, 2023
92157da
finish
bmosaicml Dec 18, 2023
d137bbc
finish
bmosaicml Dec 18, 2023
b25da3f
Merge branch 'dev' into pass_on_custom_generation_kwargs
bmosaicml Dec 18, 2023
909ed63
finish pr
bmosaicml Dec 20, 2023
e263b5b
Merge branch 'pass_on_custom_generation_kwargs' of github.com:bmosaic…
bmosaicml Dec 20, 2023
32d6668
Merge branch 'dev' into pass_on_custom_generation_kwargs
bmosaicml Dec 20, 2023
4ff16b4
implement early stop
bmosaicml Dec 20, 2023
7ee0a72
Merge branch 'pass_on_custom_generation_kwargs' into add_custom_stopp…
bmosaicml Dec 20, 2023
83a60b7
add tesT
bmosaicml Dec 20, 2023
a031772
Merge branch 'pass_on_custom_generation_kwargs' into add_custom_stopp…
bmosaicml Dec 20, 2023
e5943d6
merge update
bmosaicml Dec 20, 2023
67b4685
Merge branch 'add_custom_stopping_criteria' of github.com:bmosaicml/c…
bmosaicml Dec 20, 2023
aa05076
fix bug
bmosaicml Dec 23, 2023
dce4ef0
bug fix
bmosaicml Dec 23, 2023
cb3c69d
add keys
bmosaicml Dec 26, 2023
786c64c
diff split
bmosaicml Dec 26, 2023
7f20954
fix typo
bmosaicml Dec 26, 2023
b7014a3
Merge branch 'dev' into add_custom_stopping_criteria
bmosaicml Dec 26, 2023
4efbc31
fix precommit
bmosaicml Dec 27, 2023
c7bebfd
Merge branch 'add_custom_stopping_criteria' of github.com:bmosaicml/c…
bmosaicml Dec 27, 2023
d15d808
fix precommit
bmosaicml Dec 27, 2023
6b4fbfb
fix precommit
bmosaicml Dec 27, 2023
999e2ad
fix precommit
bmosaicml Dec 27, 2023
c231f8a
fix precommit
bmosaicml Dec 27, 2023
6e38a71
fix precommit
bmosaicml Dec 27, 2023
dfb3c1e
fix conditional import
bmosaicml Jan 8, 2024
4b7a31e
Merge branch 'dev' into add_custom_stopping_criteria
bmosaicml Jan 8, 2024
14e45fa
add nlp metrics
bmosaicml Jan 10, 2024
8d03bb6
Merge branch 'dev' into add_custom_stopping_criteria
bmosaicml Jan 10, 2024
8a5779c
remove code gen changes
bmosaicml Jan 10, 2024
b9405de
Merge branch 'add_custom_stopping_criteria' of github.com:bmosaicml/c…
bmosaicml Jan 10, 2024
b321283
Merge branch 'dev' into add_custom_stopping_criteria
dakinggg Jan 11, 2024
03964f6
fix nits
bmosaicml Jan 12, 2024
4162bd7
Merge branch 'add_custom_stopping_criteria' of github.com:bmosaicml/c…
bmosaicml Jan 12, 2024
a66c979
Merge branch 'dev' into add_custom_stopping_criteria
bmosaicml Jan 12, 2024
1dac72b
Merge branch 'dev' into add_custom_stopping_criteria
bmosaicml Jan 15, 2024
2550bd0
fix union
bmosaicml Jan 15, 2024
0e4c015
fix
bmosaicml Jan 15, 2024
78ecb42
fix
bmosaicml Jan 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 97 additions & 75 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from composer.core import DataSpec
from composer.core.data_spec import _default_split_batch, _split_list
from composer.datasets.utils import stop_sequences_criteria
from composer.utils import MissingConditionalImportError, dist, get_file

if TYPE_CHECKING:
Expand Down Expand Up @@ -139,21 +140,21 @@ def _read_dataset(self, dataset: Dataset) -> List[Dict[str, str]]:
})
return result

def __init__(
self,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
question_prelimiter: str,
fewshot_random_seed: int,
cot_delimiter: str = '',
):
def __init__(self,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str,
question_prelimiter: str,
fewshot_random_seed: int,
cot_delimiter: str = '',
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True):
if tokenizer.eos_token_id is None:
raise ValueError('`InContextLearningQATaskDataset` tokenizer must have non-null `eos_token_id`')
try:
Expand All @@ -166,6 +167,8 @@ def __init__(
if dist.get_local_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
self.early_stopping_criteria = early_stopping_criteria
self.do_normalization = do_normalization
self.samples = self._read_dataset(dataset)
self.samples = strip_data(self.samples)
self.tokenizer = tokenizer
Expand Down Expand Up @@ -299,17 +302,22 @@ def collate_fn(self, data):
# We will search for the answer within the portion of the model response
# beginning with `cot_delimiter`
cot_delimiter = sample['cot_delimiter']

stopping_criteria = None
if self.early_stopping_criteria:
stopping_criteria = stop_sequences_criteria(self.tokenizer, self.early_stopping_criteria, len(inputs))
batch = {
'input_ids': torch.stack(inputs),
'mode': 'generate',
'labels': answers,
'cot_delimiter': cot_delimiter,
'generation_length': self.max_answer_length,
'stopping_criteria': self.early_stopping_criteria,
'do_normalization': self.do_normalization,
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
'stopping_criteria': stopping_criteria,
'eos_token_id': self.tokenizer.eos_token_id,
}
}

Expand All @@ -323,7 +331,9 @@ def split_batch(self, batch: Any, microbatch_size: int):
# Don't split kwargs that don't change
# Normally split torch tensors
# List split lists of strings
no_split = ['mode', 'generation_length', 'generation_kwargs', 'cot_delimiter']
no_split = [
'mode', 'generation_length', 'generation_kwargs', 'cot_delimiter', 'do_normalization', 'stopping_criteria'
]
normal_split = ['input_ids', 'attention_mask']
list_split = ['labels']
chunked = {}
Expand All @@ -339,7 +349,7 @@ def split_batch(self, batch: Any, microbatch_size: int):
raise ValueError(f'Unexpected key {k}')
num_chunks = len(chunked['input_ids'])
for k, v in batch.items():
if isinstance(v, (int, float, str, bool, dict)):
if k in no_split:
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
chunked[k] = [v] * num_chunks
return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)]

Expand Down Expand Up @@ -950,6 +960,7 @@ def __init__(
pass_at_k: int = 1,
top_p: Optional[float] = 0.95,
top_k: Optional[int] = 40,
early_stopping_criteria: Optional[List[str]] = None,
):
if tokenizer.eos_token_id is None:
raise ValueError('`InContextLearningCodeEvalDataset` tokenizer must have non-null `eos_token_id`')
Expand All @@ -963,6 +974,7 @@ def __init__(
if dist.get_local_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
self.early_stopping_criteria = early_stopping_criteria
self.samples = list(
dataset.map(
lambda examples: {
Expand Down Expand Up @@ -1100,6 +1112,9 @@ def collate_fn(self, data):
test_outputs.append(test_output)
languages.append(language)

stopping_criteria = None
if self.early_stopping_criteria:
stopping_criteria = stop_sequences_criteria(self.tokenizer, self.early_stopping_criteria, len(inputs))
batch = {
'input_ids': torch.stack(inputs),
'mode': 'generate',
Expand All @@ -1121,6 +1136,7 @@ def collate_fn(self, data):
'top_p': self.top_p,
'top_k': self.top_k,
'use_cache': True,
'stopping_criteria': stopping_criteria,
'eos_token_id': self.tokenizer.eos_token_id
}
}
Expand Down Expand Up @@ -1161,23 +1177,24 @@ def split_batch(self, batch: Any, microbatch_size: int):


def build_icl_dataloader(
icl_task_type: str,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
batch_size: int,
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str, # e.g. 'translate english to french:'
example_delimiter: str, # e.g. '\n'
continuation_delimiter: str, # e.g. ''
destination_path: str,
question_prelimiter: str = '', # e.g. 'Question: '
cot_delimiter: str = '',
fewshot_random_seed: int = 1234,
pass_at_k: int = 1,
generations_per_sample: int = 1,
) -> DataSpec:
icl_task_type: str,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
batch_size: int,
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str, # e.g. 'translate english to french:'
example_delimiter: str, # e.g. '\n'
continuation_delimiter: str, # e.g. ''
destination_path: str,
question_prelimiter: str = '', # e.g. 'Question: '
cot_delimiter: str = '',
fewshot_random_seed: int = 1234,
pass_at_k: int = 1,
generations_per_sample: int = 1,
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True) -> DataSpec:
if icl_task_type == 'multiple_choice':
dataset = InContextLearningMultipleChoiceTaskDataset(dataset_uri,
tokenizer,
Expand Down Expand Up @@ -1228,7 +1245,9 @@ def build_icl_dataloader(
destination_path=destination_path,
question_prelimiter=question_prelimiter,
fewshot_random_seed=fewshot_random_seed,
cot_delimiter=cot_delimiter)
cot_delimiter=cot_delimiter,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization)
effective_batchsize = batch_size
elif icl_task_type == 'code_evaluation':
dataset = InContextLearningCodeEvalDataset(dataset_uri,
Expand All @@ -1242,7 +1261,8 @@ def build_icl_dataloader(
code_prelimiter=question_prelimiter,
fewshot_random_seed=fewshot_random_seed,
pass_at_k=pass_at_k,
generations_per_sample=generations_per_sample)
generations_per_sample=generations_per_sample,
early_stopping_criteria=early_stopping_criteria)
effective_batchsize = batch_size
else:
raise Exception(f'Unrecognized ICL task type: {icl_task_type}')
Expand Down Expand Up @@ -1328,7 +1348,9 @@ def get_icl_task_dataloader(
pass_at_k: int = 1,
generations_per_sample: int = 1,
cot_delimiter: str = '',
has_categories: bool = False) -> Union[DataSpec, Dict[str, DataSpec]]:
has_categories: bool = False,
early_stopping_criteria: Optional[List[str]] = None,
do_normalization: bool = True) -> Union[DataSpec, Dict[str, DataSpec]]:
"""This constructs a dataloader (or dataloaders if has_categories is True) capable of evaluating LLMs on in-context learning language modeling tasks, for example LAMBADA. An example usage is below:

>>> dl = get_icl_task_dataloader(
Expand Down Expand Up @@ -1381,41 +1403,41 @@ def get_icl_task_dataloader(
categories = sorted(output_files.keys())
for category in categories:
partition_uri = output_files[category]
result_dls[category] = build_icl_dataloader(
icl_task_type,
partition_uri,
tokenizer,
batch_size,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
partition_uri + '_tmp',
question_prelimiter,
cot_delimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
)
result_dls[category] = build_icl_dataloader(icl_task_type,
partition_uri,
tokenizer,
batch_size,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
partition_uri + '_tmp',
question_prelimiter,
cot_delimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization)
return result_dls
else:
return build_icl_dataloader(
icl_task_type,
dataset_uri,
tokenizer,
batch_size,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
destination_path,
question_prelimiter,
cot_delimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
)
return build_icl_dataloader(icl_task_type,
dataset_uri,
tokenizer,
batch_size,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
destination_path,
question_prelimiter,
cot_delimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
early_stopping_criteria=early_stopping_criteria,
do_normalization=do_normalization)
66 changes: 65 additions & 1 deletion composer/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import textwrap
import warnings
from typing import Callable, List, Tuple, Union
from typing import TYPE_CHECKING, Callable, List, Tuple, Union

import numpy as np
import torch
Expand All @@ -15,6 +15,10 @@
from torchvision.datasets import VisionDataset

from composer.core import Batch
from composer.utils import MissingConditionalImportError

if TYPE_CHECKING:
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
import transformers

__all__ = [
'add_vision_dataset_transform',
Expand Down Expand Up @@ -166,3 +170,63 @@ def add_vision_dataset_transform(dataset: VisionDataset, transform: Callable, is
else:
dataset.transform = transforms.Compose([dataset.transform, transform])
log.warning(transform_added_logstring)


try:
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
import transformers
del transformers # unused
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp',
conda_package='transformers',
conda_channel='conda-forge') from e


class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence.
Slightly modified from: https://github.com/EleutherAI/lm-evaluation-harness/blob/78545d42f2ca95c6fe0ed220d456eeb94f4485e9/lm_eval/utils.py#L614-L649
"""

def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
batch_size: int,
) -> None:
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
# we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization

# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer

def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, :][:, -self.sequence_id_len:]

lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if i >= len(lookback_tokens_batch):
# the last batch of a dataset may b smaller than `batch_size`
# automatically set those indices in the done_tracker to True
# since those indices don't show up in the current batch
self.done_tracker[i] = True
break
elif not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker


def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList([
*[MultiTokenEOSCriteria(sequence, tokenizer, batch_size) for sequence in stop_sequences],
])
14 changes: 12 additions & 2 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,23 @@ def update(self, outputs: List[str], labels: List[List[str]], batch: Optional[Di
if batch is None:
batch = {}
cot_delimiter = batch.get('cot_delimiter', '')
do_normalization = batch.get('do_normalization', True)
stopping_criteria = batch.get('stopping_criteria', None)
for sample_output, sample_labels in zip(outputs, labels):
final_answer = sample_output

if stopping_criteria is not None and len(stopping_criteria) > 0:
final_answer = re.split('|'.join(stopping_criteria), final_answer)[0]

if cot_delimiter is not None and len(cot_delimiter) > 0:
final_answer = final_answer.split(cot_delimiter)[-1]

cleaned_final_answer = self.normalize_answer(final_answer)
cleaned_sample_labels = {self.normalize_answer(label) for label in sample_labels}
if do_normalization:
cleaned_final_answer = self.normalize_answer(final_answer)
cleaned_sample_labels = {self.normalize_answer(label) for label in sample_labels}
else:
cleaned_final_answer = final_answer
cleaned_sample_labels = set(sample_labels)

if any(cleaned_final_answer.startswith(label) for label in cleaned_sample_labels):
self.correct += torch.tensor(1.0)
Expand Down
Loading