Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: The final reason why you will get a model that cannot stop generation when you fine-tune the Qwen2.5-7b-base use Lora and a non-<|endoftext|> token as eos_token. #1064

Open
4 tasks done
hxs91 opened this issue Nov 8, 2024 · 9 comments

Comments

@hxs91
Copy link

hxs91 commented Nov 8, 2024

Model Series

Qwen2.5

What are the models used?

Qwen2.5-7b-base

What is the scenario where the problem happened?

sft with huggingface trainer

Is this a known issue?

  • I have followed the GitHub README.
  • I have checked the Qwen documentation and cannot find an answer there.
  • I have checked the documentation of the related framework and cannot find useful information.
  • I have searched the issues and there is not a similar one.

Information about environment

Doesn't matter.

Log output

Doesn't matter.

Description

Steps to reproduce

  1. Use Qwen2.5-7b-base
  2. Modify its eos_token from <|endoftext|> to <|im_end|>(or any other specical token) in the tokenizer_config.json
  3. Use LoRA to fine-tune the model on your downstream task, makesure your LoRA will not fine-tune the lm_head and embedding.
  4. The fine-tuning data follow the input, instruction and output format and I place them with : input+instruction+output+eos_token, so do the labels.

Expected results

Expected: The fine-tuned model can produce text as you presented in the training data.
Happend: The model can indeed generate appropriate text but cannot stop generation.

Attempts to fix

  1. check the generate function reveive the right eos_token_id
  2. check the training procedure get the right inputs_id and labels, make sure eos_token_id has been trained.

Final reason

If I change the eos_token back to <|endoftext|>, the model will have the right behavior.

After careful inspection, I found the reason: Qwen2.5-7b-base has the same weight in lm_head and embedding for additional tokens like <|im_end|>, <|object_ref_start|> and so on except for <|endoftext|>.

print("lm_head")
print("151643:"+str(model.lm_head.weight[151643]))
print("151644:"+str(model.lm_head.weight[151644]))
print("151645:"+str(model.lm_head.weight[151645]))
print("151646:"+str(model.lm_head.weight[151646]))
print("151647:"+str(model.lm_head.weight[151647]))
print("151648:"+str(model.lm_head.weight[151648]))
print("151649:"+str(model.lm_head.weight[151649]))
print("151650:"+str(model.lm_head.weight[151650]))
print("142333:"+str(model.lm_head.weight[142333]))
print("embedding")
print("151643:"+str(model.get_input_embeddings().weight[151643]))
print("151644:"+str(model.get_input_embeddings().weight[151644]))
print("151645:"+str(model.get_input_embeddings().weight[151645]))
print("151646:"+str(model.get_input_embeddings().weight[151646]))
print("151647:"+str(model.get_input_embeddings().weight[151647]))
print("151648:"+str(model.get_input_embeddings().weight[151648]))
print("151649:"+str(model.get_input_embeddings().weight[151649]))
print("151650:"+str(model.get_input_embeddings().weight[151650]))
print("142333:"+str(model.get_input_embeddings().weight[142333]))

the output

