Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip entire header for llama3 decode #1656

Merged
merged 6 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/torchtune/models/llama3/test_llama3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,17 @@ def test_validate_special_tokens(self):
"<|python_tag|>": 128255,
},
)

def test_skip_special_tokens(
self,
tokenizer,
user_text_message,
assistant_text_message,
user_text_a,
user_text_b,
assistant_text,
):
# This should satisfy text = decode(encode(text))
tokens = user_text_message[1] + assistant_text_message[1]
text = tokenizer.decode(tokens, skip_special_tokens=True)
assert text == user_text_a + user_text_b + assistant_text
29 changes: 26 additions & 3 deletions torchtune/models/llama3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import re
from typing import Any, Dict, List, Mapping, Optional, Tuple

from torchtune.data import Message, PromptTemplate, truncate
Expand Down Expand Up @@ -113,6 +114,12 @@ def __init__(

self.prompt_template = prompt_template

# Regex for removing special tokens from the decoded string
self._special_token_regex = re.compile(r"<\|.*?\|>")
self._special_token_header_regex = re.compile(
r"<\|start_header_id\|>.*?<\|end_header_id\|>\n\n"
)

def _validate_special_tokens(
self,
):
Expand All @@ -131,6 +138,15 @@ def _validate_special_tokens(
if token not in self.special_tokens:
raise ValueError(f"{token} missing from special_tokens")

def _remove_special_tokens(self, text: str) -> str:
"""
Remove special tokens from the decoded string.
"""
# First remove the headers, then the remaining special tokens
return self._special_token_regex.sub(
"", self._special_token_header_regex.sub("", text)
)

@property
def base_vocab_size(self) -> int:
return self.tt_model.base_vocab_size
Expand Down Expand Up @@ -166,11 +182,18 @@ def decode(
Returns:
str: The decoded string.
"""
return self.tt_model.decode(
token_ids,
# We will remove special tokens manually via regex on the decoded string.
# This is because removing all special tokens does not remove the role and
# whitespace added from the special tokens, i.e., the "user" and "\n\n" in
# "<|start_header_id|>user<|end_header_id|>\n\n"
Comment on lines +185 to +188
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe move this comment up to where you define self._special_token_regex and self._special_token_header_regex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is where it actually happens, so makes more sense to keep it here? no strong opinions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I feel the same, fine to keep it here then

decoded_string = self.tt_model.decode(
token_ids=token_ids,
truncate_at_eos=truncate_at_eos,
skip_special_tokens=skip_special_tokens,
skip_special_tokens=False,
)
if skip_special_tokens:
Copy link
Contributor

@felipemello1 felipemello1 Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this test pass?

text = decode(encode(text), skip_special_tokens=True)

Probably not well formulated, but i want to see if after every encode/decode we are adding \n\n

Copy link
Contributor

@felipemello1 felipemello1 Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that a string here is necessary to explain why you need this extra logic for "skip_special_tokens" if decode already has this flag. In other words: why do we skip special tokens tokens in two different places? Is there a more elegant way to solve this, like adding the special token to the tt_model directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I understand that the goal here is just to fix skip_special_tokens for Llama3 tokenizer decode, but it seems to me like we are doing something very unexpected here with skip_special_tokens. (a) we have it defined on the base tokenizer and it's basically a no-op now, and (b) we are now inconsistent on whether this needs to be defined on the ModelTokenizer and BaseTokenizer. If it is a function of tokenize_messages on the ModelTokenizer moreso than the BaseTokenizer, maybe we should update the Protocol along with other callsites?

decoded_string = self._remove_special_tokens(decoded_string)
return decoded_string

def _tokenize_header(self, message: Message) -> List[int]:
"""
Expand Down
Loading