-
Notifications
You must be signed in to change notification settings - Fork 576
/
speculative_sample.py
486 lines (433 loc) · 19.9 KB
/
speculative_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, PreTrainedModel
from transformers import (
LogitsProcessorList,
StoppingCriteriaList,
)
from transformers.generation.streamers import BaseStreamer
import torch
from typing import Tuple, List, Optional
import copy
def norm_logits(
x: torch.Tensor,
logits: torch.Tensor,
logits_processor: LogitsProcessorList,
logits_warper: LogitsProcessorList,
do_sample: bool = False,
cur_len=None,
) -> torch.Tensor:
"""
Args:
x (`torch.Tensor`): input ids, shape (batch, seqlen)
logits `(`torch.Tensor`): shape (batch, seqlen, vocab)
do_sample ('bool'): whether do sample
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
logits_warper (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from
[`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
do_sample ('boo;'): whether do sample.
cur_len ('int'): length of current decoded tokens.
Returns:
`torch.Tensor`: probs with shape as (batch, seq_len)
"""
new_logits = logits[:,:]
if len(logits_processor) > 0:
for i in range(x.shape[1]-cur_len+1):
new_logits[:,i,:] = logits_processor(x[:,:cur_len+i], new_logits[:,i,:])
if do_sample and len(logits_warper) > 0:
for i in range(x.shape[1]-cur_len+1):
new_logits[:,i,:] = logits_warper(x[:,:cur_len+i], new_logits[:,i,:])
probs = new_logits.softmax(dim=-1)
return probs
def sample(probs : torch.Tensor, do_sample : bool = False, num_samples: int = 1):
if do_sample:
new_token = torch.multinomial(probs, num_samples=num_samples)
else:
new_token = torch.argmax(probs, keepdim=True)
return new_token
def max_fn(x):
"""
norm(max (x, 0))
"""
x_max = torch.where(x > 0, x, torch.zeros_like(x))
x_max_sum = torch.sum(x_max, dim=1, keepdim=True)
return x_max / x_max_sum
def _draft_model_serial_forward(
prefix : torch.Tensor,
draft_k : int,
draft_model : torch.nn.Module,
logits_processor,
logits_warper,
do_sample=False,
past_key_values=None,
rejected=False,
eos_token_id_tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]], torch.Tensor or bool]:
""" forward draft model draft_k times
Args:
prefix (`torch.Tensor`): the original input ids
draft_k (`int`): how many times draft model forward and sample
draft_model (`torch.nn.Module`): an draft model
logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and
generation config.
logits_warper: List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution
do_sample (`bool`): whether do sample
past_key_values: kv cache of draft model in last iteration
rejected (`bool`): whether any of tokens in last iteration was rejected
eos_token_id_tensor (`torch.Tensor`): eos token id in tokenizer
Returns:
Tuple[torch.Tensor, torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]], torch.Tensor or bool]:
generated tokens, probability distribution of draft model's output,
past_key_values of draft model, flag of whether last token is eos
"""
x = prefix
x = x.to(draft_model.device)
input_ids = x
probs = None
if past_key_values != None:
if rejected == False:
output = draft_model(input_ids[:,-2:-1], past_key_values = past_key_values, use_cache=True)
past_key_values = output.past_key_values
input_ids = input_ids[:,-1:]
probs = norm_logits(x[:,:-1], output.logits, logits_processor, logits_warper, do_sample, x.shape[1]-1)
else:
input_ids = input_ids[:,-1:]
for _ in range(draft_k):
output = draft_model(input_ids, past_key_values = past_key_values, use_cache=True)
new_probs = norm_logits(x, output.logits[:,-1:], logits_processor, logits_warper, do_sample, x.shape[1])
next_tok = sample(new_probs[:, -1, :], do_sample=do_sample)
if eos_token_id_tensor is not None:
last_token_is_eos = next_tok.tile(eos_token_id_tensor.shape[0], 1)
last_token_is_eos = (
~last_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_token_is_eos:
break
else:
last_token_is_eos = False
past_key_values = output.past_key_values
probs = torch.cat((probs, new_probs), dim=1) if probs != None else torch.cat((output.logits[:,:-1], new_probs), dim=1)
input_ids = next_tok
x = torch.cat((x, next_tok), dim=1)
return x, probs, past_key_values, last_token_is_eos
def _speculative_sampling(
prefix : torch.Tensor,
target_model : torch.nn.Module,
draft_model : torch.nn.Module,
max_new_tokens : int ,
draft_k : int = 4,
logits_processor: LogitsProcessorList = None,
logits_warper : LogitsProcessorList = None,
do_sample = False,
eos_token_id = None,
stopping_criteria = None,
streamer: Optional["BaseStreamer"] = None,
) -> torch.Tensor:
"""
DeepMind version Speculative Sampling.
Accelerating Large Language Model Decoding with Speculative Sampling
https://arxiv.org/abs/2302.01318
Args:
prefix (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now.
target_model (torch.nn.Module): target model, the large one
draft_model (torch.nn.Module): draft model, the small one
max_new_tokens (int): the max overall generated tokens number.
draft_k (int): the token number small model guesses.
logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and
generation config.
logits_warper: List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution
do_sample (`bool`): whether do sample
eos_token_id: eos token id in tokenizer
stopping_criteria: An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
Returns:
torch.Tensor: generated tokens (batch, target_seqlen)
"""
input_seq_len = prefix.shape[1]
T = input_seq_len + max_new_tokens
assert prefix.shape[0] == 1, "input batch size must be 1"
if draft_k <= 0:
draft_k = 4
adaptive_k = True
else:
adaptive_k = False
draft_past_key_values = None
draft_probs = None
target_past_key_values = None
target_probs = None
rejected = False
unfinished_sequences = prefix.new(prefix.shape[0]).fill_(1)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(prefix.device) if eos_token_id is not None else None
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
while prefix.shape[1] < T:
prefix_len = prefix.shape[1]
x, new_draft_probs, draft_past_key_values, _ = _draft_model_serial_forward(
prefix,
draft_k,
draft_model,
logits_processor,
logits_warper,
do_sample,
draft_past_key_values,
rejected,
eos_token_id_tensor
)
if draft_probs != None and new_draft_probs != None:
draft_probs = torch.concat((draft_probs, new_draft_probs), dim=1)
elif new_draft_probs == None:
draft_probs = draft_probs
else:
draft_probs = new_draft_probs
if target_past_key_values != None:
unchecked_token_count = x.shape[1] - target_probs.shape[1] - 1
outputs = target_model(x[:,-(unchecked_token_count+1):], past_key_values=target_past_key_values, use_cache=True)
else:
unchecked_token_count = x.shape[1] - prefix_len
outputs = target_model(x, use_cache=True)
new_target_probs = norm_logits(x, outputs.logits[:,-(unchecked_token_count+1):], logits_processor, logits_warper, do_sample, prefix_len)
target_probs = torch.cat((target_probs, new_target_probs), dim=1) if target_probs != None else torch.cat((outputs.logits[:,:-(unchecked_token_count+1)], new_target_probs), dim=1)
target_past_key_values = outputs.past_key_values
# n_valid: the length of the valid prefix
is_all_accept = True
n_valid = prefix_len
for i in range(unchecked_token_count):
r = torch.rand(1, device = target_probs.device)
cur_token_id = x[:, prefix_len + i]
cur_pos = prefix_len + i - 1
if r < torch.min(
torch.tensor([1], device=draft_probs.device),
target_probs[:, cur_pos, cur_token_id] / draft_probs[:, cur_pos, cur_token_id]
):
# accept, and update n_valid
n_valid += 1
else:
# reject
target_new_token = sample(
max_fn(
target_probs[:, n_valid-1, :] - draft_probs[:, n_valid-1, :]
), do_sample=do_sample
)
is_all_accept = False
rejected = True
break
n_valid = min(n_valid, T - 1)
prefix = x[:, :n_valid]
if is_all_accept:
target_new_token = sample(target_probs[:, -1, :], do_sample=do_sample)
rejected = False
else:
draft_probs = draft_probs[:,:n_valid,:]
target_probs = target_probs[:,:n_valid,:]
if "bloom" in draft_model.__class__.__name__.lower() or (
draft_model.config.architectures is not None and "bloom" in draft_model.config.architectures[0].lower()
):
draft_past_key_values = [
(key[:,:,:n_valid], value[:,:n_valid,:])
for key,value in draft_past_key_values
]
target_past_key_values = [
(key[:,:,:n_valid], value[:,:n_valid,:])
for key,value in target_past_key_values
]
else:
draft_past_key_values = [
(key[:,:,:n_valid,:], value[:,:,:n_valid,:])
for key,value in draft_past_key_values
]
target_past_key_values = [
(key[:,:,:n_valid,:], value[:,:,:n_valid,:])
for key,value in target_past_key_values
]
if adaptive_k:
if is_all_accept:
draft_k += 2
else:
draft_k = max(1, draft_k - 1)
prefix = torch.cat((prefix, target_new_token), dim=1)
if streamer is not None:
streamer.put(prefix.cpu())
if stopping_criteria(prefix, target_probs):
# this_peer_finished = True
break
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
prefix[:, -1]
.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
# this_peer_finished = True
break
if streamer is not None:
streamer.end()
return prefix
def speculative_sample(
input_ids,
target_model: Optional["PreTrainedModel"],
draft_model: Optional["PreTrainedModel"],
generation_config: GenerationConfig,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
draft_k: int = 4,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
inputs_tensor, _, model_kwargs = target_model._prepare_model_inputs(
input_ids, generation_config.bos_token_id, model_kwargs
)
model_kwargs["use_cache"] = generation_config.use_cache
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
# warnings.warn(
# f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
# "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
# " recommend using `max_new_tokens` to control the maximum length of the generation.",
# UserWarning,
# )
pass
elif generation_config.max_new_tokens is not None:
# if not has_default_max_length:
# logger.warning(
# f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
# f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
# "Please refer to the documentation for more information. "
# "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
# )
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
# input_ids_string = "decoder_input_ids" if target_model.config.is_encoder_decoder else "input_ids"
# logger.warning(
# f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
# f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
# " increasing `max_new_tokens`."
# )
pass
# prepare logis_processor, stopping_criteria, logits_warper
try:
logits_processor = target_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
except TypeError:
# Please install the latest transformers (commit equal or later than d533465) to enable CFG sampling.
logits_processor = target_model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
stopping_criteria = target_model._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper=target_model._get_logits_warper(generation_config) if generation_config.do_sample else None
outputs = _speculative_sampling(
prefix=input_ids,
target_model=target_model,
draft_model=draft_model,
max_new_tokens=generation_config.max_new_tokens,
draft_k=draft_k,
logits_processor=logits_processor,
logits_warper=logits_warper,
do_sample=generation_config.do_sample,
eos_token_id=generation_config.eos_token_id,
stopping_criteria=stopping_criteria,
streamer=streamer,
)
return outputs
if __name__ == "__main__":
# A usage example
draft_model_name = 'Draft/Model/Path'
target_model_name = 'Target/Model/Path'
DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。"""
TEMPLATE = (
"[INST] <<SYS>>\n"
"{system_prompt}\n"
"<</SYS>>\n\n"
"{instruction} [/INST]"
)
def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT):
return TEMPLATE.format_map({'instruction': instruction,'system_prompt': system_prompt})
inputs = ["我能用lightning数据线给安卓手机充电吗?"]
negative_text = generate_prompt(inputs[0], system_prompt="回复尽可能多的内容。")
inputs = [generate_prompt(text) for text in inputs]
tokenizer = AutoTokenizer.from_pretrained(target_model_name)
print("begin loading models")
draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map='auto',
load_in_8bit=False
)
draft_model.resize_token_embeddings(len(tokenizer))
print(f"Load {draft_model_name}")
target_model = AutoModelForCausalLM.from_pretrained(
target_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map='auto',
load_in_8bit=False
)
print(f"Load {target_model_name}")
draft_model.eval()
target_model.eval()
print("finish loading models")
torch_device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
input_ids = tokenizer.encode(inputs[0], return_tensors='pt').to(torch_device)
negative_inputs = tokenizer(negative_text,return_tensors="pt")
negative_prompt_ids = negative_inputs["input_ids"].to(torch_device)
negative_prompt_attention_mask = negative_inputs["attention_mask"].to(torch_device)
generation_config = GenerationConfig(
temperature=0.2,
top_k=40,
top_p=0.9,
do_sample=True,
num_beams=1,
repetition_penalty=1.1,
max_new_tokens=128
)
outputs = speculative_sample(
input_ids=input_ids,
target_model=target_model,
draft_model=draft_model,
generation_config=generation_config,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
# draft_k=4,
# guidance_scale=1.5,
# negative_prompt_ids=negative_prompt_ids,
# negative_prompt_attention_mask=negative_prompt_attention_mask,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)