Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support global prefix caching #11385

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/offline_inference_with_global_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory

# NOTE: This is just a running example. For benchmarking purpose,
# please see benchmarks/benchmark_prefix_caching.py

# Common prefix.
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

generating_prompts = [prefix + prompt for prompt in prompts]

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)

# Create an LLM without prefix caching as a baseline.
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)

print("Results without `enable_prefix_caching`")

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = regular_llm.generate(generating_prompts, sampling_params)

regular_generated_texts = []
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
regular_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Destroy the LLM object and free up the GPU memory.
del regular_llm
cleanup_dist_env_and_memory()

# Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
num_global_cache_blocks=5000,
gpu_memory_utilization=0.4)

# Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params)

# Generate with prefix caching.
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)

print("Results with `enable_prefix_caching`")

cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Compare the results and display the speedup
generated_same = all([
regular_generated_texts[i] == cached_generated_texts[i]
for i in range(len(prompts))
])
print(f"Generated answers are the same: {generated_same}")

"""
We can simulate the global prefix cache this way:
1. Run the vllm instance with APC for some time, so some prompts may not hit APC as they are old and evicted

Check failure on line 88 in examples/offline_inference_with_global_prefix.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

examples/offline_inference_with_global_prefix.py:88:81: E501 Line too long (108 > 80)
2. Delete the first vllm instance and start a new one. In this case, global kv cache can be hit directly

Check failure on line 89 in examples/offline_inference_with_global_prefix.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

examples/offline_inference_with_global_prefix.py:89:81: E501 Line too long (104 > 80)
Here we demo the second option.
"""
# Destroy the LLM object and free up the GPU memory.
del prefix_cached_llm
cleanup_dist_env_and_memory()

# Create an LLM with global prefix caching enabled.
global_prefix_cached_llm = LLM(model="facebook/opt-125m",
enable_prefix_caching=True,
num_global_cache_blocks=5000,
gpu_memory_utilization=0.4)

# Generate with global prefix caching.
outputs = global_prefix_cached_llm.generate(generating_prompts, sampling_params)

print("Results with `enable_global_prefix_caching`")

global_cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
global_cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print("-" * 80)

