Skip to content

Commit

Permalink
Strictly check the number of placeholder tokens
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 23, 2024
1 parent 4873ff8 commit 4dbb5a3
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,18 +1128,20 @@ def get_dummy_data(
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]

total_placeholders_by_modality = dict[str, int]()
for modality, placeholders in placeholders_by_modality.items():
num_placeholders = sum(item["length"] for item in placeholders)
max_tokens = mm_max_tokens[modality]

if num_placeholders != max_tokens:
logger.warning(
"The processed dummy data has a total of %d placeholder "
"tokens for the '%s' modality, which is not the expected "
"%d tokens.", num_placeholders, modality, max_tokens)

total_placeholders_by_modality[modality] = num_placeholders
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")

total_len = len(prompt_token_ids)
if total_len > seq_len:
Expand Down

0 comments on commit 4dbb5a3

Please sign in to comment.