Skip to content

Commit

Permalink
Vertex AI Model Handler Private Endpoint Support (#27696)
Browse files Browse the repository at this point in the history
* Vertex AI Model Handler Private Endpoint Support

* linting

* Trailing whitespace

* import order

* Extra context on experiments

* Trailing whitespace
  • Loading branch information
jrmccluskey authored Jul 31, 2023
1 parent 881338e commit e9c81de
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,24 @@ def parse_known_args(argv):
type=str,
required=True,
help='GCP location for the Endpoint')
parser.add_argument(
'--endpoint_network',
dest='vpc_network',
type=str,
required=False,
help='GCP network the endpoint is peered to')
parser.add_argument(
'--experiment',
dest='experiment',
type=str,
required=False,
help='GCP experiment to pass to init')
help='Vertex AI experiment label to apply to queries')
parser.add_argument(
'--private',
dest='private',
type=bool,
default=False,
help="True if the Vertex AI endpoint is a private endpoint")
return parser.parse_known_args(argv)


Expand Down Expand Up @@ -130,7 +143,9 @@ def run(
endpoint_id=known_args.endpoint,
project=known_args.project,
location=known_args.location,
experiment=known_args.experiment)
experiment=known_args.experiment,
network=known_args.vpc_network,
private=known_args.private)

pipeline = test_pipeline
if not test_pipeline:
Expand Down
60 changes: 41 additions & 19 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Optional
from typing import Sequence

from google.api_core.exceptions import ClientError
from google.api_core.exceptions import ServerError
from google.api_core.exceptions import TooManyRequests
from google.cloud import aiplatform

Expand All @@ -41,20 +41,21 @@
# pylint: disable=line-too-long


def _retry_on_gcp_client_error(exception):
def _retry_on_appropriate_gcp_error(exception):
"""
Retry filter that returns True if a returned HTTP error code is 4xx. This is
used to retry remote requests that fail, most notably 429 (TooManyRequests.)
This is used for GCP-specific client errors.
Retry filter that returns True if a returned HTTP error code is 5xx or 429.
This is used to retry remote requests that fail, most notably 429
(TooManyRequests.)
Args:
exception: the returned exception encountered during the request/response
loop.
Returns:
boolean indication whether or not the exception is a GCP ClientError.
boolean indication whether or not the exception is a Server Error (5xx) or
a TooManyRequests (429) error.
"""
return isinstance(exception, ClientError)
return isinstance(exception, (TooManyRequests, ServerError))


class VertexAIModelHandlerJSON(ModelHandler[Any,
Expand All @@ -67,6 +68,7 @@ def __init__(
location: str,
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
**kwargs):
"""Implementation of the ModelHandler interface for Vertex AI.
**NOTE:** This API and its implementation are under development and
Expand All @@ -76,21 +78,33 @@ def __init__(
Vertex AI endpoint. In that way it functions more like a mid-pipeline
IO. Public Vertex AI endpoints have a maximum request size of 1.5 MB.
If you wish to make larger requests and use a private endpoint, provide
the Compute Engine network you wish to use.
the Compute Engine network you wish to use and set `private=True`
Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to query
project: the GCP project name where the endpoint is deployed
location: the GCP location where the endpoint is deployed
experiment: optional. experiment label to apply to the
queries
queries. See
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
for more information.
network: optional. the full name of the Compute Engine
network the endpoint is deployed on; used for private
endpoints only.
endpoints. The network or subnetwork Dataflow pipeline
option must be set and match this network for pipeline
execution.
Ex: "projects/12345/global/networks/myVPC"
private: optional. if the deployed Vertex AI endpoint is
private, set to true. Requires a network to be provided
as well.
"""

self._env_vars = kwargs.get('env_vars', {})

if private and network is None:
raise ValueError(
"A VPC network must be provided to use a private endpoint.")

# TODO: support the full list of options for aiplatform.init()
# See https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform#google_cloud_aiplatform_init
aiplatform.init(
Expand All @@ -102,7 +116,9 @@ def __init__(
# Check for liveness here but don't try to actually store the endpoint
# in the class yet
self.endpoint_name = endpoint_id
_ = self._retrieve_endpoint(self.endpoint_name)
self.is_private = private

_ = self._retrieve_endpoint(self.endpoint_name, self.is_private)

# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
Expand All @@ -113,18 +129,27 @@ def __init__(
self.throttler = AdaptiveThrottler(
window_ms=1, bucket_ms=1, overload_ratio=2)

def _retrieve_endpoint(self, endpoint_id: str) -> aiplatform.Endpoint:
def _retrieve_endpoint(
self, endpoint_id: str, is_private: bool) -> aiplatform.Endpoint:
"""Retrieves an AI Platform endpoint and queries it for liveness/deployed
models.
Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
is_private: a boolean indicating if the Vertex AI endpoint is a private
endpoint
Returns:
An aiplatform.Endpoint object
Raises:
ValueError: if endpoint is inactive or has no models deployed to it.
"""
endpoint = aiplatform.Endpoint(endpoint_name=endpoint_id)
if is_private:
endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
endpoint_name=endpoint_id)
LOGGER.debug("Treating endpoint %s as private", endpoint_id)
else:
endpoint = aiplatform.Endpoint(endpoint_name=endpoint_id)
LOGGER.debug("Treating endpoint %s as public", endpoint_id)

try:
mod_list = endpoint.list_models()
Expand All @@ -133,7 +158,7 @@ def _retrieve_endpoint(self, endpoint_id: str) -> aiplatform.Endpoint:
"Failed to contact endpoint %s, got exception: %s", endpoint_id, e)

if len(mod_list) == 0:
raise ValueError("Endpoint %s has no models deployed to it.")
raise ValueError("Endpoint %s has no models deployed to it.", endpoint_id)

return endpoint

Expand All @@ -143,11 +168,11 @@ def load_model(self) -> aiplatform.Endpoint:
"""
# Check to make sure the endpoint is still active since pipeline
# construction time
ep = self._retrieve_endpoint(self.endpoint_name)
ep = self._retrieve_endpoint(self.endpoint_name, self.is_private)
return ep

@retry.with_exponential_backoff(
num_retries=5, retry_filter=_retry_on_gcp_client_error)
num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
def get_request(
self,
batch: Sequence[Any],
Expand All @@ -170,9 +195,6 @@ def get_request(
except TooManyRequests as e:
LOGGER.warning("request was limited by the service with code %i", e.code)
raise
except ClientError as e:
LOGGER.warning("request failed with error code %i", e.code)
raise
except Exception as e:
LOGGER.error("unexpected exception raised as part of request, got %s", e)
raise
Expand Down
18 changes: 15 additions & 3 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import unittest

try:
from apache_beam.ml.inference.vertex_ai_inference import _retry_on_gcp_client_error
from apache_beam.ml.inference.vertex_ai_inference import _retry_on_appropriate_gcp_error
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
from google.api_core.exceptions import TooManyRequests
except ImportError:
raise unittest.SkipTest('VertexAI dependencies are not installed')
Expand All @@ -28,11 +29,22 @@
class RetryOnClientErrorTest(unittest.TestCase):
def test_retry_on_client_error_positive(self):
e = TooManyRequests(message="fake service rate limiting")
self.assertTrue(_retry_on_gcp_client_error(e))
self.assertTrue(_retry_on_appropriate_gcp_error(e))

def test_retry_on_client_error_negative(self):
e = ValueError()
self.assertFalse(_retry_on_gcp_client_error(e))
self.assertFalse(_retry_on_appropriate_gcp_error(e))


class ModelHandlerArgConditions(unittest.TestCase):
def test_exception_on_private_without_network(self):
self.assertRaises(
ValueError,
VertexAIModelHandlerJSON,
endpoint_id="1",
project="testproject",
location="us-central1",
private=True)


if __name__ == '__main__':
Expand Down

0 comments on commit e9c81de

Please sign in to comment.