qwen2_base_7b
lm_head
151643:tensor([-0.0025, -0.0061, -0.0063,  ..., -0.0042, -0.0118,  0.0019],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
151644:tensor([ 0.0005,  0.0091,  0.0034,  ...,  0.0020,  0.0002, -0.0011],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
151645:tensor([ 0.0005,  0.0091,  0.0034,  ...,  0.0020,  0.0002, -0.0011],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
151646:tensor([ 0.0005,  0.0091,  0.0034,  ...,  0.0020,  0.0002, -0.0011],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
151647:tensor([ 0.0005,  0.0091,  0.0034,  ...,  0.0020,  0.0002, -0.0011],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
151648:tensor([ 0.0005,  0.0091,  0.0034,  ...,  0.0020,  0.0002, -0.0011],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
151649:tensor([ 0.0005,  0.0091,  0.0034,  ...,  0.0020,  0.0002, -0.0011],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
151650:tensor([ 0.0005,  0.0091,  0.0034,  ...,  0.0020,  0.0002, -0.0011],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
142333:tensor([ 0.0005,  0.0091,  0.0034,  ...,  0.0020,  0.0002, -0.0011],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
embedding
151643:tensor([-0.0186,  0.0347,  0.0092,  ...,  0.0040, -0.0077,  0.0006],
       dtype=torch.bfloat16, grad_fn=<SelectBackward0>)
151644:tensor([ 1.1755e-37, -1.1755e-37,  1.1755e-37,  ...,  1.1755e-37,
        -1.1755e-37, -1.1755e-37], dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
151645:tensor([-1.1755e-37, -1.1755e-37,  1.1755e-37,  ...,  1.1755e-37,
        -1.1755e-37,  1.1755e-37], dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
151646:tensor([-1.1755e-37, -1.1755e-37,  1.1755e-37,  ...,  1.1755e-37,
         1.1755e-37,  1.1755e-37], dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
151647:tensor([ 1.1755e-37,  1.1755e-37, -1.1755e-37,  ...,  1.1755e-37,
         1.1755e-37, -1.1755e-37], dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
151648:tensor([ 1.1755e-37, -1.1755e-37,  1.1755e-37,  ...,  1.1755e-37,
        -1.1755e-37,  1.1755e-37], dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
151649:tensor([ 1.1755e-37, -1.1755e-37,  1.1755e-37,  ..., -1.1755e-37,
         1.1755e-37,  1.1755e-37], dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
151650:tensor([-1.1755e-37, -1.1755e-37,  1.1755e-37,  ..., -1.1755e-37,
         1.1755e-37, -1.1755e-37], dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
142333:tensor([ 1.1755e-37, -1.1755e-37, -1.1755e-37,  ...,  1.1755e-37,
         1.1755e-37,  1.1755e-37], dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)

151643 is the id of <|endoftext|>, 142333 is a random id I pick (maybe there are more ids like this), other ids are defined in the tokenizer_config.json. This explains why the fine-tuned model cannot stop generation: although the <|im_end|>(151645) is trained, but its weight in lm_head is same with many other ids, when it should be generated at inference, its logits will also equal to those ids, thus any id can be picked during decoding. In fact I observe logits for 151645 indeed incerease when it should be generated, but the same for those ids that have the same lm_head weights. This will not happen for 151643 , because it have a different lm_head weight, maybe it is trained during the pre-training stage.

I am quite confused why this will happen since all lm_head and embedding weights are initialized by normal distribution according to the released code.

@hxs91 hxs91 changed the title [Bug]: The final reason why you will get a model that cannot stop generation when you fine-tune the Qwen2.5-7b-base use Lora and a non-<|endoftext|> token as bos_token. [Bug]: The final reason why you will get a model that cannot stop generation when you fine-tune the Qwen2.5-7b-base use Lora and a non-<|endoftext|> token as eos_token. Nov 8, 2024
@jklj077
Copy link
Collaborator

jklj077 commented Nov 14, 2024

"<|im_start|>" and "<|im_end|>" are indeed not trained for the base models in the Qwen2.5 series.

This explains why the fine-tuned model cannot stop generation: although the <|im_end|>(151645) is trained, ...

I understand that you "makesure your LoRA will not fine-tune the lm_head and embedding". So,

  • it is trained only in the sense that the parameters of the LoRA adapters do receive gradients and are updated in such a way that "<|im_end|>" is more likely to be generated; but
  • it is not trained in the sense that the parameters of the embedding or the LM head of "<|im_end|>" are updated, so "<|im_end|>" is still not differentiable from other untrained tokens.

This will not happen for 151643 , because it have a different lm_head weight, maybe it is trained during the pre-training stage.

This is true. "<|endoftext|>" is trained. See also https://qwen.readthedocs.io/en/latest/getting_started/concepts.html#control-tokens-chat-template.

If you must keep the lm_head and embedding from being finetuned, you can

  • regard all untrained tokens as stopping tokens, e.g., setting eos_token_ids in generation_config.json for transformers; or
  • randomly reinitialize the embedding and lm_head for "<|im_start|>" and "<|im_end|>" after the model is loaded from the checkpoint; or
  • use a different chat template that does not rely on special tokens and set stopping criteria accordingly.

However, the best way is still to train the embedding and the lm_head of the new tokens.

@hxs91
Copy link
Author

hxs91 commented Nov 14, 2024

@jklj077 Thanks for the patient explanation. It takes me quite a lot of time to figure out this problem. Maybe it shoud not be called a "BUG", I wrote it down incase other guys encounter the same problem.

BTW, if initialize all embedding and lm_head according to the norm distribution, the all tokens will have a different embedding and lm_head weight, which I thought it shoud be, the phenomenon I described will not happen, even these tokens are not trained in the pretraining stage.

@echoht
Copy link

echoht commented Nov 15, 2024

多轮问答的数据集里,lora的训练方式, eos_token 设置成 <|endoftext|> 会有问题吗?

@hxs91
Copy link
Author

hxs91 commented Nov 15, 2024

多轮问答的数据集里,lora的训练方式, eos_token 设置成 <|endoftext|> 会有问题吗?

设置成<|endoftext|>就应该不会有问题。

@echoht
Copy link

echoht commented Nov 15, 2024

我看这个doc https://qwen.readthedocs.io/zh-cn/latest/getting_started/concepts.html#control-tokens-chat-template;<|endoftext|>是放到多轮的最后一轮的结尾,感觉和这个有冲突,但是没有做实验验证。

@echoht
Copy link

echoht commented Nov 15, 2024

If I change the eos_token back to <|endoftext|>, the model will have the right behavior.

你这个是在哪个地方调整?是在lora微调前调整还是lora微调后在generation的时候调整?

@hxs91
Copy link
Author

hxs91 commented Nov 15, 2024

我看这个doc https://qwen.readthedocs.io/zh-cn/latest/getting_started/concepts.html#control-tokens-chat-template;<|endoftext|>是放到多轮的最后一轮的结尾,感觉和这个有冲突,但是没有做实验验证。

我做的是base模型微调,和模板没关系~

If I change the eos_token back to <|endoftext|>, the model will have the right behavior.

你这个是在哪个地方调整?是在lora微调前调整还是lora微调后在generation的时候调整?

在微调前的tokenizer_config.json

@echoht
Copy link

echoht commented Nov 15, 2024

我看这个doc https://qwen.readthedocs.io/zh-cn/latest/getting_started/concepts.html#control-tokens-chat-template;<|endoftext|>是放到多轮的最后一轮的结尾,感觉和这个有冲突,但是没有做实验验证。

我做的是base模型微调,和模板没关系~

If I change the eos_token back to <|endoftext|>, the model will have the right behavior.

你这个是在哪个地方调整?是在lora微调前调整还是lora微调后在generation的时候调整?

在微调前的tokenizer_config.json

不是用llama-factory训练框架微调吗?

@hxs91
Copy link
Author

hxs91 commented Nov 15, 2024

我看这个doc https://qwen.readthedocs.io/zh-cn/latest/getting_started/concepts.html#control-tokens-chat-template;<|endoftext|>是放到多轮的最后一轮的结尾,感觉和这个有冲突,但是没有做实验验证。

我做的是base模型微调,和模板没关系~

If I change the eos_token back to <|endoftext|>, the model will have the right behavior.

你这个是在哪个地方调整?是在lora微调前调整还是lora微调后在generation的时候调整?

在微调前的tokenizer_config.json

不是用llama-factory训练框架微调吗?

不是,用的huggingface trainer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants