Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: vLLM on TPU does not support --pipeline-parallel-size with Ray #11260

Open
1 task done
totorochina opened this issue Dec 17, 2024 · 1 comment
Open
1 task done
Labels
bug Something isn't working ray anything related with ray tpu Related to Google TPUs

Comments

@totorochina
Copy link

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.6.0.dev20241126+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: glibc-2.35

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-1013-gcp-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             112
On-line CPU(s) list:                0-111
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7B13
CPU family:                         25
Model:                              1
Thread(s) per core:                 2
Core(s) per socket:                 56
Socket(s):                          1
Stepping:                           0
BogoMIPS:                           4899.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          1.8 MiB (56 instances)
L1i cache:                          1.8 MiB (56 instances)
L2 cache:                           28 MiB (56 instances)
L3 cache:                           224 MiB (7 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-111
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pyzmq==26.2.0
[pip3] torch==2.6.0.dev20241126+cpu
[pip3] torch-xla==2.6.0+git39e67b5
[pip3] torchvision==0.20.0.dev20241126+cpu
[pip3] transformers==4.47.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pyzmq                     26.2.0                   pypi_0    pypi
[conda] torch                     2.6.0.dev20241126+cpu          pypi_0    pypi
[conda] torch-xla                 2.6.0+git39e67b5          pypi_0    pypi
[conda] torchvision               0.20.0.dev20241126+cpu          pypi_0    pypi
[conda] transformers              4.47.0                   pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.4.post2.dev378+g69ba344d
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

LD_LIBRARY_PATH=/home/ext_hzchen_google_com/miniconda3/envs/tpu/lib/python3.10/site-packages/cv2/../../lib64:

Model Input Dumps

No response

🐛 Describe the bug

I setup v5e-32(8 hosts/4 chips each) and started a ray cluster.

output of ray status

======== Autoscaler status: 2024-12-17 11:18:22.317629 ========
Node status
---------------------------------------------------------------
Active:
 1 node_719fd8c930dcd8b932914ebb34d70d16323b468e1f93094a78d50f75
 1 node_ac79abf69175ce6f927b5317fe292d22aea9ac4e3c224a7ba42ca6d3
 1 node_23a15d4de28063c4a04865b963c54c0a8f29d8e928c298e4021a146b
 1 node_153c21365890656c54742efe22e675b89127db7998084e482c8260c6
 1 node_cedcf4265be52d297b3d95bab832675688d8241360c819bd9bd63de7
 1 node_8993f15ab992801a04a70b5b7c4c691158b949bd7da98c35154aa372
 1 node_3723c73cee4bc754265308f5a9f384b493e06b7d76860a739e9ddfb4
 1 node_e2823d348a9370bfe108d635330b339b48e4847ce0608af311cd75e2
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/880.0 CPU
 0.0/32.0 TPU
 0.0/1.0 TPU-v5litepod-32-head
 0B/1.01TiB memory
 0B/449.11GiB object_store_memory
 0.0/8.0 tpuvm-01

Demands:
 (no resource demands)

I start serving with below command

vllm serve mistralai/Pixtral-Large-Instruct-2411 --config-format mistral --load-format mistral --tokenizer-mode mistral --num-scheduler-steps 2 --swap-space 4 --max-model-len=1024 --limit_mm_per_prompt 'image=10' --tensor-parallel-size 4 --pipeline-parallel-size 8 --disable-log-requests --dtype=bfloat16

get below error messages

INFO 12-17 09:37:22 api_server.py:643] vLLM API server version 0.6.4.post2.dev378+g69ba344d
INFO 12-17 09:37:22 api_server.py:644] args: Namespace(subparser='serve', model_tag='mistralai/Pixtral-Large-Instruct-2411', config='', host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='mistralai/Pixtral-Large-Instruct-2411', task='auto', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='mistral', trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format='mistral', config_format='mistral', dtype='bfloat16', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=1024, guided_decoding_backend='xgrammar', logits_processor_pattern=None, distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=8, tensor_parallel_size=4, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=0, swap_space=4.0, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=None, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt={'image': 10}, mm_processor_kwargs=None, mm_cache_preprocessor=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=2, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', disable_log_requests=True, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, dispatch_function=<function serve at 0x7fa9d7e42320>)
INFO 12-17 09:37:24 config.py:1938] Downcasting torch.float32 to torch.bfloat16.
INFO 12-17 09:37:31 config.py:451] This model supports multiple tasks: {'embed', 'reward', 'score', 'classify', 'generate'}. Defaulting to 'generate'.
WARNING 12-17 09:37:31 config.py:569] Async output processing can not be enabled with pipeline parallel
2024-12-17 09:37:31,835	INFO worker.py:1636 -- Connecting to existing Ray cluster at address: 10.164.0.34:6379...
2024-12-17 09:37:31,898	INFO worker.py:1812 -- Connected to Ray cluster. View the dashboard at �[1m�[32m127.0.0.1:8265 �[39m�[22m
INFO 12-17 09:37:32 llm_engine.py:249] Initializing an LLM engine (v0.6.4.post2.dev378+g69ba344d) with config: model='mistralai/Pixtral-Large-Instruct-2411', speculative_config=None, tokenizer='mistralai/Pixtral-Large-Instruct-2411', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1024, download_dir=None, load_format=mistral, tensor_parallel_size=4, pipeline_parallel_size=8, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=None, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Pixtral-Large-Instruct-2411, num_scheduler_steps=2, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=False, mm_cache_preprocessor=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":2,"backend":"openxla","candidate_compile_sizes":[],"compile_sizes":[],"capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False, 
�[36m(RayWorkerWrapper pid=150332, ip=10.164.0.35)�[0m WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m INFO 12-17 09:40:07 tpu.py:27] Cannot use _Backend.FLASH_ATTN backend on TPU.
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m INFO 12-17 09:40:07 selector.py:163] Using Pallas backend.
INFO 12-17 09:40:08 tpu.py:27] Cannot use _Backend.FLASH_ATTN backend on TPU.
INFO 12-17 09:40:08 selector.py:163] Using Pallas backend.
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467] Error executing method init_device. This might cause deadlock in distributed execution.
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467] Traceback (most recent call last):
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/worker/worker_base.py", line 459, in execute_method
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     return executor(*args, **kwargs)
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/worker/tpu_worker.py", line 67, in init_device
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     ensure_model_parallel_initialized(
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 1101, in ensure_model_parallel_initialized
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     initialize_model_parallel(tensor_model_parallel_size,
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 1045, in initialize_model_parallel
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     _TP = init_model_parallel_group(group_ranks,
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 876, in init_model_parallel_group
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     return GroupCoordinator(
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 233, in __init__
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/device_communicators/tpu_communicator.py", line 43, in __init__
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     local_rank = global_rank % local_world_size
�[36m(RayWorkerWrapper pid=148818, ip=10.164.0.37)�[0m ERROR 12-17 09:40:09 worker_base.py:467] ZeroDivisionError: integer division or modulo by zero
ERROR 12-17 09:40:09 worker_base.py:467] Error executing method init_device. This might cause deadlock in distributed execution.
ERROR 12-17 09:40:09 worker_base.py:467] Traceback (most recent call last):
ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/worker/worker_base.py", line 459, in execute_method
ERROR 12-17 09:40:09 worker_base.py:467]     return executor(*args, **kwargs)
ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/worker/tpu_worker.py", line 67, in init_device
ERROR 12-17 09:40:09 worker_base.py:467]     ensure_model_parallel_initialized(
ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 1101, in ensure_model_parallel_initialized
ERROR 12-17 09:40:09 worker_base.py:467]     initialize_model_parallel(tensor_model_parallel_size,
ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 1045, in initialize_model_parallel
ERROR 12-17 09:40:09 worker_base.py:467]     _TP = init_model_parallel_group(group_ranks,
ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 876, in init_model_parallel_group
ERROR 12-17 09:40:09 worker_base.py:467]     return GroupCoordinator(
ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 233, in __init__
ERROR 12-17 09:40:09 worker_base.py:467]     self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/device_communicators/tpu_communicator.py", line 43, in __init__
ERROR 12-17 09:40:09 worker_base.py:467]     local_rank = global_rank % local_world_size
ERROR 12-17 09:40:09 worker_base.py:467] ZeroDivisionError: integer division or modulo by zero
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ext_hzchen_google_com/miniconda3/envs/tpu/bin/vllm", line 33, in <module>
[rank0]:     sys.exit(load_entry_point('vllm', 'console_scripts', 'vllm')())
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/scripts.py", line 201, in main
[rank0]:     args.dispatch_function(args)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/scripts.py", line 42, in serve
[rank0]:     uvloop.run(run_server(args))
[rank0]:   File "/home/ext_hzchen_google_com/miniconda3/envs/tpu/lib/python3.10/site-packages/uvloop/__init__.py", line 82, in run
[rank0]:     return loop.run_until_complete(wrapper())
[rank0]:   File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
[rank0]:   File "/home/ext_hzchen_google_com/miniconda3/envs/tpu/lib/python3.10/site-packages/uvloop/__init__.py", line 61, in wrapper
[rank0]:     return await main
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/entrypoints/openai/api_server.py", line 667, in run_server
[rank0]:     async with build_async_engine_client(args) as engine_client:
[rank0]:   File "/home/ext_hzchen_google_com/miniconda3/envs/tpu/lib/python3.10/contextlib.py", line 199, in __aenter__
[rank0]:     return await anext(self.gen)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/entrypoints/openai/api_server.py", line 117, in build_async_engine_client
[rank0]:     async with build_async_engine_client_from_engine_args(
[rank0]:   File "/home/ext_hzchen_google_com/miniconda3/envs/tpu/lib/python3.10/contextlib.py", line 199, in __aenter__
[rank0]:     return await anext(self.gen)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/entrypoints/openai/api_server.py", line 150, in build_async_engine_client_from_engine_args
[rank0]:     engine_client = build_engine()
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/engine/async_llm_engine.py", line 707, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/engine/async_llm_engine.py", line 594, in __init__
[rank0]:     self.engine = self._engine_class(*args, **kwargs)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/engine/async_llm_engine.py", line 267, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/engine/llm_engine.py", line 288, in __init__
[rank0]:     self.model_executor = executor_class(vllm_config=vllm_config, )
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/executor/ray_tpu_executor.py", line 306, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/executor/ray_tpu_executor.py", line 39, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/executor/executor_base.py", line 36, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/executor/ray_tpu_executor.py", line 51, in _init_executor
[rank0]:     self._init_workers_ray(placement_group)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/executor/ray_tpu_executor.py", line 184, in _init_workers_ray
[rank0]:     self._run_workers("init_device")
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/executor/ray_tpu_executor.py", line 249, in _run_workers
[rank0]:     driver_worker_output = self.driver_worker.execute_method(
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/worker/worker_base.py", line 468, in execute_method
[rank0]:     raise e
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/worker/worker_base.py", line 459, in execute_method
[rank0]:     return executor(*args, **kwargs)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/worker/tpu_worker.py", line 67, in init_device
[rank0]:     ensure_model_parallel_initialized(
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 1101, in ensure_model_parallel_initialized
[rank0]:     initialize_model_parallel(tensor_model_parallel_size,
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 1045, in initialize_model_parallel
[rank0]:     _TP = init_model_parallel_group(group_ranks,
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 876, in init_model_parallel_group
[rank0]:     return GroupCoordinator(
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 233, in __init__
[rank0]:     self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
[rank0]:   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/device_communicators/tpu_communicator.py", line 43, in __init__
[rank0]:     local_rank = global_rank % local_world_size
[rank0]: ZeroDivisionError: integer division or modulo by zero
�[36m(RayWorkerWrapper pid=140502, ip=10.164.0.40)�[0m WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=144474, ip=10.164.0.36)�[0m INFO 12-17 09:40:09 tpu.py:27] Cannot use _Backend.FLASH_ATTN backend on TPU.�[32m [repeated 30x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)�[0m
�[36m(RayWorkerWrapper pid=144474, ip=10.164.0.36)�[0m INFO 12-17 09:40:09 selector.py:163] Using Pallas backend.�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467] Error executing method init_device. This might cause deadlock in distributed execution.�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467] Traceback (most recent call last):�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/worker/worker_base.py", line 459, in execute_method�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     return executor(*args, **kwargs)�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/worker/tpu_worker.py", line 67, in init_device�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     ensure_model_parallel_initialized(�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 1101, in ensure_model_parallel_initialized�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     initialize_model_parallel(tensor_model_parallel_size,�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 1045, in initialize_model_parallel�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     _TP = init_model_parallel_group(group_ranks,�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 876, in init_model_parallel_group�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     return GroupCoordinator(�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/parallel_state.py", line 233, in __init__�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     self.tpu_communicator = TpuCommunicator(group=self.cpu_group)�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]   File "/home/ext_hzchen_google_com/vllm/vllm/distributed/device_communicators/tpu_communicator.py", line 43, in __init__�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467]     local_rank = global_rank % local_world_size�[32m [repeated 30x across cluster]�[0m
�[36m(RayWorkerWrapper pid=149225, ip=10.164.0.38)�[0m ERROR 12-17 09:40:09 worker_base.py:467] ZeroDivisionError: integer division or modulo by zero�[32m [repeated 30x across cluster]�[0m

What's more, I can start the serving with --tensor-parallel-size 32 only without error, which may have performance impact.
Would like to know if this is work as intended or not.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@totorochina totorochina added the bug Something isn't working label Dec 17, 2024
@ruisearch42 ruisearch42 added the ray anything related with ray label Dec 20, 2024
@ruisearch42
Copy link
Collaborator

Hi @totorochina , I don't have an environment to conveniently repro and debug, but looks like the issue is caused by ZeroDivisionError in local_rank = global_rank % local_world_size.

And local_world_size = global_world_size // num_nodes

What's the corresponding value of global_world_size, num_nodes, local_world_size, do they make sense?

@ruisearch42 ruisearch42 added the tpu Related to Google TPUs label Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ray anything related with ray tpu Related to Google TPUs
Projects
None yet
Development

No branches or pull requests

2 participants