# Compare the results and display the speedup
generated_same = all([
regular_generated_texts[i] == global_cached_generated_texts[i]
for i in range(len(prompts))
])
print(f"Generated answers are the same: {generated_same}")
7 changes: 7 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class FlashAttentionMetadata(AttentionMetadata):
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
block_hash_map: Optional[List[Dict[int, int]]]

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
Expand Down Expand Up @@ -234,6 +235,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
Expand Down Expand Up @@ -284,6 +286,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
Expand Down Expand Up @@ -376,6 +379,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.block_hash_map: List[Dict[int, int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
Expand Down Expand Up @@ -440,6 +444,8 @@ def _add_seq_group(
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
if seq_id in inter_data.block_hash_map:
self.block_hash_map.append(inter_data.block_hash_map[seq_id])

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
Expand Down Expand Up @@ -559,6 +565,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=use_captured_graph,
)

Expand Down
5 changes: 5 additions & 0 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.block_hash_map: List[Dict[int, int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
Expand Down Expand Up @@ -185,6 +186,8 @@ def _add_seq_group(
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
if seq_id in inter_data.block_hash_map:
self.block_hash_map.append(inter_data.block_hash_map[seq_id])

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
Expand Down Expand Up @@ -275,6 +278,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=use_captured_graph,
)

Expand Down Expand Up @@ -326,6 +330,7 @@ def graph_capture_get_metadata_for_batch(
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
block_hash_map={},
use_cuda_graph=True,
)
if is_encoder_decoder_model:
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
Expand Down Expand Up @@ -263,6 +264,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
block_hash_map=self.block_hash_map,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict

import torch

Expand Down Expand Up @@ -28,6 +28,7 @@ class PagedAttentionMetadata:
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
block_hash_map: Optional[List[Dict[int, int]]]


class PagedAttention:
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ def __init__(
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
num_global_cache_blocks: int = 0,
cpu_offload_gb: float = 0,
) -> None:
self.block_size = block_size
Expand All @@ -896,6 +897,7 @@ def __init__(
self.is_attention_free = is_attention_free
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.num_global_cache_blocks = num_global_cache_blocks
self.cpu_offload_gb = cpu_offload_gb

self._verify_args()
Expand Down
17 changes: 17 additions & 0 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.sequence import Sequence

from vllm.global_cache import global_cache_instance

PrefixHash = int

# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
Expand Down Expand Up @@ -179,6 +181,10 @@
# No cached block => Allocate a new block
block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash)
block.append_token_ids(token_ids)

if global_cache_instance.checkExist(block.content_hash):

Check failure on line 185 in vllm/core/block/prefix_caching_block.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Argument 1 to "checkExist" of "GlobalCache" has incompatible type "Optional[int]"; expected "int" [arg-type]

Check failure on line 185 in vllm/core/block/prefix_caching_block.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Argument 1 to "checkExist" of "GlobalCache" has incompatible type "int | None"; expected "int" [arg-type]

Check failure on line 185 in vllm/core/block/prefix_caching_block.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Argument 1 to "checkExist" of "GlobalCache" has incompatible type "int | None"; expected "int" [arg-type]

Check failure on line 185 in vllm/core/block/prefix_caching_block.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Argument 1 to "checkExist" of "GlobalCache" has incompatible type "int | None"; expected "int" [arg-type]
block.global_computed = True

Check failure on line 186 in vllm/core/block/prefix_caching_block.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

"Block" has no attribute "global_computed" [attr-defined]

Check failure on line 186 in vllm/core/block/prefix_caching_block.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

"Block" has no attribute "global_computed" [attr-defined]

Check failure on line 186 in vllm/core/block/prefix_caching_block.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

"Block" has no attribute "global_computed" [attr-defined]

Check failure on line 186 in vllm/core/block/prefix_caching_block.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

"Block" has no attribute "global_computed" [attr-defined]

return block

def allocate_immutable_blocks(
Expand Down Expand Up @@ -703,6 +709,7 @@
block_id: Optional[int] = None,
computed: bool = False,
extra_hash: Optional[int] = None,
global_computed: bool = False,
):
assert isinstance(allocator, PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
Expand All @@ -718,6 +725,8 @@
self._computed = computed
self._extra_hash = extra_hash

self._global_computed = global_computed

# On the first time, we create the block object, and next we only
# reinitialize it
if hasattr(self, "_block"):
Expand Down Expand Up @@ -759,6 +768,14 @@
def computed(self, value) -> None:
self._computed = value

@property
def global_computed(self) -> bool:
return self._global_computed

@global_computed.setter
def global_computed(self, value) -> None:
self._global_computed = value

@property
def last_accessed(self) -> float:
return self._last_accessed
Expand Down
24 changes: 22 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@
is_prompt=False,
seq_data={},
sampling_params=None,
block_tables={})
block_tables={},
block_global_computed_tables={})


def scheduler_running_outputs_builder():
Expand Down Expand Up @@ -1314,6 +1315,8 @@
seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}
block_hash_map: Dict[int, Dict[int, int]] = {}
block_global_computed_tables : Dict[int, List[int]] = {}

if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
Expand All @@ -1332,6 +1335,21 @@
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)

if self.cache_config.enable_prefix_caching and self.cache_config.num_global_cache_blocks > 0 and seq_group.is_prefill():

Check failure on line 1339 in vllm/core/scheduler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/scheduler.py:1339:81: E501 Line too long (136 > 80)
global_computed_list = []
block_hash_dict = {}
for block_id in block_tables[seq_id]:
for block in self.block_manager.block_tables[seq_id].blocks:

Check failure on line 1343 in vllm/core/scheduler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/scheduler.py:1343:81: E501 Line too long (84 > 80)
if block.block_id == block_id:
if block.global_computed:
global_computed_list.append(block_id)
if block.content_hash is not None:
block_hash_dict[block_id] = block.content_hash

Check failure on line 1348 in vllm/core/scheduler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/scheduler.py:1348:81: E501 Line too long (82 > 80)
break
block_global_computed_tables[seq_id] = global_computed_list
block_hash_map[seq_id] = block_hash_dict

self.block_manager.access_all_blocks_in_seq(seq, now)

if self.cache_config.enable_prefix_caching:
Expand All @@ -1344,7 +1362,7 @@
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill = False
if is_prompt:
if is_prompt:
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
Expand All @@ -1368,6 +1386,8 @@
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
block_global_computed_tables=block_global_computed_tables,
block_hash_map=block_hash_map,
do_sample=do_sample,
pooling_params=seq_group.pooling_params,
token_chunk_size=token_chunk_size,
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/arg_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class EngineArgs:
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
num_global_cache_blocks: int = 0
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
swap_space: float = 4 # GiB
Expand Down Expand Up @@ -438,6 +439,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Enables automatic prefix caching. "
"Use --no-enable-prefix-caching to disable explicitly.",
)
parser.add_argument(
"--num-global-cache-blocks",
type=int,
default=EngineArgs.num_global_cache_blocks,
help="number of global kv cache blocks, 0 for disable. ")
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
Expand Down Expand Up @@ -1044,6 +1050,7 @@ def create_engine_config(self,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
num_global_cache_blocks=self.num_global_cache_blocks,
cpu_offload_gb=self.cpu_offload_gb,
)
parallel_config = ParallelConfig(
Expand Down
Loading
Loading