Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vLLM's V1 Engine Architecture #8779

Open
1 task done
simon-mo opened this issue Sep 24, 2024 · 10 comments
Open
1 task done

vLLM's V1 Engine Architecture #8779

simon-mo opened this issue Sep 24, 2024 · 10 comments
Labels

Comments

@simon-mo
Copy link
Collaborator

simon-mo commented Sep 24, 2024

This issues describes the high level directions that "create LLM Engine V1". We want the design to be as transparent as possible and created this issue to track progress and solicit feedback.

Goal:

  • The new engine will be simple and performant. We found the first iteration of the engine to be simple, the multistep engine to be performant, but we want best of the both worlds. For it to be performat, we want to minimize GPU idle time.
  • The new architecture will be extensible and modular. We found the current codebase becoming difficult to extend and add new features (both production and experimental features) due to the hard tangling of different features. In the new design, features should be compatible with each other.
  • Tech debts will be cleaned up. We will remove optimizations that compromise code readability. We will also redo ad-hoc implementations to support certain features/models.

Non-goals, the following are important but orthogonal:

  • Optimize GPU time/kernels
  • Add new features/optimizations
  • Performance in rare cases

The scope is exclusively in the scheduler, memory manager, distributed architecture. We will not touch APIs, models, kernels, and most parts of the model runner.

Highlights of the new design:

  • Driver process + SPMD workers
    • When TP=n & PP=m, vLLM engine will have n*m + 1 processes in total.
      • Corollary: even when using a single GPU, we will have 2 processes.
    • The driver process will have the scheduler, memory manager, etc.
    • The workers are stateful, maintaining most of the request states.
      • The driver will only send the “diffs”
        • New request: input token IDs & block tables & sampling params, etc.
        • In-flight request: scheduled request IDs, new block IDs (no token IDs, sampling params, etc.)
    • Clean up data structures like SeqGroupMetadata
  • Async single-step scheduling, instead of multi-step scheduling
    • Scheduler will schedule the n+1-th step, while the worker is executing the n-th step.
    • We will reuse the code from multi-step scheduling to incrementally update the model inputs.
    • Needs a special care for PP, since the output token IDs from the last stage should be sent to the first stage.
  • De-tokenizer moves to the driver process
    • Async de-tokenization can be regarded as part of async scheduling
  • Native support for different types of model states
    • Regular KV cache, Mamba cache, encoder cache, etc.
    • Dedicated memory manager & block table for each type of cache
  • Drop beam search from vLLM engine
    • Provide a solution to emulate beam search outside vLLM engine
  • Prefix-caching as a first-class feature
    • Implement parallel sampling via prefix caching
    • Remove the concept of SequenceGroup
    • Optimize prefix caching overheads
  • Remove/minimize PyObjectCache

Lessons we learned from V1:

  • To achieve high GPU utilization, we should care about everything happening on the CPU.
    • Python is slow.
    • Fast GPUs like H100 do not necessarily have fast CPUs. They may have hundreds of CPU cores, but each with low clock speed.
    • Moreover, GPUs will get faster and faster, while CPUs will not.
  • Scheduling is not cheap.
    • For every step, the vLLM scheduler goes over the whole self.running queue and performs some operations for each request (e.g., allocating a new block). And this is written in Python.
  • Input broadcasting is expensive.
    • Instead of sending request information from scheduler to workers every step, the workers should be stateful and maintain most of the request states.
  • Preparing the model & sampler inputs (e.g., block table) is expensive.
    • We should cache the inputs of the previous steps, and** build new inputs incrementally from the cached inputs**, if possible.
    • However, not every state should be kept in GPU memory. It’s OK to cache & incrementally build some inputs in CPU memory, and send them to GPU every step.
  • De-tokenization is expensive.
    • For every step, vLLM de-tokenizes the generated output token IDs and checks the stop criteria.
    • The overhead becomes significant for large batch sizes.
  • Sampler is expensive.
    • The GPU operations themselves are not very expensive.
    • However, “pythonizing” the sampler outputs is expensive.
    • Plus, the sampler can launch many small GPU kernels with CPU-GPU synchronizations.
  • Supporting different types of model states (e.g., KV cache, Mamba cache, encoder cache) is challenging.
    • We need native cache managers for these different types of caches.
    • We need to deal with memory fragmentation due to the different sizes of the different states

