Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Align worker index with node boundary #7932

Merged
merged 4 commits into from
Sep 2, 2024
Merged

[TPU] Align worker index with node boundary #7932

merged 4 commits into from
Sep 2, 2024

Conversation

WoosukKwon
Copy link
Collaborator

Fixes #7485

@WoosukKwon WoosukKwon added the tpu Related to Google TPUs label Aug 28, 2024
@WoosukKwon WoosukKwon requested a review from youkaichao August 28, 2024 00:25
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member

is it enough to remove the following lines?

with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):

@WoosukKwon
Copy link
Collaborator Author

@youkaichao I tried it, but still got gibberish results without the patch. I think this is because the rank IDs used in all gather are assigned by XLA runtime, regardless of IPs.

@youkaichao
Copy link
Member

I will not block this pr for that reason, but it would be better to know in the future, what's the rank used by XLA.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please fix the format

@Beomi
Copy link

Beomi commented Aug 28, 2024

Hi, I still get this error due to the mismatched global world size even on this branch(tpu-rank) version.

I use these commands below to build new docker image based on Dockerfile.tpu and use the code of this branch.

# on main node, main vm IP:PORT is 10.130.0.60:6379
sudo docker run -t -d -e HF_TOKEN=INPUT_YOUR_TOKEN --privileged --net host --shm-size=16G -it vllm ray start --head --block
# on other nodes
sudo docker run -t -d -e VLLM_HOST_IP=10.130.0.60 -e HF_TOKEN=YOUR_TOKEN -e  --privileged --net host --shm-size=16G -it vllm ray start --address 10.130.0.60:6379 --block

after checking this ray status -- to confirm all the tpu cores are gathered in ray(32 cores since I use TPUv4-64 for this exp)

======== Autoscaler status: 2024-08-28 04:31:11.442038 ========
Node status
---------------------------------------------------------------
Active:
 1 node_7f47f048ccc1c5455a203b92ef6d2bdd5d0332b18fa6f092d6c6425a
 1 node_be42b0ac8682eacb5da8cee5918de34e28dc90b1041b6522405c1c62
 1 node_13290553e86634224a524ec68c5b3e3f0478e491669681f92723ec49
 1 node_2ee5e9791b60dfaf727341ae8a67910996f0572ba36260df07507668
 1 node_3ad7ae73804a43f3a6acbc4e9b48f92693c8cf9e09e275e49863c056
 1 node_195573b8d1e45ef8e2e27d20cf530f48884e0be3e8ead3bf9c716338
 1 node_a1804b4b12788b7a33e02f0f9c38d18512d6ee4d27a9d9775fce7f5a
 1 node_13d0f73aeba95e5f03f42494a57a9ac1204e273106e7ad109dde11f5
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/1920.0 CPU
 0.0/32.0 TPU
 0.0/1.0 TPU-v4-64-head
 0B/3.00TiB memory
 0B/121.60GiB object_store_memory
 0.0/8.0 v4-64

However when I run this command to run vllm server --

vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 -tp 4 -pp 8 --device tpu --distributed_executor_backend ray

this error comes.

