-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TFX image classification example (#23456)
* TFX image classification example * TFX image classification with serving model * Add TF Model Wrapper * Clean up code * Refactoring * Add a test for tfx_bsl RunInference * fixup lint * Refactoring * Fixup lint * Add brief summary for the example * Apply suggestions from code review Co-authored-by: Andy Ye <[email protected]> * Refactoring code and add comments * Update help description * reorder imports * Reorder imports again * Add docstring * Refactoring * Add pillow to tfx pipeline requirements * Move inferencePostCommitIT to Python 3.9 suite * Uncomment other postcommit suites Co-authored-by: Andy Ye <[email protected]>
- Loading branch information
1 parent
30b2617
commit 1685251
Showing
8 changed files
with
509 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
16 changes: 16 additions & 0 deletions
16
sdks/python/apache_beam/examples/inference/tfx_bsl/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
147 changes: 147 additions & 0 deletions
147
sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# Intended only for internal testing. | ||
|
||
from typing import Dict | ||
from typing import Optional | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class TFModelWrapperWithSignature(tf.keras.Model): | ||
""" | ||
Helper class used to wrap a based tf.keras.Model object with a serving | ||
signature that can passed to the tfx_bsl RunInference transform. | ||
A TF model saved using this helper class expects inputs as | ||
images serialized to tf.string using tf.io.parse_tensor | ||
and then passing serialized images to the RunInference transform | ||
in the tf.train.Example. More about tf.train.Example at | ||
https://www.tensorflow.org/api_docs/python/tf/train/Example | ||
Usage: | ||
Step 1: | ||
# Save the base TF model with modified signature . | ||
signature_model = TFModelWrapperWithSignature( | ||
model=model, | ||
preprocess_input=preprocess_input, | ||
input_dtype=input_dtype, | ||
feature_description=feature_description, | ||
**kwargs | ||
) | ||
tf.saved_model.save(signature_model, path) | ||
Step 2: | ||
# Load the saved_model in the beam pipeline to create ModelHandler. | ||
saved_model_spec = model_spec_pb2.SavedModelSpec( | ||
model_path=known_args.model_path) | ||
inferece_spec_type = model_spec_pb2.InferenceSpecType( | ||
saved_model_spec=saved_model_spec) | ||
model_handler = CreateModelHandler(inferece_spec_type) | ||
""" | ||
def __init__( | ||
self, | ||
model, | ||
preprocess_input=None, | ||
input_dtype=None, | ||
feature_description=None, | ||
**kwargs): | ||
""" | ||
model: model: Base tensorflow model used for TFX-BSL RunInference transform. | ||
preprocess_input: Preprocess method to be included as part of the | ||
model's serving signature. | ||
input_dtype: tf dtype of the inputs passed to the model. | ||
For eg: tf.int32, tf.uint8. | ||
feature_description: Feature spec to parse inputs from tf.train.Example | ||
using tf.parse_example(). For more details, please take a look at | ||
https://www.tensorflow.org/api_docs/python/tf/io/parse_example | ||
If there are extra arguments(for eg: training=False) that should be | ||
passed to the base tf model during inference, please pass them in kwargs. | ||
""" | ||
super().__init__() | ||
self.model = model | ||
self.preprocess_input = preprocess_input | ||
self.input_dtype = input_dtype | ||
self.feature_description = feature_description | ||
if not feature_description: | ||
self.feature_description = {'image': tf.io.FixedLenFeature((), tf.string)} | ||
self._kwargs = kwargs | ||
|
||
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) | ||
def call(self, serialized_examples): | ||
features = tf.io.parse_example( | ||
serialized_examples, features=self.feature_description) | ||
|
||
# Initialize a TensorArray to store the deserialized values. | ||
# For more details, please look at | ||
# https://github.com/tensorflow/tensorflow/issues/39323#issuecomment-627586602 | ||
num_batches = len(features['image']) | ||
deserialized_vectors = tf.TensorArray( | ||
self.input_dtype, size=num_batches, dynamic_size=True) | ||
# Vectorized version of tf.io.parse_tensor is not available. | ||
# Use for loop to vectorize the tensor. For more details, refer | ||
# https://github.com/tensorflow/tensorflow/issues/43706 | ||
for i in range(num_batches): | ||
deserialized_value = tf.io.parse_tensor( | ||
features['image'][i], out_type=self.input_dtype) | ||
# In Graph mode, return value must get assigned in order to | ||
# update the array. More details at | ||
# http://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873 | ||
deserialized_vectors = deserialized_vectors.write(i, deserialized_value) | ||
deserialized_tensor = deserialized_vectors.stack() | ||
if self.preprocess_input: | ||
deserialized_tensor = self.preprocess_input(deserialized_tensor) | ||
return self.model(deserialized_tensor, **self._kwargs) | ||
|
||
|
||
def save_tf_model_with_signature( | ||
path_to_save_model, | ||
model=None, | ||
preprocess_input=None, | ||
input_dtype=tf.float32, | ||
feature_description: Optional[Dict] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Helper function used to save the Tensorflow Model with a serving signature. | ||
This is intended only for internal testing. | ||
Args: | ||
path_to_save_model: Path to save the model with modified signature. | ||
model: model: Base tensorflow model used for TFX-BSL RunInference transform. | ||
preprocess_input: Preprocess method to be included as part of the | ||
model's serving signature. | ||
input_dtype: tf dtype of the inputs passed to the model. | ||
For eg: tf.int32, tf.uint8. | ||
feature_description: Feature spec to parse inputs from tf.train.Example using | ||
tf.parse_example(). For more details, please take a look at | ||
https://www.tensorflow.org/api_docs/python/tf/io/parse_example | ||
If there are extra arguments(for eg: training=False) that should be passed to | ||
the base tf model during inference, please pass them in kwargs. | ||
""" | ||
if not model: | ||
model = tf.keras.applications.MobileNetV2(weights='imagenet') | ||
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input | ||
signature_model = TFModelWrapperWithSignature( | ||
model=model, | ||
preprocess_input=preprocess_input, | ||
input_dtype=input_dtype, | ||
feature_description=feature_description, | ||
**kwargs) | ||
tf.saved_model.save(signature_model, path_to_save_model) |
19 changes: 19 additions & 0 deletions
19
sdks/python/apache_beam/examples/inference/tfx_bsl/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License.s | ||
# | ||
|
||
tfx_bsl | ||
pillow |
194 changes: 194 additions & 0 deletions
194
sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
A sample pipeline illustrating how to use Apache Beam RunInference | ||
with TFX_BSL CreateModelHandler API. For more details, please look at | ||
https://www.tensorflow.org/tfx/tfx_bsl/api_docs/python/tfx_bsl/public/beam/run_inference/CreateModelHandler. | ||
Note: A Tensorflow Model needs to be updated with a @tf.function | ||
signature in order to accept bytes as inputs, and should have logic | ||
to decode bytes to Tensors that is acceptable by the TensorFlow model. | ||
Please take a look at TFModelWrapperWithSignature class in | ||
build_tensorflow_model.py on how to modify TF Model's signature | ||
and the logic to decode the image tensor. | ||
""" | ||
|
||
import argparse | ||
import io | ||
import logging | ||
import os | ||
from typing import Iterable | ||
from typing import Iterator | ||
from typing import Optional | ||
from typing import Tuple | ||
|
||
import apache_beam as beam | ||
import tensorflow as tf | ||
from apache_beam.io.filesystems import FileSystems | ||
from apache_beam.ml.inference.base import KeyedModelHandler | ||
from apache_beam.ml.inference.base import RunInference | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.options.pipeline_options import SetupOptions | ||
from apache_beam.runners.runner import PipelineResult | ||
from PIL import Image | ||
from tfx_bsl.public.beam.run_inference import CreateModelHandler | ||
from tfx_bsl.public.beam.run_inference import prediction_log_pb2 | ||
from tfx_bsl.public.proto import model_spec_pb2 | ||
|
||
_IMG_SIZE = (224, 224) | ||
|
||
|
||
def filter_empty_lines(text: str) -> Iterator[str]: | ||
if len(text.strip()) > 0: | ||
yield text | ||
|
||
|
||
def read_and_process_image( | ||
image_file_name: str, | ||
path_to_dir: Optional[str] = None) -> Tuple[str, tf.Tensor]: | ||
if path_to_dir is not None: | ||
image_file_name = os.path.join(path_to_dir, image_file_name) | ||
with FileSystems().open(image_file_name, 'r') as file: | ||
data = Image.open(io.BytesIO(file.read())).convert('RGB') | ||
# Note: Converts the image dtype from uint8 to float32 | ||
# https://www.tensorflow.org/api_docs/python/tf/image/resize | ||
image = tf.keras.preprocessing.image.img_to_array(data) | ||
image = tf.image.resize(image, _IMG_SIZE) | ||
return image_file_name, image | ||
|
||
|
||
def convert_image_to_example_proto(tensor: tf.Tensor) -> tf.train.Example: | ||
""" | ||
This method performs the following: | ||
1. Accepts the tensor as input | ||
2. Serializes the tensor into bytes and pass it through | ||
tf.train.Feature | ||
3. Pass the serialized tensor feature using tf.train.Example | ||
Proto to the RunInference transform. | ||
Args: | ||
tensor: A TF tensor. | ||
Returns: | ||
example_proto: A tf.train.Example containing serialized tensor. | ||
""" | ||
serialized_non_scalar = tf.io.serialize_tensor(tensor) | ||
feature_of_bytes = tf.train.Feature( | ||
bytes_list=tf.train.BytesList(value=[serialized_non_scalar.numpy()])) | ||
features_for_example = {'image': feature_of_bytes} | ||
example_proto = tf.train.Example( | ||
features=tf.train.Features(feature=features_for_example)) | ||
return example_proto | ||
|
||
|
||
class ProcessInferenceToString(beam.DoFn): | ||
def process( | ||
self, element: Tuple[str, | ||
prediction_log_pb2.PredictionLog]) -> Iterable[str]: | ||
""" | ||
Args: | ||
element: Tuple of str, and PredictionLog. Inference can be parsed | ||
from prediction_log | ||
returns: | ||
str of filename and inference. | ||
""" | ||
filename, predict_log = element[0], element[1].predict_log | ||
output_value = predict_log.response.outputs | ||
output_tensor = ( | ||
tf.io.decode_raw( | ||
output_value['output_0'].tensor_content, out_type=tf.float32)) | ||
max_index_output_tensor = tf.math.argmax(output_tensor, axis=0) | ||
yield filename + ',' + str(tf.get_static_value(max_index_output_tensor)) | ||
|
||
|
||
def parse_known_args(argv): | ||
"""Parses args for the workflow.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--input', | ||
dest='input', | ||
required=True, | ||
help='Path to the text file containing image names.') | ||
parser.add_argument( | ||
'--output', | ||
dest='output', | ||
required=True, | ||
help='Path to save output predictions text file.') | ||
parser.add_argument( | ||
'--model_path', | ||
dest='model_path', | ||
required=True, | ||
help="Path to the model.") | ||
parser.add_argument( | ||
'--images_dir', | ||
default=None, | ||
help='Path to the directory where images are stored.' | ||
'Not required if image names in the input file have absolute path.') | ||
return parser.parse_known_args(argv) | ||
|
||
|
||
def run(argv=None, save_main_session=True, pipeline=None) -> PipelineResult: | ||
""" | ||
Args: | ||
argv: Command line arguments defined for this example. | ||
save_main_session: Used for internal testing. | ||
test_pipeline: Used for internal testing. | ||
""" | ||
known_args, pipeline_args = parse_known_args(argv) | ||
pipeline_options = PipelineOptions(pipeline_args) | ||
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session | ||
|
||
saved_model_spec = model_spec_pb2.SavedModelSpec( | ||
model_path=known_args.model_path) | ||
inferece_spec_type = model_spec_pb2.InferenceSpecType( | ||
saved_model_spec=saved_model_spec) | ||
model_handler = CreateModelHandler(inferece_spec_type) | ||
# create a KeyedModelHandler to accommodate image names as keys. | ||
keyed_model_handler = KeyedModelHandler(model_handler) | ||
|
||
if not pipeline: | ||
pipeline = beam.Pipeline(options=pipeline_options) | ||
|
||
filename_value_pair = ( | ||
pipeline | ||
| 'ReadImageNames' >> beam.io.ReadFromText(known_args.input) | ||
| 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines) | ||
| 'ProcessImageData' >> beam.Map( | ||
lambda image_name: read_and_process_image( | ||
image_file_name=image_name, path_to_dir=known_args.images_dir))) | ||
|
||
predictions = ( | ||
filename_value_pair | ||
| 'ConvertToExampleProto' >> | ||
beam.Map(lambda x: (x[0], convert_image_to_example_proto(x[1]))) | ||
| 'TFXRunInference' >> RunInference(keyed_model_handler) | ||
| 'PostProcess' >> beam.ParDo(ProcessInferenceToString())) | ||
_ = ( | ||
predictions | ||
| "WriteOutputToGCS" >> beam.io.WriteToText( | ||
known_args.output, | ||
shard_name_template='', | ||
append_trailing_newlines=True)) | ||
|
||
result = pipeline.run() | ||
result.wait_until_finish() | ||
return result | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.getLogger().setLevel(logging.INFO) | ||
run() |
Oops, something went wrong.