Timeline wise, we plan to execute the changes incrementally. Overtime we will add PRs and issues related to the new architecture here.

The design is led by the vLLM maintainers @WoosukKwon @zhuohan123 @youkaichao @simon-mo @LiuXiaoxuanPKU @comaniac @alexm-neuralmagic @njhill @robertgshaw2-neuralmagic @rkooo567 and many others!

@youkaichao
Copy link
Member

I want to highlight that, the re-arch will only affect vllm developers who need to change vLLM's code, in a positive way to make their lives easier. For vLLM users who use vLLM directly, there would be no breaking changes except for beam-search. And we hope to bring better performance for users as well as an extensible architecture for developers.

@simon-mo simon-mo pinned this issue Sep 24, 2024
@simon-mo simon-mo changed the title [Draft] vLLM's V2 Engine Architecture vLLM's V2 Engine Architecture Sep 24, 2024
@noooop
Copy link
Contributor

noooop commented Sep 25, 2024

As vllm supports more and more models and functions, they require different attention, scheduler, executor, and input output processor. . These modules are becoming increasingly complex, and sometimes new features must be compromised for compatibility. ultimately leading to suboptimal results

Take support for encode only models as an example

Although the encode only models is much simpler than the decode model, they are very different.

The simplest way to support the encode only models is to implement different modules for models of different architectures and load the required modules on demand.

I call this architecture Workflow Defined Engine, or WDE for short.

PTAL #8453 #8452


I'm implementing async scheduler (Async single-step scheduling).
Beam search and SeqGroupMetadata drive me crazy.
Awesome to hear about Beam Search and SeqGroupMetadata removed

@lixiaolx
Copy link

mark

@noooop
Copy link
Contributor

noooop commented Sep 27, 2024

Workflow Definition Engine draft pull request is almost complete and there are almost 10,000 lines of code.

as @DarkLight1337 said:

Hi, as mentioned earlier, there is basically no way we can merge all these changes at once. You should break up this refactoring into multiple stages.

Therefore, we hope to invite more people to participate, including but not limited to providing suggestions, participating in discussions, align with vLLM's V2 engine architecture goals, and discussing how to break it into stages, help review code for future PRs

Let me briefly introduce the content of this PR. Including

  1. what new models need to be supported,
  2. what new features these new models have, and
  3. how engine Architecture needs to support these features flexibly and efficiently.

What new models need to be supported

These models are all from issues and are also very famous:

  • xlm_roberta
  • bge-m3
  • bge-reranker-v2-m3
  • bert
  • bge v1.5 family
  • Snowflake Arctic Embed (Family)
  • gte-Qwen2
  • This list is still growing

