Skip to content

Commit

Permalink
test: BOS occurrence in formatted conversation
Browse files Browse the repository at this point in the history
Signed-off-by: Toshiki Kataoka <[email protected]>
  • Loading branch information
toslunar committed Dec 24, 2024
1 parent 1a7f9e3 commit 13f56c1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
Empty file.
47 changes: 47 additions & 0 deletions tests/examples_tests/test_jinja.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pathlib import Path

import pytest
import transformers

jinja_paths = [
pytest.param(path, id=path.stem)
for path in sorted((Path(__name__).parent.parent /
"examples").glob("*.jinja"))
]


@pytest.mark.parametrize("path", jinja_paths)
@pytest.mark.parametrize("num_messages", [1, 3])
def test_bos(path: Path, num_messages: int) -> None:
with path.open("r", encoding="utf-8") as f:
chat_template = f.read()
# We might guess an appropriate tokenizer model from the file name but we
# don't maintain such list.
# Use arbitrary BOS for testing. It doesn't have to match the str in the
# correct tokenizer.
bos_token = "=BOS="
tokenizer = transformers.PreTrainedTokenizerBase(
chat_template=chat_template, bos_token=bos_token, eos_token="=EOS=")
conversation = [
{
"role": "user",
"content": "1"
},
{
"role": "assistant",
"content": "2"
},
{
"role": "user",
"content": "3"
},
][:num_messages]
try:
prompt: str = tokenizer.apply_chat_template(conversation=conversation,
tokenize=False)
except Exception as e:
if str(e) == "Embedding models should only embed one message at a time":
pytest.skip(reason=str(e))
raise
assert prompt.startswith(bos_token)
assert prompt.count(bos_token) == 1

0 comments on commit 13f56c1

Please sign in to comment.