We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi all, I'm recently trying to run the LLaMA-2-70B model in a single GPU, with a lot of help from this project.
But I found that, it is very dangerous to using default non_block=True setting like:
non_block=True
https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L328 https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L332
main code:
tokenizer = LlamaTokenizer.from_pretrained(args.model_dir) model = LlamaForCausalLM.from_pretrained( args.model_dir, low_cpu_mem_usage=True, torch_dtype=DTYPE ).eval() origin_llama_model = model.get_decoder() model.set_decoder( OffloadLlamaModel(origin_llama_model, device=device, num_slices=args.num_slices) ) del origin_llama_model model.lm_head.cuda() # move model.lm_head to GPU prompt = "Give me some suggestions on how to lose weight." input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device=device) logging.info("Generating response...") s = time.time() generate_ids = model.generate( input_ids, do_sample=False, num_beams=1, max_length=200 )
The OffloadLlamaModel code:
class DecodeOutput(object): def __init__(self, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns): super().__init__() self.elements = [ hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns ] def items(self): return self.elements def cuda(self): self.elements = [item.cuda() if hasattr(item, 'cuda') and callable(item.__getattribute__('cuda')) else item for item in self.elements] return self def cpu(self): self.elements = [item.cpu() if hasattr(item, 'cpu') and callable(item.__getattribute__('cpu')) else item for item in self.elements] return self def __str__(self): return "DecodeOutput(" + str(self.elements) + ")" def __getitem__(self, index: int): return self.elements[index] class WrappedLlamaDecoderLayer(nn.Module): def __init__(self, index: int, decoder: LlamaDecoderLayer): super(WrappedLlamaDecoderLayer, self).__init__() self.idx = index self.decoder = decoder def forward(self, inputs: DecodeOutput): # unpack all parameters [hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns] = inputs.items() if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[self.idx] if past_key_values is not None else None # note: removed code like 'if self.gradient_checkpointing and self.training', so only for inference layer_outputs = self.decoder( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) outputs = DecodeOutput( hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns ) return outputs class OffloadLlamaModel(nn.Module): def __init__(self, llama_model: LlamaModel, device=torch.device('cuda'), offload_device=torch.device("cpu"), num_slices=3, checkpoint_activation=False, num_microbatches=1): logging.info("OffloadLlamaModel Initializing.") super(OffloadLlamaModel, self).__init__() self.config = llama_model.config self.padding_idx = llama_model.padding_idx self.vocab_size = llama_model.vocab_size self.embed_tokens = llama_model.embed_tokens.cuda() logging.info("Convert origin LlamaModel.layers to a nn.Sequential of WrappedLlamaDecoders.") _sequential = nn.Sequential() for idx, decoder in enumerate(llama_model.layers): _sequential.add_module("layer_%d" % idx, WrappedLlamaDecoderLayer(idx, decoder)) self.layers = OffloadModel( model=_sequential, device=device, offload_device=offload_device, num_slices=num_slices, checkpoint_activation=checkpoint_activation, num_microbatches=num_microbatches, ) for sid, slc in enumerate(self.layers.model_slices): logging.debug( f"Shard {sid:d} holds WrappedLlamaDecodeLayer [{','.join(str(m.idx) for m in slc.model_shard)}]" ) self.norm = llama_model.norm.cuda() # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None inputs = DecodeOutput( hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns ) layer_outputs = self.layers.forward(inputs) [hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns] = layer_outputs.items() hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, )
I found that the model generated different responses when using different num_slices settings, even when the random seed fixed.
num_slices
The pairwise_distance of each decoder layer between the original model and the offloaded model was like:
2023-10-27 17:46:28,544 - INFO: Loading LLaMA model and tokenizer. You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565 Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:17<00:00, 5.98s/it] 2023-10-27 17:46:47,599 - INFO: Running model natively. 2023-10-27 17:47:15,826 - INFO: Running OffloadModel using given num_slices setting. 2023-10-27 17:47:15,826 - INFO: OffloadLlamaModel Initializing. 2023-10-27 17:47:16,049 - INFO: Convert origin LlamaModel.layers to a nn.Sequential of WrappedLlamaDecoders. 2023-10-27 17:47:16,052 - INFO: This model has 12688.18M parameters, aiming for 6344.09M parameters per shard 2023-10-27 17:47:39,404 - INFO: Shard 0 holds 6344.09M parameters 2023-10-27 17:47:39,405 - INFO: Shard 1 holds 6344.09M parameters 2023-10-27 17:47:39,412 - DEBUG: Shard 0 holds WrappedLlamaDecodeLayer [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] 2023-10-27 17:47:39,412 - DEBUG: Shard 1 holds WrappedLlamaDecodeLayer [20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39] Embedding is same: True RMSNorm is same: True Checking layers: Layer 00 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 01 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 02 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 03 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 04 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 05 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 06 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 07 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 08 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 09 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 10 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 11 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 12 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 13 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 14 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 15 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 16 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 17 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 18 obj is same: True, attention diff: 0.7080, hidden_state diff: 0.0001 Layer 19 obj is same: True, attention diff: 0.7104, hidden_state diff: 13.1719 Layer 20 obj is same: True, attention diff: 0.0465, hidden_state diff: 19.6094 Layer 21 obj is same: True, attention diff: 0.0314, hidden_state diff: 19.7031 Layer 22 obj is same: True, attention diff: 0.0166, hidden_state diff: 19.9062 Layer 23 obj is same: True, attention diff: 0.0146, hidden_state diff: 20.4062 Layer 24 obj is same: True, attention diff: 0.0164, hidden_state diff: 20.8125 Layer 25 obj is same: True, attention diff: 0.0153, hidden_state diff: 21.2500 Layer 26 obj is same: True, attention diff: 0.0143, hidden_state diff: 21.8281 Layer 27 obj is same: True, attention diff: 0.0151, hidden_state diff: 22.4375 Layer 28 obj is same: True, attention diff: 0.0112, hidden_state diff: 22.9844 Layer 29 obj is same: True, attention diff: 0.0150, hidden_state diff: 23.4531 Layer 30 obj is same: True, attention diff: 0.0098, hidden_state diff: 24.0781 Layer 31 obj is same: True, attention diff: 0.0129, hidden_state diff: 24.6562 Layer 32 obj is same: True, attention diff: 0.0098, hidden_state diff: 25.2656 Layer 33 obj is same: True, attention diff: 0.0164, hidden_state diff: 25.8750 Layer 34 obj is same: True, attention diff: 0.0106, hidden_state diff: 26.5000 Layer 35 obj is same: True, attention diff: 0.0133, hidden_state diff: 27.2188 Layer 36 obj is same: True, attention diff: 0.0166, hidden_state diff: 28.0000 Layer 37 obj is same: True, attention diff: 0.0179, hidden_state diff: 28.9688 Layer 38 obj is same: True, attention diff: 0.7056, hidden_state diff: 30.0312 Layer 39 obj is same: True, attention diff: 0.6108, hidden_state diff: 83.8750 ['<s>me a examples for how to improve weight and\n'] ['<s>me a examples on how to improve weight fast I']
Once I manually set non_blocking=False, all the above diff disappeared.
non_blocking=False
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi all, I'm recently trying to run the LLaMA-2-70B model in a single GPU, with a lot of help from this project.
But I found that, it is very dangerous to using default
non_block=True
setting like:https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L328
https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L332
main code:
The OffloadLlamaModel code:
I found that the model generated different responses when using different
num_slices
settings, even when the random seed fixed.The pairwise_distance of each decoder layer between the original model and the offloaded model was like:
Once I manually set
non_blocking=False
, all the above diff disappeared.The text was updated successfully, but these errors were encountered: