Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Introduce SPMD worker execution using Ray accelerated DAG #6032

Merged
merged 2 commits into from
Jul 18, 2024

Conversation

ruisearch42
Copy link
Collaborator

@ruisearch42 ruisearch42 commented Jul 1, 2024

This introduces an SPMD execution mode for Worker. In this mode, there is no longer a driver worker and the rank 0 worker is moved to a separate process. All workers are expected to take an ExecuteModelRequest input, instead of using NCCL as a control plane to receive inputs.

To keep the changes contained, for now, this path needs to be used with the new Ray accelerated DAG feature. Compared to Ray Core, this feature reduces system performance overheads for task execution and args passing, by using an execution loop and shared memory, respectively.

This PR is based on top of #5980 , and added the following:

  • Added e2e correctness tests for VLLM_USE_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1
  • Fixed test failures
  • Resolved conflicts with master
  • Update the required Ray version
  • Add some benchmarks

Benchmarking
TP = 4, requests = 500
Latency column format: latency_with_spmd_change / latency_without_spmd_change

GPU Model input_len output_len qps avg latency % comparison median latency % comparison
A10 Mistral-7B-v0.1 128 128 3 22.7 / 23.4 97.0% 19 / 18.4 103.2%
V100 Llama-2-7b-chat-hf 32 128 3 13.8 / 13.7 100.7% 13.7 / 13.7 100.0%
V100 Llama-2-7b-chat-hf 128 128 3 13.9 / 13.7 101.5% 13.9 / 13.7 101.5%
V100 Llama-2-7b-hf 256 128 3 14.7 / 13.7 107.3% 14.2 / 13.7 103.6%
A100 Meta-Llama-3-70B-Instruct 32 32 6 53.8 / 54.1 99.4% 53.5 / 53.6 99.8%

Summary
For smaller input lengths, the latency is better than or the same as before. For larger input lengths, the latency has small overhead. For larger input lengths, it is expected to have better latency when delta optimization is built on top (work starting soon).

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@ruisearch42 ruisearch42 force-pushed the spmd-tp branch 2 times, most recently from 014685e to 1758da9 Compare July 1, 2024 23:28
@ruisearch42 ruisearch42 marked this pull request as ready for review July 8, 2024 15:24
@youkaichao
Copy link
Member

youkaichao commented Jul 9, 2024

I'd like to support this, but currently the problem is we need to serialize ExecuteModelRequest and SamplerOutput. They have redundant data and can contain on-device data that are expensive to serialize.

I think the first step should be simplify these two structure.

@cadedaniel
Copy link
Collaborator

Can you help me understand the problem better @youkaichao ? I want to understand if it's something we can solve with deltas, plus moving the on-device fields to worker state (like what Jamba modeling does).

@youkaichao
Copy link
Member

@cadedaniel I think #6241 should be a starting point.

And, if this PR can achieve the same performance as the main, then I would be glad to accept it. My current impression is this would be slow because of the inefficient serialization overhead.

@cadedaniel
Copy link
Collaborator

OK. @ruisearch42 will collect numbers and report here.

@rkooo567 rkooo567 self-requested a review July 11, 2024 06:09
@rkooo567 rkooo567 self-assigned this Jul 11, 2024
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the semantic of existing env var USE_RAY_COMPILED_DAG completely. Maybe we should just deprecate this env var (just raise an exception) and replace it to USE_SPMD_WORKER?

.buildkite/test-pipeline.yaml Outdated Show resolved Hide resolved
vllm/executor/distributed_gpu_executor.py Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
@rkooo567
Copy link
Collaborator

And, if this PR can achieve the same performance as the main, then I would be glad to accept it. My current impression is this would be slow because of the inefficient serialization overhead.

this is correct. Our old fork shows that doing input delta optimization can match the perf with the master. Do you think it makes sense to merge the PR and follow up after given the feature is isolated using an env var?

@ruisearch42 ruisearch42 changed the title [wip][Core] Introduce SPMD worker execution using Ray accelerated DAG [Core] Introduce SPMD worker execution using Ray accelerated DAG Jul 11, 2024
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if tests pass!

vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
vllm/worker/worker_base.py Outdated Show resolved Hide resolved
vllm/worker/worker_base.py Show resolved Hide resolved
@rkooo567 rkooo567 added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 16, 2024
@youkaichao
Copy link
Member

sorry for the long wait.

I did some benchmarking for this branch on 4 H100:

without spmd (using mp backend):

$ python benchmarks/benchmark_throughput.py --output-len 256 --input 256 --model meta-llama/Llama-2-7b-hf -tp 4
Throughput: 32.98 requests/s, 16883.88 tokens/s

with spmd:

$ VLLM_USE_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 python benchmarks/benchmark_throughput.py --output-len 256 --input 256 --model meta-llama/Llama-2-7b-hf -tp 4 --distributed-executor-backend ray
Throughput: 17.78 requests/s, 9102.25 tokens/s

the throughput is only a half. I might be wrong in the benchmarking, please help me investigate or reproduce.

there is also a shutdown error, although it is benign:

Exception ignored in: <function RayGPUExecutor.del at 0x7fb0ba6d6160>
Traceback (most recent call last):
File "/data/youkaichao/vllm/vllm/executor/ray_gpu_executor.py", line 373, in del
self.forward_dag.teardown()
File "/data/youkaichao/miniconda/envs/vllm/lib/python3.9/site-packages/ray/dag/compiled_dag_node.py", line 1402, in teardown
monitor.teardown(wait=True)
File "/data/youkaichao/miniconda/envs/vllm/lib/python3.9/site-packages/ray/dag/compiled_dag_node.py", line 1204, in teardown
outer._dag_submitter.close()
File "/data/youkaichao/miniconda/envs/vllm/lib/python3.9/site-packages/ray/experimental/channel/common.py", line 383, in close
self._output_channel.close()
File "/data/youkaichao/miniconda/envs/vllm/lib/python3.9/site-packages/ray/experimental/channel/shared_memory_channel.py", line 629, in close
channel.close()
File "/data/youkaichao/miniconda/envs/vllm/lib/python3.9/site-packages/ray/experimental/channel/shared_memory_channel.py", line 512, in close
self._worker.core_worker.experimental_channel_set_error(self._writer_ref)
AttributeError: 'Worker' object has no attribute 'core_worker'

In general, this is the direction I want to push in the future. However, I would say this implementation is quick and dirty. It is too specialized, and would leave much tech debit for the future. We have two control-plane execution pattern in the same codebase, and the code can be very confusing.

By "quick and dirty", I mean, this PR only specializes to execute_model, and a lot of methods are left untouched. For example, in spmd worker, the driver (engine) does not hold the model anymore, but if we call add_lora, it will still call the driver (engine), which will lead to error. For a full spmd style worker, we should consider all possible functions.

My original plan, is to analyze which objects should live in the engine process and which objects should live in the worker process, and then minimize the data transfer between engine process and worker process. Then we can confidently remove the non-spmd style code completely.

@rkooo567
Copy link
Collaborator

@youkaichao we will take a look at the benchmark. I am 99% sure it is due to that we send all tokens to workers at each batch. The overhead increases with more batch size. So this requires delta input optimization.

@youkaichao
Copy link
Member

why the benchmark of latency shown in #6032 (comment) is so different from benchmark of throughput then?

"we send all tokens to workers at each batch"

I assume this would also affect benchmark of latency.

@rkooo567
Copy link
Collaborator

rkooo567 commented Jul 16, 2024

Btw, we are confirming the theory now! Latency benchmark has lower batch size in general compared to throughput benchmark, and I assume that's why. (so with higher batch, serialization overhead is much higher without delta optimization). But 2X is pretty big, and rui is taking a look at this.

@ruisearch42
Copy link
Collaborator Author

Looking into the benchmarks. Some quick responses:

there is also a shutdown error, although it is benign

Thanks for reporting. This is likely some Ray/config issue, I happen to see the same error yesterday where ADAG is not used. I didn't run into it last time in testing. Will take a look.

By "quick and dirty", I mean, this PR only specializes to execute_model, and a lot of methods are left untouched. For example, in spmd worker, the driver (engine) does not hold the model anymore, but if we call add_lora, it will still call the driver (engine), which will lead to error. For a full spmd style worker, we should consider all possible functions.

Hmm, I think in SPMD mode add_lora will be called on the driver worker (which holds the model), not the driver itself. And it looks straightforward to adapt the code if there is a need.

My original plan, is to analyze which objects should live in the engine process and which objects should live in the worker process, and then minimize the data transfer between engine process and worker process. Then we can confidently remove the non-spmd style code completely.

Great thought. We are probably moving towards the same direction. In this PR, SPMD is config guarded and the plan is to remove non-SPMD path in future without being blocked.

@rkooo567
Copy link
Collaborator

thanks for another review @youkaichao !

Signed-off-by: Rui Qiao <[email protected]>
@rkooo567 rkooo567 enabled auto-merge (squash) July 17, 2024 21:55
auto-merge was automatically disabled July 17, 2024 22:52

Head branch was pushed to by a user without write access

@ruisearch42 ruisearch42 force-pushed the spmd-tp branch 3 times, most recently from dc0e6bb to 90e358f Compare July 18, 2024 00:03
Signed-off-by: Rui Qiao <[email protected]>
@rkooo567 rkooo567 merged commit 61e5927 into vllm-project:main Jul 18, 2024
72 checks passed
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 19, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
gnpinkert pushed a commit to gnpinkert/vllm that referenced this pull request Jul 26, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants