Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling #11394

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
148 changes: 148 additions & 0 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Dict

import torch
import torch.nn as nn

from vllm.platforms import current_platform


class TopKTopPSampler(nn.Module):

def forward(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
if current_platform.is_cuda:
return self.forward_cuda(logits, generators, no_top_k, k, no_top_p,
p)
return self.forward_native(logits, generators, no_top_k, k, no_top_p,
p)

def forward_native(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)

def forward_cuda(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
if no_top_k and no_top_p:
return random_sample(probs, generators)
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)


def apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))

if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))

# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits


def random_sample(
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
# This might still be done here unnecessarily if there are greedies
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
probs: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
assert not (no_top_k and no_top_p)
max_top_k_round = 32
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if len(generators) != batch_size:
uniform_samples.uniform_()
if generators:
for i, generator in generators.items():
uniform_samples[:, i].uniform_(generator=generator)

import flashinfer.sampling
if no_top_k:
# Top-p only.
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
probs, uniform_samples, p, deterministic=True)
elif no_top_p:
# Top-k only.
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
probs, uniform_samples, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids, success = (
flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, uniform_samples, k, p, deterministic=True))

if not success.all():
if not no_top_k:
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
if not no_top_p:
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0], deterministic=True)
return next_token_ids.view(-1)
119 changes: 29 additions & 90 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
"""A layer that samples the next tokens from the model's outputs."""
from typing import Dict

import torch
import torch.nn as nn

from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):

def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()

def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
logits = self.apply_temperature(logits, sampling_metadata.temperature)
logits = self.apply_top_k_top_p(logits, sampling_metadata)
# Use float32 for the logits.
logits = logits.to(torch.float32)
orig_logits = logits

probs = self.get_probs(logits)
sampled = self.sample(probs, sampling_metadata)
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)

if sampling_metadata.max_num_logprobs > 0:
logprobs = self.get_logprobs(logits)
logprobs = self.get_logprobs(orig_logits)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
Expand All @@ -52,108 +58,41 @@ def apply_temperature(
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use float32 to apply temperature scaling.
logits = logits.to(torch.float32)
# Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1))
return logits
return logits / temp.unsqueeze(dim=1)

def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)

def apply_top_k_top_p(
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return _apply_top_k_top_p(
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(logits)

random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)

def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.softmax(logits, dim=-1, dtype=torch.float32)

def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)

def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
return probs.argmax(dim=-1).view(-1)

def random_sample(
self,
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
# This might still be done here unnecessarily if there are greedies
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)

def sample(
self,
probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(probs)
if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators)
return random_sampled

greedy_sampled = self.greedy_sample(probs)
random_sampled = self.random_sample(probs,
sampling_metadata.generators)
greedy_sampled = self.greedy_sample(logits)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
)
return sampled


# TODO(woosuk): Optimize this with a custom kernel.
def _apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))

if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))

# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
Loading