These models is roughly divided into three categories:

  • Encode only models. (Bidirectional Transformers, causal=False), Often fine-tuned as retriever and reranker etc.
  • Decode only models. (masked multi-head attention, causal=True). There are two interesting uses:
    • Output last hidden states as a feature extractor
    • Decode only retriever (I don't know of a better name),E.g. e5-mistral-7b (The only Embed model currently supported by vllm)
    • Whether it has been fine-tuned or not, there is almost no difference in the code.
  • Enable bidirectional. LLM2Vec propose a simple unsupervised approach that can transform any decoder-only LLM into a strong text encoder.

What new features these new models have

What the above three categories have in common is that there is only the prefill stage. In order to make the terminology more precise, prefill only is used below.

You can think of prefill only as encode only fancy writing.

New features:

  1. attention
    • Prefill only models requires simpler attention implementations, no need to consider kvcache, no decoding phase
    • We need to support enable_bidirectional flag manually or read hf config automatically, enable bidirectional.
  2. scheduler
    • Prefill only models requires simpler scheduler, no need to consider kvcache and preemption
    • Prefill only models, there is no correlation between tasks, so it is easy to implement async scheduling
  3. executer
    • In order to support async scheduling, model_input_builder needs to be separated from the runner.
    • The main thread executes scheduling and all CPU processing, and the gpu thread only executes h2d, execution model, d2h
    • If async scheduling and async execution are implemented, data parallelism is also easy to implement. Data parallelism is more efficient for small models

How engine Architecture needs to support these features flexibly and efficiently.

If we directly add new functions to existing modules, these modules are becoming increasingly complex, and sometimes new features must be compromised for compatibility. ultimately leading to suboptimal results

The most flexible and efficient way to support the prefill only models is to implement different modules for models of different architectures and load the required modules on demand.

I call this architecture Workflow Defined Engine, or WDE for short.

I divided the Engine into the following modules.

  • InputProcessor: The llm models inputs strings, the reranker inputs pairs, and the multimodal model input is more complex...
  • OutputProcessor: The retriever(embedding) models output embeddings, reranker models and classification models output Scores...
  • ModelInputBuilder: Building model inputs and attention metadata
  • AttnBackend: Support different AttnBackend and enable bidirectional
  • Tokenizer: There may be different tokenizers
  • Executor: Sync\Async\TP\PP\DP\maybe more
  • Worker & runner: Support different devices\maybe more
  • EngineArgs: Different models, different config may accept different parameters
  • maybe more

With wde, there is no need for one module to be compatible with all functions. You can use the dynamic loading feature of python to load different modules at the highest level, for different models and different needs.

  • Modules can be configured through Workflow, plug and play
  • Flexibly support plug-ins, and developers can load their own modules.
  • Workflow is really the best place to hide dirty codes.
    Some models cannot use the common Workflow. When you don’t know where to put the dirty code, you can always create a new workflow and link the model architecture to the new workflow to avoid leaving dirty code everywhere for the sake of compatibility.

PTAL #8453 #8452

@Yang-x-Zhao
Copy link

Given Driver process + SPMD workers, it's there a chance to separate LLMEngine process and worker processes on different nodes(servers)? To be more concrete, the OpenAPI server process and LLMEngine process should live on a node with high performance CPU only, while the worker processes should live on normal GPU node(s).

I guess this idea is somehow related to ray spmd worker: #6556, even though I suspect their current implementation is not supporting a distributed LLMEngine process.

@Venkat2811
Copy link

Lessons we learned from V1:

To achieve high GPU utilization, we should care about everything happening on the CPU.

  • Python is slow.

Scheduling is not cheap.

  • For every step, the vLLM scheduler goes over the whole self.running queue and performs some operations for each request (e.g., allocating a new block). And this is written in Python.

Sampler is expensive.

  • However, “pythonizing” the sampler outputs is expensive.

@simon-mo is the team considering moving away from python ?

@yuki252111
Copy link

mark

@wedobetter
Copy link

Lessons we learned from V1:
To achieve high GPU utilization, we should care about everything happening on the CPU.

  • Python is slow.

Scheduling is not cheap.

  • For every step, the vLLM scheduler goes over the whole self.running queue and performs some operations for each request (e.g., allocating a new block). And this is written in Python.

Sampler is expensive.

  • However, “pythonizing” the sampler outputs is expensive.

@simon-mo is the team considering moving away from python ?

Probably easier to cythonize critical bits and wait for PY3.13 support in torch

@sleepwalker2017
Copy link

We notice that, when input lengths are short, for example less than 200, the prefill stages costs too much GPU idle.
If python code is too slow to make gpu busy, can we take prefills for short sequences into cuda graph?

@Stonesjtu
Copy link

Stonesjtu commented Dec 20, 2024

Drop beam search from vLLM engine
Provide a solution to emulate beam search outside vLLM engine

@simon-mo v1 seems a huge performance bumps in terms of sampling and multi-modality support.

However beam search provides flexibility for users who don't care overall speed, do we have the solution for a stand-alone beam search right now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

10 participants