-
Notifications
You must be signed in to change notification settings - Fork 457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip entire header for llama3 decode #1656
Changes from 3 commits
445286a
fe9dccf
784b527
1e8bcf0
59268d4
e7bbac6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import re | ||
from typing import Any, Dict, List, Mapping, Optional, Tuple | ||
|
||
from torchtune.data import Message, PromptTemplate, truncate | ||
|
@@ -113,6 +114,12 @@ def __init__( | |
|
||
self.prompt_template = prompt_template | ||
|
||
# Regex for removing special tokens from the decoded string | ||
self._special_token_regex = re.compile(r"<\|.*?\|>") | ||
self._special_token_header_regex = re.compile( | ||
r"<\|start_header_id\|>.*?<\|end_header_id\|>\n\n" | ||
) | ||
|
||
def _validate_special_tokens( | ||
self, | ||
): | ||
|
@@ -131,6 +138,15 @@ def _validate_special_tokens( | |
if token not in self.special_tokens: | ||
raise ValueError(f"{token} missing from special_tokens") | ||
|
||
def _remove_special_tokens(self, text: str) -> str: | ||
""" | ||
Remove special tokens from the decoded string. | ||
""" | ||
# First remove the headers, then the remaining special tokens | ||
return self._special_token_regex.sub( | ||
"", self._special_token_header_regex.sub("", text) | ||
) | ||
|
||
@property | ||
def base_vocab_size(self) -> int: | ||
return self.tt_model.base_vocab_size | ||
|
@@ -166,11 +182,18 @@ def decode( | |
Returns: | ||
str: The decoded string. | ||
""" | ||
return self.tt_model.decode( | ||
token_ids, | ||
# We will remove special tokens manually via regex on the decoded string. | ||
# This is because removing all special tokens does not remove the role and | ||
# whitespace added from the special tokens, i.e., the "user" and "\n\n" in | ||
# "<|start_header_id|>user<|end_header_id|>\n\n" | ||
decoded_string = self.tt_model.decode( | ||
token_ids=token_ids, | ||
truncate_at_eos=truncate_at_eos, | ||
skip_special_tokens=skip_special_tokens, | ||
skip_special_tokens=False, | ||
) | ||
if skip_special_tokens: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will this test pass?
Probably not well formulated, but i want to see if after every encode/decode we are adding \n\n There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that a string here is necessary to explain why you need this extra logic for "skip_special_tokens" if decode already has this flag. In other words: why do we skip special tokens tokens in two different places? Is there a more elegant way to solve this, like adding the special token to the tt_model directly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I understand that the goal here is just to fix |
||
decoded_string = self._remove_special_tokens(decoded_string) | ||
return decoded_string | ||
|
||
def _tokenize_header(self, message: Message) -> List[int]: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would maybe move this comment up to where you define
self._special_token_regex
andself._special_token_header_regex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is where it actually happens, so makes more sense to keep it here? no strong opinions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I feel the same, fine to keep it here then