(pid=614, ip=10.130.15.194) INFO 08-28 04:16:28 importing.py:10] Triton not installed; certain GPU-related functions will not be available. [repeated 2x across cluster]
(RayWorkerWrapper pid=376, ip=10.130.0.62) WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
INFO 08-28 04:16:35 selector.py:198] Cannot use _Backend.FLASH_ATTN backend on TPU.
INFO 08-28 04:16:35 selector.py:146] Using Pallas backend.
(RayWorkerWrapper pid=376, ip=10.130.15.193) INFO 08-28 04:16:35 selector.py:198] Cannot use _Backend.FLASH_ATTN backend on TPU.
(RayWorkerWrapper pid=376, ip=10.130.15.193) INFO 08-28 04:16:35 selector.py:146] Using Pallas backend.
(pid=615, ip=10.130.0.59) INFO 08-28 04:16:32 importing.py:10] Triton not installed; certain GPU-related functions will not be available.
ERROR 08-28 04:16:35 worker_base.py:465] Error executing method init_device. This might cause deadlock in distributed execution.
ERROR 08-28 04:16:35 worker_base.py:465] Traceback (most recent call last):
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/worker/worker_base.py", line 457, in execute_method
ERROR 08-28 04:16:35 worker_base.py:465]     return executor(*args, **kwargs)
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/worker/tpu_worker.py", line 83, in init_device
ERROR 08-28 04:16:35 worker_base.py:465]     ensure_model_parallel_initialized(
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/parallel_state.py", line 965, in ensure_model_parallel_initialized
ERROR 08-28 04:16:35 worker_base.py:465]     initialize_model_parallel(tensor_model_parallel_size,
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/parallel_state.py", line 931, in initialize_model_parallel
ERROR 08-28 04:16:35 worker_base.py:465]     _TP = init_model_parallel_group(group_ranks,
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/parallel_state.py", line 773, in init_model_parallel_group
ERROR 08-28 04:16:35 worker_base.py:465]     return GroupCoordinator(
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/parallel_state.py", line 175, in __init__
ERROR 08-28 04:16:35 worker_base.py:465]     self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/device_communicators/tpu_communicator.py", line 29, in __init__
ERROR 08-28 04:16:35 worker_base.py:465]     local_rank = global_rank % local_world_size
ERROR 08-28 04:16:35 worker_base.py:465] ZeroDivisionError: integer division or modulo by zero
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/workspace/vllm/vllm/entrypoints/openai/rpc/server.py", line 230, in run_rpc_server
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
  File "/workspace/vllm/vllm/entrypoints/openai/rpc/server.py", line 31, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 709, in from_engine_args
    engine = cls(
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 600, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 809, in _init_engine
    return engine_class(*args, **kwargs)
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 261, in __init__
    super().__init__(*args, **kwargs)
  File "/workspace/vllm/vllm/engine/llm_engine.py", line 288, in __init__
    self.model_executor = executor_class(
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 304, in __init__
    super().__init__(*args, **kwargs)
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 36, in __init__
    super().__init__(*args, **kwargs)
  File "/workspace/vllm/vllm/executor/executor_base.py", line 46, in __init__
    self._init_executor()
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 48, in _init_executor
    self._init_workers_ray(placement_group)
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 175, in _init_workers_ray
    self._run_workers("init_device")
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 242, in _run_workers
    driver_worker_output = self.driver_worker.execute_method(
  File "/workspace/vllm/vllm/worker/worker_base.py", line 466, in execute_method
    raise e
  File "/workspace/vllm/vllm/worker/worker_base.py", line 457, in execute_method
    return executor(*args, **kwargs)
  File "/workspace/vllm/vllm/worker/tpu_worker.py", line 83, in init_device
    ensure_model_parallel_initialized(
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 965, in ensure_model_parallel_initialized
    initialize_model_parallel(tensor_model_parallel_size,
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 931, in initialize_model_parallel
    _TP = init_model_parallel_group(group_ranks,
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 773, in init_model_parallel_group
    return GroupCoordinator(
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 175, in __init__
    self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
  File "/workspace/vllm/vllm/distributed/device_communicators/tpu_communicator.py", line 29, in __init__
    local_rank = global_rank % local_world_size
ZeroDivisionError: integer division or modulo by zero

Because of the code here: https://github.com/vllm-project/vllm/blob/tpu-rank/vllm/distributed/device_communicators/tpu_communicator.py#L28

it causes

class TpuCommunicator:
        ....
        # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
        # must be used together. Therefore, the local rank and world size can
        # be simply calculated as follows.
        global_rank = dist.get_rank(group) # <- this global rank wil be 0 on main node
        global_world_size = dist.get_world_size(group) # <- this size should be 32 but it is 4(single node's TPU cores)
        num_nodes = len(ray.nodes()) # <- correct value, 8 for TPUv4-64
        local_world_size = global_world_size // num_nodes # this line cases err since 4 // 8 == 0, causing later line ZeroDivision Error.
        local_rank = global_rank % local_world_size
        pjrt.initialize_multiprocess(local_rank, local_world_size)
        xr._init_world_size_ordinal()

I think pytorch distributed dist.get_world_size(group) can't gather all the TPU cores information from ray cluster, rather it uses local TPU xla device only.

@Beomi
Copy link

Beomi commented Aug 28, 2024

@WoosukKwon I think this problem comes from these lines:

for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

I think this codes are supposed to create local cpu group for optmized rank (such as TP in 1 node) but I think this cases cpu_group to have only 4 process(=1 node TPU cores), making Zero Division err to the code above I mentioned.

is there suggested way to workaround this issue?

@DarkLight1337
Copy link
Member

LGTM, please fix the format

#7929 fixes the CI failure related to mypy, please merge from main to resolve it.

@WoosukKwon WoosukKwon merged commit e2b2aa5 into main Sep 2, 2024
19 of 21 checks passed
@WoosukKwon WoosukKwon deleted the tpu-rank branch September 2, 2024 06:09
@WoosukKwon
Copy link
Collaborator Author

@Beomi Thanks for trying out this PR. Currently, vLLM's TPU backend does not support PP. Could you please use smaller TPU pod and retry with the updated main branch?

@Beomi
Copy link

Beomi commented Sep 2, 2024

@WoosukKwon Thanks for clarification! it seems like TP over all nodes works :)
Hope PP comes to TPU as well.
Thank you 😄

@Beomi
Copy link

Beomi commented Sep 2, 2024

@WoosukKwon BTW, there is weird situation related with Multinode situation -- the code works well with TPUv4-8/v4-16/v4-32 but it fails to launch worker on one node if I run on v4-64. I'll open it on the other issue :)

Manikandan-Thangaraj-ZS0321 added a commit to Manikandan-Thangaraj-ZS0321/vllm that referenced this pull request Sep 2, 2024
[TPU] Align worker index with node boundary (vllm-project#7932)
@youkaichao
Copy link
Member

Hope PP comes to TPU as well.

TPU has very good cross-node inter connect. I don't think it needs PP.

triple-Mu pushed a commit to triple-Mu/vllm_official that referenced this pull request Sep 4, 2024
dsikka pushed a commit to neuralmagic/vllm that referenced this pull request Sep 5, 2024
opus24 added a commit to Hyper-Accel/vllm that referenced this pull request Sep 10, 2024
commit a1d8742
Author: Simon Mo <[email protected]>
Date:   Mon Sep 9 23:21:00 2024 -0700

    Add NVIDIA Meetup slides, announce AMD meetup, and add contact info (vllm-project#8319)

commit 6cd5e5b
Author: Dipika Sikka <[email protected]>
Date:   Mon Sep 9 23:02:52 2024 -0400

    [Misc] Fused MoE Marlin support for GPTQ (vllm-project#8217)

commit c7cb5c3
Author: Kyle Sayers <[email protected]>
Date:   Mon Sep 9 16:27:26 2024 -0400

    [Misc] GPTQ Activation Ordering (vllm-project#8135)

commit f9b4a2d
Author: Vladislav Kruglikov <[email protected]>
Date:   Mon Sep 9 21:20:46 2024 +0300

    [Bugfix] Correct adapter usage for cohere and jamba (vllm-project#8292)

commit 58fcc85
Author: Adam Lugowski <[email protected]>
Date:   Mon Sep 9 11:16:37 2024 -0700

    [Frontend] Add progress reporting to run_batch.py (vllm-project#8060)

    Co-authored-by: Adam Lugowski <[email protected]>

commit 08287ef
Author: Kyle Mistele <[email protected]>
Date:   Mon Sep 9 09:45:11 2024 -0500

    [Bugfix] Streamed tool calls now more strictly follow OpenAI's format; ensures Vercel AI SDK compatibility (vllm-project#8272)

commit 4ef41b8
Author: Alexander Matveev <[email protected]>
Date:   Sun Sep 8 00:01:51 2024 -0400

    [Bugfix] Fix async postprocessor in case of preemption (vllm-project#8267)

commit cfe712b
Author: Joe Runde <[email protected]>
Date:   Sat Sep 7 14:03:16 2024 -0600

    [CI/Build] Use python 3.12 in cuda image (vllm-project#8133)

    Signed-off-by: Joe Runde <[email protected]>

commit b962ee1
Author: sumitd2 <[email protected]>
Date:   Sat Sep 7 23:48:40 2024 +0530

    ppc64le: Dockerfile fixed, and a script for buildkite (vllm-project#8026)

commit 36bf815
Author: Isotr0py <[email protected]>
Date:   Sun Sep 8 01:45:44 2024 +0800

    [Model][VLM] Decouple weight loading logic for `Paligemma` (vllm-project#8269)

commit e807125
Author: Isotr0py <[email protected]>
Date:   Sat Sep 7 16:38:23 2024 +0800

    [Model][VLM] Support multi-images inputs for InternVL2 models (vllm-project#8201)

commit 9f68e00
Author: Cyrus Leung <[email protected]>
Date:   Sat Sep 7 16:02:39 2024 +0800

    [Bugfix] Fix broken OpenAI tensorizer test (vllm-project#8258)

commit ce2702a
Author: youkaichao <[email protected]>
Date:   Fri Sep 6 22:40:46 2024 -0700

    [tpu][misc] fix typo (vllm-project#8260)

commit 795b662
Author: Wei-Sheng Chin <[email protected]>
Date:   Fri Sep 6 20:18:16 2024 -0700

    Enable Random Prefix Caching in Serving Profiling Tool (benchmark_serving.py) (vllm-project#8241)

commit 2f707fc
Author: Cyrus Leung <[email protected]>
Date:   Sat Sep 7 10:57:24 2024 +0800

    [Model] Multi-input support for LLaVA (vllm-project#8238)

commit 41e95c5
Author: Kyle Mistele <[email protected]>
Date:   Fri Sep 6 21:49:01 2024 -0500

    [Bugfix] Fix Hermes tool call chat template bug (vllm-project#8256)

    Co-authored-by: Kyle Mistele <[email protected]>

commit 12dd715
Author: William Lin <[email protected]>
Date:   Fri Sep 6 17:48:48 2024 -0700

    [misc] [doc] [frontend] LLM torch profiler support (vllm-project#7943)

commit 29f49cd
Author: Patrick von Platen <[email protected]>
Date:   Sat Sep 7 01:02:05 2024 +0200

    [Model] Allow loading from original Mistral format (vllm-project#8168)

    Co-authored-by: Michael Goin <[email protected]>

commit 23f3222
Author: Dipika Sikka <[email protected]>
Date:   Fri Sep 6 18:29:03 2024 -0400

    [Misc] Remove `SqueezeLLM` (vllm-project#8220)

commit 9db52ea
Author: rasmith <[email protected]>
Date:   Fri Sep 6 17:26:09 2024 -0500

    [Kernel] [Triton] Memory optimization for awq_gemm and awq_dequantize, 2x throughput (vllm-project#8248)

commit 1447c97
Author: Alexey Kondratiev(AMD) <[email protected]>
Date:   Fri Sep 6 14:51:03 2024 -0400

    [CI/Build] Increasing timeout for multiproc worker tests (vllm-project#8203)

commit de80783
Author: Rui Qiao <[email protected]>
Date:   Fri Sep 6 09:18:35 2024 -0700

    [Misc] Use ray[adag] dependency instead of cuda (vllm-project#7938)

commit e5cab71
Author: afeldman-nm <[email protected]>
Date:   Fri Sep 6 12:01:14 2024 -0400

    [Frontend] Add --logprobs argument to `benchmark_serving.py` (vllm-project#8191)

commit baa5467
Author: Nick Hill <[email protected]>
Date:   Thu Sep 5 20:39:29 2024 -0700

    [BugFix] Fix Granite model configuration (vllm-project#8216)

commit db3bf7c
Author: Jiaxin Shan <[email protected]>
Date:   Thu Sep 5 18:10:33 2024 -0700

    [Core] Support load and unload LoRA in api server (vllm-project#6566)

    Co-authored-by: Jee Jee Li <[email protected]>

commit 2febcf2
Author: sroy745 <[email protected]>
Date:   Thu Sep 5 13:25:29 2024 -0700

    [Documentation][Spec Decode] Add documentation about lossless guarantees in Speculative Decoding in vLLM (vllm-project#7962)

commit 2ee4528
Author: Michael Goin <[email protected]>
Date:   Thu Sep 5 11:09:46 2024 -0400

    Move verify_marlin_supported to GPTQMarlinLinearMethod (vllm-project#8165)

commit 9da25a8
Author: Alex Brooks <[email protected]>
Date:   Thu Sep 5 06:48:10 2024 -0600

    [MODEL] Qwen Multimodal Support (Qwen-VL / Qwen-VL-Chat) (vllm-project#8029)

    Signed-off-by: Alex-Brooks <[email protected]>
    Co-authored-by: DarkLight1337 <[email protected]>

commit 8685ba1
Author: [email protected] <[email protected]>
Date:   Thu Sep 5 17:03:37 2024 +0530

    Inclusion of InternVLChatModel In PP_SUPPORTED_MODELS(Pipeline Parallelism) (vllm-project#7860)

commit 288a938
Author: Cyrus Leung <[email protected]>
Date:   Thu Sep 5 18:51:53 2024 +0800

    [Doc] Indicate more information about supported modalities (vllm-project#8181)

commit e39ebf5
Author: Elfie Guo <[email protected]>
Date:   Wed Sep 4 22:12:26 2024 -0700

    [Core/Bugfix] Add query dtype as per FlashInfer API requirements. (vllm-project#8173)

commit ba262c4
Author: Kevin H. Luu <[email protected]>
Date:   Wed Sep 4 20:33:12 2024 -0700

    [ci] Mark LoRA test as soft-fail (vllm-project#8160)

    Signed-off-by: kevin <[email protected]>

commit 4624d98
Author: Woosuk Kwon <[email protected]>
Date:   Wed Sep 4 20:31:48 2024 -0700

    [Misc] Clean up RoPE forward_native (vllm-project#8076)

commit 1afc931
Author: William Lin <[email protected]>
Date:   Wed Sep 4 17:35:36 2024 -0700

    [bugfix] >1.43 constraint for openai (vllm-project#8169)

    Co-authored-by: Michael Goin <[email protected]>

commit e01c2be
Author: Maureen McElaney <[email protected]>
Date:   Wed Sep 4 19:50:13 2024 -0400

    [Doc] [Misc] Create CODE_OF_CONDUCT.md (vllm-project#8161)

commit 32e7db2
Author: Simon Mo <[email protected]>
Date:   Wed Sep 4 16:34:27 2024 -0700

    Bump version to v0.6.0 (vllm-project#8166)

commit 008cf88
Author: Harsha vardhan manoj Bikki <[email protected]>
Date:   Wed Sep 4 16:33:43 2024 -0700

    [Neuron] Adding support for adding/ overriding neuron configuration a… (vllm-project#8062)

    Co-authored-by: Harsha Bikki <[email protected]>

commit 77d9e51
Author: Cody Yu <[email protected]>
Date:   Wed Sep 4 13:23:22 2024 -0700

    [MISC] Replace input token throughput with total token throughput (vllm-project#8164)

    Co-authored-by: Michael Goin <[email protected]>

commit e02ce49
Author: Kyle Mistele <[email protected]>
Date:   Wed Sep 4 15:18:13 2024 -0500

    [Feature] OpenAI-Compatible Tools API + Streaming for Hermes & Mistral models (vllm-project#5649)

    Co-authored-by: constellate <[email protected]>
    Co-authored-by: Kyle Mistele <[email protected]>

commit 561d6f8
Author: Woosuk Kwon <[email protected]>
Date:   Wed Sep 4 13:05:50 2024 -0700

    [CI] Change test input in Gemma LoRA test (vllm-project#8163)

commit d1dec64
Author: alexeykondrat <[email protected]>
Date:   Wed Sep 4 14:57:54 2024 -0400

    [CI/Build][ROCm] Enabling LoRA tests on ROCm (vllm-project#7369)

    Co-authored-by: Simon Mo <[email protected]>

commit 2ad2e56
Author: Cody Yu <[email protected]>
Date:   Wed Sep 4 11:53:25 2024 -0700

    [MISC] Consolidate FP8 kv-cache tests (vllm-project#8131)

commit d331156
Author: wnma <[email protected]>
Date:   Wed Sep 4 18:55:37 2024 +0800

    [Bugfix] remove post_layernorm in siglip (vllm-project#8106)

commit ccd7207
Author: TimWang <[email protected]>
Date:   Wed Sep 4 14:17:05 2024 +0800

    chore: Update check-wheel-size.py to read MAX_SIZE_MB from env (vllm-project#8103)

commit 855c262
Author: Cyrus Leung <[email protected]>
Date:   Wed Sep 4 13:22:17 2024 +0800

    [Frontend] Multimodal support in offline chat (vllm-project#8098)

commit 2be8ec6
Author: Peter Salas <[email protected]>
Date:   Tue Sep 3 21:38:21 2024 -0700

    [Model] Add Ultravox support for multiple audio chunks (vllm-project#7963)

commit e16fa99
Author: Dipika Sikka <[email protected]>
Date:   Tue Sep 3 22:12:41 2024 -0400

    [Misc] Update fbgemmfp8 to use `vLLMParameters` (vllm-project#7972)

    Co-authored-by: Michael Goin <[email protected]>

commit 61f4a93
Author: Woosuk Kwon <[email protected]>
Date:   Tue Sep 3 18:35:33 2024 -0700

    [TPU][Bugfix] Use XLA rank for persistent cache path (vllm-project#8137)

commit d4db9f5
Author: Nick Hill <[email protected]>
Date:   Tue Sep 3 17:57:41 2024 -0700

    [Benchmark] Add `--async-engine` option to benchmark_throughput.py (vllm-project#7964)

commit 2188a60
Author: Dipika Sikka <[email protected]>
Date:   Tue Sep 3 17:21:44 2024 -0400

    [Misc] Update `GPTQ` to use `vLLMParameters` (vllm-project#7976)

commit dc0b606
Author: Simon Mo <[email protected]>
Date:   Tue Sep 3 14:11:42 2024 -0700

    [CI] Change PR remainder to avoid at-mentions (vllm-project#8134)

commit 0af3abe
Author: Woosuk Kwon <[email protected]>
Date:   Tue Sep 3 13:29:24 2024 -0700

    [TPU][Bugfix] Fix next_token_ids shape (vllm-project#8128)

commit f1575dc
Author: Kevin H. Luu <[email protected]>
Date:   Tue Sep 3 13:25:09 2024 -0700

    [ci] Fix GHA workflow  (vllm-project#8129)

    Signed-off-by: kevin <[email protected]>

commit c02638e
Author: tomeras91 <[email protected]>
Date:   Tue Sep 3 22:37:08 2024 +0300

    [CI/Build] make pip install vllm work in macos (for import only) (vllm-project#8118)

commit 652c83b
Author: Antoni Baum <[email protected]>
Date:   Tue Sep 3 12:28:25 2024 -0700

    [Misc] Raise a more informative exception in add/remove_logger (vllm-project#7750)

commit 6d646d0
Author: Alexander Matveev <[email protected]>
Date:   Tue Sep 3 14:50:29 2024 -0400

    [Core] Optimize Async + Multi-step (vllm-project#8050)

commit 95a178f
Author: Kevin H. Luu <[email protected]>
Date:   Tue Sep 3 11:32:27 2024 -0700

    [CI] Only PR reviewers/committers can trigger CI on PR (vllm-project#8124)

    Signed-off-by: kevin <[email protected]>

commit bd852f2
Author: Cody Yu <[email protected]>
Date:   Tue Sep 3 10:49:18 2024 -0700

    [Performance] Enable chunked prefill and prefix caching together (vllm-project#8120)

    Co-authored-by: Tao He <[email protected]>
    Co-authored-by: Juelianqvq <[email protected]>

commit ec26653
Author: Isotr0py <[email protected]>
Date:   Tue Sep 3 21:37:52 2024 +0800

    [Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (vllm-project#8061)

commit 0fbc669
Author: Woosuk Kwon <[email protected]>
Date:   Mon Sep 2 20:35:42 2024 -0700

    [Bugfix] Fix single output condition in output processor (vllm-project#7881)

commit 6e36f4f
Author: wang.yuqi <[email protected]>
Date:   Tue Sep 3 05:20:12 2024 +0800

    improve chunked prefill performance

    [Bugfix] Fix vllm-project#7592 vllm 0.5.4 enable_chunked_prefill throughput is slightly lower than 0.5.3~0.5.0. (vllm-project#7874)

commit dd2a6a8
Author: Isotr0py <[email protected]>
Date:   Mon Sep 2 23:48:56 2024 +0800

    [Bugfix] Fix internlm2 tensor parallel inference (vllm-project#8055)

commit 4ca65a9
Author: Isotr0py <[email protected]>
Date:   Mon Sep 2 20:43:26 2024 +0800

    [Core][Bugfix] Accept GGUF model without .gguf extension (vllm-project#8056)

commit e2b2aa5
Author: Woosuk Kwon <[email protected]>
Date:   Sun Sep 1 23:09:46 2024 -0700

    [TPU] Align worker index with node boundary (vllm-project#7932)

commit e6a26ed
Author: Lily Liu <[email protected]>
Date:   Sun Sep 1 21:23:29 2024 -0700

    [SpecDecode][Kernel] Flashinfer Rejection Sampling (vllm-project#7244)

commit f8d6014
Author: Shawn Tan <[email protected]>
Date:   Sun Sep 1 21:37:18 2024 -0400

    [Model] Add Granite model (vllm-project#7436)

    Co-authored-by: Nick Hill <[email protected]>

commit 5b86b19
Author: Roger Wang <[email protected]>
Date:   Sun Sep 1 14:46:57 2024 -0700

    [Misc] Optional installation of audio related packages (vllm-project#8063)

commit 5231f08
Author: Roger Wang <[email protected]>
Date:   Sat Aug 31 16:35:53 2024 -0700

    [Frontend][VLM] Add support for multiple multi-modal items (vllm-project#8049)

commit 8423aef
Author: Robert Shaw <[email protected]>
Date:   Sat Aug 31 15:44:03 2024 -0400

    [BugFix][Core] Multistep Fix Crash on Request Cancellation (vllm-project#8059)

commit 4f5d844
Author: Nicolò Lucchesi <[email protected]>
Date:   Sat Aug 31 09:27:58 2024 +0200

    [Bugfix] Fix ModelScope models in v0.5.5 (vllm-project#8037)

commit d05f0a9
Author: Cyrus Leung <[email protected]>
Date:   Sat Aug 31 13:26:55 2024 +0800

    [Bugfix] Fix import error in Phi-3.5-MoE (vllm-project#8052)

commit 622f8ab
Author: Pavani Majety <[email protected]>
Date:   Fri Aug 30 22:18:50 2024 -0700

    [Bugfix] bugfix and add model test for flashinfer fp8 kv cache. (vllm-project#8013)

commit 1248e85
Author: Wenxiang <[email protected]>
Date:   Sat Aug 31 03:42:57 2024 +0800

    [Model] Adding support for MSFT Phi-3.5-MoE (vllm-project#7729)

    Co-authored-by: Your Name <[email protected]>
    Co-authored-by: Zeqi Lin <[email protected]>
    Co-authored-by: Zeqi Lin <[email protected]>

commit 2684efc
Author: Woosuk Kwon <[email protected]>
Date:   Fri Aug 30 09:01:26 2024 -0700

    [TPU][Bugfix] Fix tpu type api (vllm-project#8035)

commit 058344f
Author: Kaunil Dhruv <[email protected]>
Date:   Fri Aug 30 08:21:02 2024 -0700

    [Frontend]-config-cli-args (vllm-project#7737)

    Co-authored-by: Cyrus Leung <[email protected]>
    Co-authored-by: Kaunil Dhruv <[email protected]>

commit 98cef6a
Author: Cyrus Leung <[email protected]>
Date:   Fri Aug 30 23:20:34 2024 +0800

    [Core] Increase default `max_num_batched_tokens` for multimodal models (vllm-project#8028)

commit f97be32
Author: Jungho Christopher Cho <[email protected]>
Date:   Sat Aug 31 00:19:27 2024 +0900

    [VLM][Model] TP support for ViTs (vllm-project#7186)

    Co-authored-by: Roger Wang <[email protected]>
    Co-authored-by: Roger Wang <[email protected]>

commit afd39a4
Author: Cyrus Leung <[email protected]>
Date:   Fri Aug 30 23:03:28 2024 +0800

    [Bugfix] Fix import error in Exaone model (vllm-project#8034)

commit 2148441
Author: Richard Liu <[email protected]>
Date:   Fri Aug 30 00:27:40 2024 -0700

    [TPU] Support single and multi-host TPUs on GKE (vllm-project#7613)

commit dc13e99
Author: Yohan Na <[email protected]>
Date:   Fri Aug 30 15:34:20 2024 +0900

    [MODEL] add Exaone model support (vllm-project#7819)

commit 34a0e96
Author: Avshalom Manevich <[email protected]>
Date:   Fri Aug 30 11:11:39 2024 +0700

    [Kernel] changing fused moe kernel chunk size default to 32k (vllm-project#7995)

commit 80c7b08
Author: Woosuk Kwon <[email protected]>
Date:   Thu Aug 29 19:35:29 2024 -0700

    [TPU] Async output processing for TPU (vllm-project#8011)

commit 428dd14
Author: afeldman-nm <[email protected]>
Date:   Thu Aug 29 22:19:08 2024 -0400

    [Core] Logprobs support in Multi-step (vllm-project#7652)

commit 4abed65
Author: Cyrus Leung <[email protected]>
Date:   Fri Aug 30 08:49:04 2024 +0800

    [VLM] Disallow overflowing `max_model_len` for multimodal models (vllm-project#7998)

commit 0c785d3
Author: Wei-Sheng Chin <[email protected]>
Date:   Thu Aug 29 16:48:11 2024 -0700

    Add more percentiles and latencies (vllm-project#7759)

commit 4664cea
Author: chenqianfzh <[email protected]>
Date:   Thu Aug 29 16:09:08 2024 -0700

    support bitsandbytes 8-bit and FP4 quantized models (vllm-project#7445)

commit 257afc3
Author: Harsha vardhan manoj Bikki <[email protected]>
Date:   Thu Aug 29 13:58:14 2024 -0700

    [Neuron] Adding support for context-lenght, token-gen buckets. (vllm-project#7885)

    Co-authored-by: Harsha Bikki <[email protected]>

commit 86a677d
Author: Dipika Sikka <[email protected]>
Date:   Thu Aug 29 16:46:55 2024 -0400

    [misc] update tpu int8 to use new vLLM Parameters (vllm-project#7973)

commit d78789a
Author: Isotr0py <[email protected]>
Date:   Fri Aug 30 03:54:49 2024 +0800

    [Bugfix] Fix incorrect vocal embedding shards for GGUF model in tensor parallelism (vllm-project#7954)

commit c334b18
Author: kushanam <[email protected]>
Date:   Thu Aug 29 12:15:04 2024 -0700

    extend cuda graph size for H200 (vllm-project#7894)

    Co-authored-by: youkaichao <[email protected]>

commit 6b34215
Author: Pavani Majety <[email protected]>
Date:   Thu Aug 29 11:53:11 2024 -0700

    [Core][Kernels] Enable FP8 KV Cache with Flashinfer backend.  + BugFix for kv_cache_dtype=auto (vllm-project#7985)

    Co-authored-by: Simon Mo <[email protected]>
    Co-authored-by: Cody Yu <[email protected]>

commit 3f60f22
Author: Alexander Matveev <[email protected]>
Date:   Thu Aug 29 14:18:26 2024 -0400

    [Core] Combine async postprocessor and multi-step (vllm-project#7921)

commit f205c09
Author: Jonas M. Kübler <[email protected]>
Date:   Thu Aug 29 07:18:13 2024 +0200

    [Bugfix] Unify rank computation across regular decoding and speculative decoding (vllm-project#7899)

commit ef99a78
Author: youkaichao <[email protected]>
Date:   Wed Aug 28 21:27:06 2024 -0700

    Revert "[Core][Kernels] Use FlashInfer backend for FP8 KV Cache when available." (vllm-project#7982)

commit 74d5543
Author: Peter Salas <[email protected]>
Date:   Wed Aug 28 20:24:31 2024 -0700

    [VLM][Core] Fix exceptions on ragged NestedTensors (vllm-project#7974)

commit a7f65c2
Author: youkaichao <[email protected]>
Date:   Wed Aug 28 17:32:26 2024 -0700

    [torch.compile] remove reset (vllm-project#7975)

commit 4289cad
Author: Nick Hill <[email protected]>
Date:   Wed Aug 28 17:22:43 2024 -0700

    [Frontend] Minor optimizations to zmq decoupled front-end (vllm-project#7957)

    Co-authored-by: Robert Shaw <rshaw@neuralmagic>

commit af59df0
Author: Michael Goin <[email protected]>
Date:   Wed Aug 28 19:19:17 2024 -0400

    Remove faulty Meta-Llama-3-8B-Instruct-FP8.yaml lm-eval test (vllm-project#7961)

commit ce6bf3a
Author: youkaichao <[email protected]>
Date:   Wed Aug 28 16:10:12 2024 -0700

    [torch.compile] avoid Dynamo guard evaluation overhead (vllm-project#7898)

    Co-authored-by: Woosuk Kwon <[email protected]>

commit 3cdfe1f
Author: bnellnm <[email protected]>
Date:   Wed Aug 28 18:11:49 2024 -0400

    [Bugfix] Make torch registration of punica ops optional (vllm-project#7970)

commit fdd9daa
Author: Mor Zusman <[email protected]>
Date:   Thu Aug 29 01:06:52 2024 +0300

    [Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (vllm-project#7651)

commit 8c56e57
Author: Stas Bekman <[email protected]>
Date:   Wed Aug 28 13:54:23 2024 -0700

    [Doc] fix 404 link (vllm-project#7966)

commit eeffde1
Author: Woosuk Kwon <[email protected]>
Date:   Wed Aug 28 13:10:21 2024 -0700

    [TPU] Upgrade PyTorch XLA nightly (vllm-project#7967)

commit e5697d1
Author: rasmith <[email protected]>
Date:   Wed Aug 28 14:37:47 2024 -0500

    [Kernel] [Triton] [AMD] Adding Triton implementations awq_dequantize and awq_gemm to support AWQ (vllm-project#7386)

commit b98cc28
Author: Pavani Majety <[email protected]>
Date:   Wed Aug 28 10:01:22 2024 -0700

    [Core][Kernels] Use FlashInfer backend for FP8 KV Cache when available. (vllm-project#7798)

    Co-authored-by: Simon Mo <[email protected]>

commit ef9baee
Author: Cyrus Leung <[email protected]>
Date:   Wed Aug 28 23:11:18 2024 +0800

    [Bugfix][VLM] Fix incompatibility between vllm-project#7902 and vllm-project#7230 (vllm-project#7948)

commit 98c12cf
Author: Stas Bekman <[email protected]>
Date:   Wed Aug 28 05:12:32 2024 -0700

    [Doc] fix the autoAWQ example (vllm-project#7937)

commit f52a43a
Author: youkaichao <[email protected]>
Date:   Wed Aug 28 01:27:07 2024 -0700

    [ci][test] fix pp test failure (vllm-project#7945)

commit e358053
Author: Cody Yu <[email protected]>
Date:   Wed Aug 28 00:36:31 2024 -0700

    [Performance] Enable chunked prefill and prefix caching together (vllm-project#7753)
Jeffwan pushed a commit to aibrix/vllm that referenced this pull request Sep 19, 2024
siddharth9820 pushed a commit to axonn-ai/vllm that referenced this pull request Sep 30, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tpu Related to Google TPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TPU] Make sure worker index aligns with node boundary
4 participants