Skip to content

Commit

Permalink
Strictly check the number of placeholder tokens
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 23, 2024
1 parent 4873ff8 commit a679c5b
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,18 +1128,15 @@ def get_dummy_data(
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]

total_placeholders_by_modality = dict[str, int]()
for modality, placeholders in placeholders_by_modality.items():
num_placeholders = sum(item["length"] for item in placeholders)
max_tokens = mm_max_tokens[modality]

if num_placeholders != max_tokens:
logger.warning(
"The processed dummy data has a total of %d placeholder "
"tokens for the '%s' modality, which is not the expected "
"%d tokens.", num_placeholders, modality, max_tokens)

total_placeholders_by_modality[modality] = num_placeholders
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
if total_placeholders_by_modality != mm_max_tokens:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, "
f"which is not the expected {mm_max_tokens} tokens.")

total_len = len(prompt_token_ids)
if total_len > seq_len:
Expand Down

0 comments on commit a679c5b

Please sign in to comment.