Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi instances of infer #744

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def to(self, device: str):
if isinstance(device, str):
self._device = device.upper()
self.request = None
self.request_dict.clear()
else:
logger.debug(f"device must be of type {str} but got {type(device)} instead")

Expand Down
8 changes: 6 additions & 2 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def __init__(
self.output_names = output_names

self.model = model
self.compiled_model = None
self.request = None
self.request_dict = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None

self._openvino_config = None
Expand Down Expand Up @@ -457,11 +459,11 @@ def compile(self):
cache_dir = Path(self.model_save_dir).joinpath("model_cache")
ov_config["CACHE_DIR"] = str(cache_dir)
logger.info(f"Setting OpenVINO CACHE_DIR to {str(cache_dir)}")
self.request = core.compile_model(self.model, self._device, ov_config)
self.compiled_model = core.compile_model(self.model, self._device, ov_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compiled_model does make more sense, but we have a lot of existing code that expects the requests property here, including integrations in other products, so if this is going to be changed there has to be a deprecation period with a DeprecationWarning for a while

# OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html
if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2:
logger.info(f"{self._device} SUPPORTED_PROPERTIES:")
_print_compiled_model_properties(self.request)
_print_compiled_model_properties(self.compiled_model)

def _reshape(
self,
Expand Down Expand Up @@ -500,6 +502,7 @@ def reshape(self, batch_size: int, sequence_length: int, height: int = None, wid
self.is_dynamic = True if batch_size == -1 and sequence_length == -1 else False
self.model = self._reshape(self.model, batch_size, sequence_length, height, width)
self.request = None
self.request_dict.clear()
return self

def half(self):
Expand All @@ -509,6 +512,7 @@ def half(self):
apply_moc_transformations(self.model, cf=False)
compress_model_transformation(self.model)
self.request = None
self.request_dict.clear()
return self

def eval(self):
Expand Down
33 changes: 26 additions & 7 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def update_pkv_precision(self, force_fp32=False):
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
self.request = None
self.request_dict.clear()

def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Expand Down Expand Up @@ -345,7 +346,7 @@ def normalized_config(self):
def compile(self):
if self.request is None:
super().compile()
self.request = self.request.create_infer_request()
self.request = self.compiled_model.create_infer_request()

def _make_stateful(self):
patch_stateful(self.config, self.model)
Expand Down Expand Up @@ -424,9 +425,14 @@ def prepare_inputs(
else:
# past_key_values are not used explicitly, instead they are handled inside the model
if past_key_values is None:
infer_req = self.request
if 'kwargs' in kwargs.keys():
tid = kwargs['kwargs']['tid']
if tid in self.request_dict:
infer_req = self.request_dict[tid]
# This is the first iteration in a sequence, reset all states
if self.request is not None:
self.request.reset_state()
if infer_req is not None:
infer_req.reset_state()
# Set initial value for the next beam_idx input that will be used at the current iteration
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
self.next_beam_idx = np.arange(batch_size, dtype=int)
Expand Down Expand Up @@ -473,6 +479,17 @@ def forward(
) -> CausalLMOutputWithPast:
self.compile()

if 'kwargs' in kwargs.keys():
tid = kwargs['kwargs']['tid']
if tid in self.request_dict:
infer_req = self.request_dict[tid]
else:
infer_req = self.compiled_model.create_infer_request()
self.request_dict[tid] = infer_req
else:
tid = -1
infer_req = self.request

inputs = self.prepare_inputs(
input_ids=input_ids,
attention_mask=attention_mask,
Expand All @@ -484,9 +501,11 @@ def forward(
if self._first_iter_beam_search:
inputs, duplication_indices = self._deduplicate_inputs(inputs)
# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
print(f'.... {tid} infer start ....\n')
infer_req.start_async(inputs, share_inputs=True)
infer_req.wait()
print(f'..... {tid} infer end .....\n')
logits = torch.from_numpy(infer_req.get_tensor("logits").data).to(self.device)
if self.stateful:
# Need a marker to differentiate the first generate iteration from the others in
# the first condition at the function beginning above.
Expand All @@ -497,7 +516,7 @@ def forward(
if not self.stateful:
if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
past_key_values = tuple(infer_req.get_tensor(key).data for key in self.key_value_output_names)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
Expand Down
Loading