Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Optimize block table transfer from CPU to GPU #11401

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,15 @@ set(VLLM_EXT_SRC
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/copy_subranges.cu"
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down
8 changes: 8 additions & 0 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif

// #ifndef USE_ROCM
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
// cudaHostGetDevicePointer(device_ptr, host_ptr, flags)
// #else
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
// hipHostGetDevicePointer(device_ptr, host_ptr, flags)
// #endif
43 changes: 43 additions & 0 deletions csrc/cuda_view.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <torch/all.h>
#include <torch/cuda.h>

// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous");

// Get raw host pointer from CPU tensor
void* host_ptr = cpu_tensor.data_ptr();

// Get a device pointer corresponding to the pinned host memory
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));

// Construct a CUDA tensor from the device pointer.
// We'll use the same sizes, strides, and dtype as the CPU tensor.
auto sizes = cpu_tensor.sizes();
auto strides = cpu_tensor.strides();
auto options =
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA

// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
// memory, so we don't free it here.
auto deleter = [](void*) {
// no-op, since the memory is owned by the original CPU tensor
};

torch::Tensor cuda_tensor =
torch::from_blob(device_ptr, sizes, strides, deleter, options);

TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch");
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch");
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch");

return cuda_tensor;
}
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);

void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n);

torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
Expand Down
75 changes: 75 additions & 0 deletions csrc/prepare_inputs/copy_subranges.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

namespace vllm {
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,
const int* __restrict__ matrix_diff,
int* __restrict__ matrix_tgt, int64_t M) {
int row_id = blockIdx.x;
int row_offset = row_id * M;

int start = matrix_diff[row_id * 2];
int length = matrix_diff[row_id * 2 + 1];
int end = start + length;
int thread_idx = threadIdx.x;
for (int i = start + thread_idx; i < end; i += blockDim.x) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most threads in the block would be idle, e.g. for decoding, there's only one or even no entry changes in the block table.

int idx = row_offset + i;
matrix_tgt[idx] = matrix_src[idx];
}
}
} // namespace vllm

void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n) {
// NOTE(woosuk): Here, we skip most of the error checking to minimize the
// CPU overheads. We assume that the caller will pass the correct inputs.

// Check tensor properties
// TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
// TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
// TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
// TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
// TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
// TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");

auto src_sizes = matrix_src.sizes();
auto diff_sizes = matrix_diff.sizes();
auto tgt_sizes = matrix_tgt.sizes();

// TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
// TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
// TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");

int64_t N = src_sizes[0];
int64_t M = src_sizes[1];

// TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
// TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
// TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
// "matrix_tgt must have same shape as matrix_src");

// TORCH_CHECK(n <= N, "n must be <= N");

const int* d_matrix_src = matrix_src.data_ptr<int>();
const int* d_matrix_diff = matrix_diff.data_ptr<int>();
int* d_matrix_tgt = matrix_tgt.data_ptr<int>();

// One thread block per row.
int blocks = n;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this can easily oversubscribe GPU SMs.

int threads;
if (blocks < 128) {
threads = 1024;
} else if (blocks < 256) {
threads = 512;
} else if (blocks < 512) {
threads = 256;
} else {
threads = 128;
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>(
d_matrix_src, d_matrix_diff, d_matrix_tgt, M);
}
9 changes: 9 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);

ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);

// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
Expand Down Expand Up @@ -98,6 +102,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);

ops.def(
"copy_subranges(Tensor matrix_src, Tensor matrix_diff, Tensor! "
"matrix_tgt, int n) -> ()");
ops.impl("copy_subranges", torch::kCUDA, &copy_subranges);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
14 changes: 14 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,20 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
block_table_bound)


# copy subrange op. Used for input preparation in the vLLM V1 GPU backend.
def copy_subranges(
src_matrix: torch.Tensor,
diff_matrix: torch.Tensor,
tgt_matrix: torch.Tensor,
num_subranges: int,
) -> None:
# NOTE(woosuk): We use `torch.ops._C.copy_subranges.default` instead of
# `torch.ops._C.copy_subranges` to avoid unnecessary CPU overheads from
# the dispatcher.
torch.ops._C.copy_subranges.default(src_matrix, diff_matrix, tgt_matrix,
num_subranges)


# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
input: torch.Tensor,
Expand Down
7 changes: 7 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,13 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")


def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)


def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
Expand Down
116 changes: 116 additions & 0 deletions vllm/v1/worker/gpu_block_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import List

import numpy as np
import torch

from vllm import _custom_ops as ops
from vllm.utils import get_cuda_view_from_cpu_tensor


class BlockTable:

def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device

self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()

# Pinned memory is required to use UVA.
# TODO(woosuk): Add other requirements for UVA.
self.use_uva = pin_memory
if self.use_uva:
self.block_table_diff = torch.zeros((max_num_reqs, 2),
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.block_table_diff_np = self.block_table_diff.numpy()

self.block_table_cpu_cuda_view = get_cuda_view_from_cpu_tensor(
self.block_table_cpu)
self.block_table_diff_cuda_view = get_cuda_view_from_cpu_tensor(
self.block_table_diff)

def add_row(self, row_idx: int, block_ids: List[int]) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, :num_blocks] = block_ids
if self.use_uva:
self.block_table_diff_np[row_idx, 0] = 0
self.block_table_diff_np[row_idx, 1] = num_blocks

def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
if self.use_uva:
self.block_table_diff_np[row_idx, 0] = start
# Move-and-append is not allowed.
assert self.block_table_diff_np[row_idx, 1] == 0
self.block_table_diff_np[row_idx, 1] = num_blocks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the non-uva case, we still need to keep track of the max-block-table-length, so that apply_diff only needs to copy max-block-table-length columns.

Copy link
Collaborator Author

@WoosukKwon WoosukKwon Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The problem is, the memcpy API requires the data to be in contiguous memory space: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY_1g85073372f776b4c4d5f89f7124b7bf79

So when the block table tensor has the shape [batch_size, max_model_len] and if we slice over the second dimension, then we have to call the memcpy API batch_size times instead of once.


def move_row(self, src: int, tgt: int) -> None:
self.block_table_np[tgt] = self.block_table_np[src]
if self.use_uva:
# Append-and-move is allowed.
self.block_table_diff_np[tgt] = self.block_table_diff_np[src]
# Clear the source row.
self.block_table_diff_np[src].fill(0)

def apply_diff(self, num_reqs: int) -> None:
if self.use_uva:
# Only copy the diff to the GPU.
ops.copy_subranges(
self.block_table_cpu_cuda_view,
self.block_table_diff_cuda_view,
self.block_table,
num_reqs,
)
else:
# Copy the entire block table to the GPU.
# NOTE(woosuk): This can be a performance bottleneck when the block
# table is large.
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)

def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
if self.use_uva:
self.block_table_diff.fill_(0)

def clear_diff(self) -> None:
if self.use_uva:
self.block_table_diff_np.fill(0)

def cuda(self) -> torch.Tensor:
return self.block_table

def cpu(self) -> torch.Tensor:
return self.block_table_cpu

def numpy(self) -> np.ndarray:
return self.block_table_np
26 changes: 10 additions & 16 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_block_table import BlockTable

if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
Expand Down Expand Up @@ -64,19 +65,14 @@ def __init__(
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)

# Attention-related.
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
# Block table.
self.block_table = BlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_blocks_per_req=max_num_blocks_per_req,
pin_memory=pin_memory,
device=device,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()

# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
Expand Down Expand Up @@ -141,8 +137,7 @@ def add_request(
start_idx:end_idx] = request.output_token_ids

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
self.block_table.add_row(req_index, request.block_ids)

sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
Expand Down Expand Up @@ -221,13 +216,12 @@ def condense(self, empty_req_indices: List[int]) -> None:
self.req_id_to_index[req_id] = empty_index

# TODO(woosuk): Optimize the copy of token_ids_cpu and
# block_table_cpu.
# block_table.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
Expand Down
Loading
Loading