Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] using deepspeed slower inference time #6818

Open
williamlin0518 opened this issue Dec 4, 2024 · 0 comments
Open

[BUG] using deepspeed slower inference time #6818

williamlin0518 opened this issue Dec 4, 2024 · 0 comments
Labels
bug Something isn't working inference

Comments

@williamlin0518
Copy link

Describe the bug
Using DeepSpeed inference with tensor parallelism and ZeRO optimization for NVLM model is showing slower performance compared to HuggingFace baseline implementation.

also, no matter tp_size set 2, 4, 6, inference time get the same in deepspeed

I ask same 6 quesiotn ( I have 6 A100 gpu) and ensure get same answer, but result get slower although deepspeed can ask 6 question parallel
deepspeed --num_gpus=6 main.py,
left is baseline and right is using deepspeed
Image

although deepspeed's gpu utilization is very high (average near 80%) and cost more memory(below figure)
Image

baseline gpu utilization and memory cost
Image
To Reproduce

  1. Running identical inference workload (6 questions) on both implementations

  2. Using tensor parallelism and ZeRO optimization in DeepSpeed version ( deepspeed --num_gpus=6 main.py)

  3. Using default HuggingFace implementation as baseline

Expected behavior
DeepSpeed with tensor parallelism and ZeRO optimization should show faster inference times.

deepspeed code

from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModel
from transformers.integrations import HfDeepSpeedConfig
import deepspeed
import os
import torch
import time
import torch.distributed as dist
from collections import defaultdict

os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To avoid warnings about parallelism in tokenizers
# distributed setup
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()

model_name = "nvidia/NVLM-D-72B"

ds_config = {

    # "replace_with_kernel_inject":True
    "bf16": {
        "enabled": True
    },
    "tensor_parallel": {
            "tp_size": 6
        },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": True,
        "contiguous_gradients": True,
    },
    "steps_per_print": 2000,
    # "train_batch_size": train_batch_size,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": False
}
# fmt: on

# next line instructs transformers to partition the model directly over multiple gpus using
# deepspeed.zero.Init when model's `from_pretrained` method is called.
#
# **it has to be run before loading the model AutoModelForSeq2SeqLM.from_pretrained(model_name)**
#
# deepspeed.zero.Init
# otherwise the model will first be loaded normally and only partitioned at forward time which is
# less efficient and when there is little CPU RAM may fail
dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive

# now a model can be loaded.

model = AutoModel.from_pretrained(model_name,trust_remote_code=True).eval()

# initialise Deepspeed ZeRO and store only the engine object
ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()  # inference

prompts = [
    "Is this review positive or negative? Review: The customer service was absolutely terrible and I'll never shop here again",
    "Is this review positive or negative? Review: Great product, exactly what I needed and arrived on time",
    "Is this review positive or negative? Review: Don't waste your money, it broke after two uses",
    "Is this review positive or negative? Review: Amazing quality and the price can't be beat, highly recommend",
    "Is this review positive or negative? Review: Mediocre at best, there are better options out there",
    "Is this review positive or negative? Review: Beyond disappointed with this purchase, complete garbage"
]
rank = torch.distributed.get_rank()
text_in = prompts[rank]

tokenizer = AutoTokenizer.from_pretrained(model_name)



# Synchronize before starting timing
torch.distributed.barrier()
start_time = time.time()

response, _ = ds_engine.module.chat(
    tokenizer,
    None,
    text_in,
    {"max_new_tokens": 1024},
    history=None,
    return_history=True
)
if rank == 0:
    end_time = time.time()
    total_time = end_time - start_time
# Gather responses from all ranks
all_responses = [None] * world_size
dist.all_gather_object(all_responses, response)

# Calculate on rank 0
if rank == 0:    
    # Calculate total characters generated
    total_chars = sum(len(resp) for resp in all_responses)
    throughput = total_chars / total_time
    
    print(f"\nTotal characters generated: {total_chars}")
    print(f"Total time taken: {total_time:.2f} seconds")
    print(f"Throughput: {throughput:.2f} characters/second")
    
    # Print individual responses
    for i, resp in enumerate(all_responses):
        print(f"\nGPU {i} response ({len(resp)} chars):\n{resp}")

baseline code template

import torch
from transformers import AutoTokenizer, AutoModel
import math
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode


def split_model():
    device_map = {}
    world_size = torch.cuda.device_count()
    num_layers = 80
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.lm_head'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map

path = "nvidia/NVLM-D-72B"
device_map = split_model()
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=False,
    trust_remote_code=True,
    device_map=device_map).eval()

print(model)

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
generation_config = dict(max_new_tokens=1024, do_sample=False)

# pure-text conversation
question = 'Hello, who are you?'
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')

System info (please complete the following information):

  • OS: [Ubuntu 22.04.3 LTS]
  • GPU count and types [one nodes with x6 A100s]
@williamlin0518 williamlin0518 added bug Something isn't working inference labels Dec 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

No branches or pull requests

1 participant