Skip to content

Commit

Permalink
[torch.compile] add dynamo time tracking (vllm-project#11005)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Dec 9, 2024
1 parent af7c4a9 commit d1c2e15
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
6 changes: 6 additions & 0 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,13 @@ def configure_post_pass(self):

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter.num_graphs_seen += 1
from .monitor import torch_compile_start_time
dynamo_time = time.time() - torch_compile_start_time
logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
self.compilation_configs.compilation_time += dynamo_time

# we control the compilation process, each instance can only be
# called once
Expand Down
6 changes: 3 additions & 3 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _support_torch_compile(

def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = \
Expand All @@ -157,9 +158,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)

if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
start_monitoring_torch_compile(vllm_config.compilation_config)

cls.__init__ = __init__

def __call__(self, *args, **kwargs):
Expand All @@ -186,6 +184,8 @@ def __call__(self, *args, **kwargs):
raise ValueError(
"Unsupported dynamic dimensions"
f" {dims} for argument {k} with type {type(arg)}.")
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config.compilation_config)

# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
Expand Down
9 changes: 7 additions & 2 deletions vllm/compilation/monitor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import time

from vllm.config import CompilationConfig, CompilationLevel
from vllm.logger import init_logger

logger = init_logger(__name__)

torch_compile_start_time: float = 0.0


def start_monitoring_torch_compile(compilation_config: CompilationConfig):
pass
global torch_compile_start_time
torch_compile_start_time = time.time()


def end_monitoring_torch_compile(compilation_config: CompilationConfig):
if compilation_config.level == CompilationLevel.PIECEWISE:
logger.info("graph compilation takes %.2f s in total",
logger.info("torch.compile takes %.2f s in total",
compilation_config.compilation_time)

0 comments on commit d1c2e15

Please sign in to comment.