Skip to content

Commit

Permalink
Allow multiple inference engines in single script (#4384)
Browse files Browse the repository at this point in the history
* add destroy method to InferenceEngine

Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
mrwyattii and jeffra authored Sep 22, 2023
1 parent 0e0748c commit 4c35880
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from ..module_inject.replace_policy import generic_policies
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor
from ..ops.transformer.inference.ds_attention import DeepSpeedSelfAttention
from ..model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference

DS_INFERENCE_ENABLED = False
from torch import nn
Expand All @@ -50,6 +52,13 @@ def __init__(self, model, config):

super().__init__()

# Have to import here because inference_module is a global, but python
# globals only work at the module level and will not be updated unless
# we import it each time we init a new inference engine.
from ..model_implementations.transformers.ds_transformer import inference_module
if inference_module is not None:
self.destroy()

self.module = model
self._config = config

Expand Down Expand Up @@ -174,6 +183,17 @@ def __init__(self, model, config):
# Check if local CUDA graphs can be created in replacement modules
self.local_cuda_graph = self._local_cuda_graph_used(self.module)

def destroy(self):
# Have to import here because inference_module is a global, but python
# globals only work at the module level and will not be updated unless
# we import it each time we init a new inference engine.
from ..model_implementations.transformers.ds_transformer import inference_module
DeepSpeedTransformerInference.layer_id = 0
DeepSpeedSelfAttention.num_layers = 0
if inference_module is not None:
inference_module.release_workspace()
inference_module = None

def profile_model_time(self, use_cuda_events=True):
if not self.model_profile_enabled and not self._config.enable_cuda_graph:
self.module.register_forward_pre_hook(self._pre_forward_hook)
Expand Down

0 comments on commit 4c35880

Please sign in to comment.