Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[yaml] add RunInference support with VertexAI #33406

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ def underlying_handler(self):
@staticmethod
def default_preprocess_fn():
raise ValueError(
'Handler does not implement a default preprocess '
'Model Handler does not implement a default preprocess '
'method. Please define a preprocessing method using the '
'\'preprocess\' tag.')
'\'preprocess\' tag. This is required in most cases because '
'most models will have a different input shape, so the model '
'cannot generalize how the input Row should be transformed. For '
'an example preprocess method, see VertexAIModelHandlerJSONProvider')

def _preprocess_fn_internal(self):
return lambda row: (row, self._preprocess_fn(row))
Expand Down Expand Up @@ -134,17 +137,34 @@ def __init__(
project: str,
location: str,
preprocess: Dict[str, str],
postprocess: Optional[Dict[str, str]] = None,
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
env_vars: Optional[Dict[str, Any]] = None,
postprocess: Optional[Dict[str, str]] = None):
env_vars: Optional[Dict[str, Any]] = None):
"""
ModelHandler for Vertex AI.

This Model Handler can be used with RunInference to load a model hosted
on VertexAI. Every model that is hosted on VertexAI should have three
distinct, required, parameters - `endpoint_id`, `project` and `location`.
These parameters tell the Model Handler how to access the model's endpoint
so that input data can be sent using an API request, and inferences can be
received as a response.

This Model Handler also required a `preprocess` function to be defined.
Polber marked this conversation as resolved.
Show resolved Hide resolved
Preprocessing and Postprocessing are described in more detail in the
RunInference docs:
https://beam.apache.org/releases/yamldoc/current/#runinference

Every model will have a unique input, but all requests should be
JSON-formatted. For example, most language models such as Llama and Gemma
expect a JSON with the key "prompt" (among other optional keys). In Python,
JSON can be expressed as a dictionary.

For example: ::

- type: RunInference
Expand All @@ -159,10 +179,24 @@ def __init__(
preprocess:
callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}'

In the above example, which mimics a call to a Llama 3 model hosted on
VertexAI, the preprocess function (in this case a lambda) takes in a Beam
Row with a single field, "prompt", and maps it to a dict with the same
field. It also specifies an optional parameter, "max_tokens", that tells the
model the allowed token size (in this case input + output token size).

Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to query.
project: the GCP project name where the endpoint is deployed.
location: the GCP location where the endpoint is deployed.
preprocess: A python callable, defined either inline, or using a file,
that is invoked on the input row before sending to the model to be
loaded by this ModelHandler. This parameter is required by the
`VertexAIModelHandlerJSON` ModelHandler.
postprocess: A python callable, defined either inline, or using a file,
that is invoked on the PredictionResult output by the ModelHandler
before parsing into the output Beam Row under the field name defined
by the inference_tag.
experiment: Experiment label to apply to the
queries. See
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
Expand All @@ -183,14 +217,6 @@ def __init__(
max_batch_duration_secs: The maximum amount of time to buffer
a batch before emitting; used in streaming contexts.
env_vars: Environment variables.
preprocess: A python callable, defined either inline, or using a file,
that is invoked on the input row before sending to the model to be
loaded by this ModelHandler. This parameter is required by the
`VertexAIModelHandlerJSON` ModelHandler.
postprocess: A python callable, defined either inline, or using a file,
that is invoked on the PredictionResult output by the ModelHandler
before parsing into the output Beam Row under the field name defined
by the inference_tag.
"""

try:
Expand Down Expand Up @@ -222,10 +248,6 @@ def inference_output_type(self):
return RowTypeConstraint.from_fields([('example', Any), ('inference', Any),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we ask the handler for the type of inference? (Presumably the example type is that of the input as well.) Some handlers may not be able to provide this, but some can.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't generally today, but we could for a very limited set of handlers. The ones I can think of with predictable output types are hugging face pipelines and vLLM; the rest are all dependent on the model.

I think this is probably worth doing when we can, it probably requires some slight modification to the PredictionResult type, though, and might be worth scoping into a follow on PR.

('model_id', Optional[str])])

@staticmethod
def default_postprocess_fn():
return lambda x: beam.Row(**x._asdict())


@beam.ptransform.ptransform_fn
def run_inference(
Expand Down Expand Up @@ -392,7 +414,7 @@ def fn(x: PredictionResult):
'inference'.
inference_args: Extra arguments for models whose inference call requires
extra parameters. Make sure to check the underlying ModelHandler docs to
see which args are allowed.
see which args are allowed.

"""

Expand All @@ -414,16 +436,15 @@ def fn(x: PredictionResult):
typ = model_handler['type']
model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None)
if model_handler_provider and issubclass(model_handler_provider,
ModelHandlerProvider):
type(ModelHandlerProvider)):
model_handler_provider.validate(model_handler['config'])
else:
raise NotImplementedError(f'Unknown model handler type: {typ}.')

model_handler_provider = ModelHandlerProvider.create_handler(model_handler)
user_type = RowTypeConstraint.from_user_type(pcoll.element_type.user_type)
schema = RowTypeConstraint.from_fields(
list(
RowTypeConstraint.from_user_type(
pcoll.element_type.user_type)._fields) +
list(user_type._fields if user_type else []) +
[(inference_tag, model_handler_provider.inference_output_type())])

return (
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/yaml/yaml_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import logging
import unittest

import yaml
Expand Down
Loading