Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][BugFix] Mamba/Jamba exceed mamba cache slots #11414

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER: float = 1.5


def get_default_cache_root():
Expand Down Expand Up @@ -466,6 +467,8 @@ def get_default_config_root():
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
"VLLM_DISABLE_COMPILE_CACHE":
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
"VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER":
lambda: float(os.getenv("VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER", "1.5")),
}

# end-env-vars-definition
Expand Down
20 changes: 11 additions & 9 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn
from transformers import JambaConfig

from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
Expand Down Expand Up @@ -422,17 +423,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \

effective_max_batch_size = int(
self.vllm_config.scheduler_config.max_num_seqs * \
envs.VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER
)
if not self.model_config.enforce_eager \
and effective_max_batch_size <= \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
self.max_batch_size = vllm_config.pad_for_cudagraph(
effective_max_batch_size)
else:
self.max_batch_size = 8192 + 2
self.max_batch_size = effective_max_batch_size

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
Expand Down
20 changes: 11 additions & 9 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn
from transformers import MambaConfig

from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -195,17 +196,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \

effective_max_batch_size = int(
self.vllm_config.scheduler_config.max_num_seqs * \
envs.VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER
)
if not self.model_config.enforce_eager \
and effective_max_batch_size <= \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
self.max_batch_size = vllm_config.pad_for_cudagraph(
effective_max_batch_size)
else:
self.max_batch_size = 8192 + 2
self.max_batch_size = effective_max_batch_size

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)
Expand Down
Loading