Skip to content

Commit

Permalink
TFX image classification example (#23456)
Browse files Browse the repository at this point in the history
* TFX image classification example

* TFX image classification with serving model

* Add TF Model Wrapper

* Clean up code

* Refactoring

* Add a test for tfx_bsl RunInference

* fixup lint

* Refactoring

* Fixup lint

* Add brief summary for the example

* Apply suggestions from code review

Co-authored-by: Andy Ye <[email protected]>

* Refactoring code and add comments

* Update help description

* reorder imports

* Reorder imports again

* Add docstring

* Refactoring

* Add pillow to tfx pipeline requirements

* Move inferencePostCommitIT to Python 3.9 suite

* Uncomment other postcommit suites

Co-authored-by: Andy Ye <[email protected]>
  • Loading branch information
AnandInguva and yeandy authored Nov 4, 2022
1 parent 30b2617 commit 1685251
Show file tree
Hide file tree
Showing 8 changed files with 509 additions and 2 deletions.
4 changes: 3 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,16 @@ tasks.register("python39PostCommit") {
dependsOn(":sdks:python:test-suites:direct:py39:postCommitIT")
dependsOn(":sdks:python:test-suites:direct:py39:hdfsIntegrationTest")
dependsOn(":sdks:python:test-suites:portable:py39:postCommitPy39")
// TODO (https://github.com/apache/beam/issues/23966)
// Move this to Python 3.10 test suite once tfx-bsl has python 3.10 wheel.
dependsOn(":sdks:python:test-suites:direct:py39:inferencePostCommitIT")
}

tasks.register("python310PostCommit") {
dependsOn(":sdks:python:test-suites:dataflow:py310:postCommitIT")
dependsOn(":sdks:python:test-suites:direct:py310:postCommitIT")
dependsOn(":sdks:python:test-suites:direct:py310:hdfsIntegrationTest")
dependsOn(":sdks:python:test-suites:portable:py310:postCommitPy310")
dependsOn(":sdks:python:test-suites:direct:py310:inferencePostCommitIT")
}

task("python37SickbayPostCommit") {
Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/examples/inference/tfx_bsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Intended only for internal testing.

from typing import Dict
from typing import Optional

import tensorflow as tf


class TFModelWrapperWithSignature(tf.keras.Model):
"""
Helper class used to wrap a based tf.keras.Model object with a serving
signature that can passed to the tfx_bsl RunInference transform.
A TF model saved using this helper class expects inputs as
images serialized to tf.string using tf.io.parse_tensor
and then passing serialized images to the RunInference transform
in the tf.train.Example. More about tf.train.Example at
https://www.tensorflow.org/api_docs/python/tf/train/Example
Usage:
Step 1:
# Save the base TF model with modified signature .
signature_model = TFModelWrapperWithSignature(
model=model,
preprocess_input=preprocess_input,
input_dtype=input_dtype,
feature_description=feature_description,
**kwargs
)
tf.saved_model.save(signature_model, path)
Step 2:
# Load the saved_model in the beam pipeline to create ModelHandler.
saved_model_spec = model_spec_pb2.SavedModelSpec(
model_path=known_args.model_path)
inferece_spec_type = model_spec_pb2.InferenceSpecType(
saved_model_spec=saved_model_spec)
model_handler = CreateModelHandler(inferece_spec_type)
"""
def __init__(
self,
model,
preprocess_input=None,
input_dtype=None,
feature_description=None,
**kwargs):
"""
model: model: Base tensorflow model used for TFX-BSL RunInference transform.
preprocess_input: Preprocess method to be included as part of the
model's serving signature.
input_dtype: tf dtype of the inputs passed to the model.
For eg: tf.int32, tf.uint8.
feature_description: Feature spec to parse inputs from tf.train.Example
using tf.parse_example(). For more details, please take a look at
https://www.tensorflow.org/api_docs/python/tf/io/parse_example
If there are extra arguments(for eg: training=False) that should be
passed to the base tf model during inference, please pass them in kwargs.
"""
super().__init__()
self.model = model
self.preprocess_input = preprocess_input
self.input_dtype = input_dtype
self.feature_description = feature_description
if not feature_description:
self.feature_description = {'image': tf.io.FixedLenFeature((), tf.string)}
self._kwargs = kwargs

@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def call(self, serialized_examples):
features = tf.io.parse_example(
serialized_examples, features=self.feature_description)

# Initialize a TensorArray to store the deserialized values.
# For more details, please look at
# https://github.com/tensorflow/tensorflow/issues/39323#issuecomment-627586602
num_batches = len(features['image'])
deserialized_vectors = tf.TensorArray(
self.input_dtype, size=num_batches, dynamic_size=True)
# Vectorized version of tf.io.parse_tensor is not available.
# Use for loop to vectorize the tensor. For more details, refer
# https://github.com/tensorflow/tensorflow/issues/43706
for i in range(num_batches):
deserialized_value = tf.io.parse_tensor(
features['image'][i], out_type=self.input_dtype)
# In Graph mode, return value must get assigned in order to
# update the array. More details at
# http://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873
deserialized_vectors = deserialized_vectors.write(i, deserialized_value)
deserialized_tensor = deserialized_vectors.stack()
if self.preprocess_input:
deserialized_tensor = self.preprocess_input(deserialized_tensor)
return self.model(deserialized_tensor, **self._kwargs)


def save_tf_model_with_signature(
path_to_save_model,
model=None,
preprocess_input=None,
input_dtype=tf.float32,
feature_description: Optional[Dict] = None,
**kwargs,
):
"""
Helper function used to save the Tensorflow Model with a serving signature.
This is intended only for internal testing.
Args:
path_to_save_model: Path to save the model with modified signature.
model: model: Base tensorflow model used for TFX-BSL RunInference transform.
preprocess_input: Preprocess method to be included as part of the
model's serving signature.
input_dtype: tf dtype of the inputs passed to the model.
For eg: tf.int32, tf.uint8.
feature_description: Feature spec to parse inputs from tf.train.Example using
tf.parse_example(). For more details, please take a look at
https://www.tensorflow.org/api_docs/python/tf/io/parse_example
If there are extra arguments(for eg: training=False) that should be passed to
the base tf model during inference, please pass them in kwargs.
"""
if not model:
model = tf.keras.applications.MobileNetV2(weights='imagenet')
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
signature_model = TFModelWrapperWithSignature(
model=model,
preprocess_input=preprocess_input,
input_dtype=input_dtype,
feature_description=feature_description,
**kwargs)
tf.saved_model.save(signature_model, path_to_save_model)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.s
#

tfx_bsl
pillow
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
A sample pipeline illustrating how to use Apache Beam RunInference
with TFX_BSL CreateModelHandler API. For more details, please look at
https://www.tensorflow.org/tfx/tfx_bsl/api_docs/python/tfx_bsl/public/beam/run_inference/CreateModelHandler.
Note: A Tensorflow Model needs to be updated with a @tf.function
signature in order to accept bytes as inputs, and should have logic
to decode bytes to Tensors that is acceptable by the TensorFlow model.
Please take a look at TFModelWrapperWithSignature class in
build_tensorflow_model.py on how to modify TF Model's signature
and the logic to decode the image tensor.
"""

import argparse
import io
import logging
import os
from typing import Iterable
from typing import Iterator
from typing import Optional
from typing import Tuple

import apache_beam as beam
import tensorflow as tf
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.runners.runner import PipelineResult
from PIL import Image
from tfx_bsl.public.beam.run_inference import CreateModelHandler
from tfx_bsl.public.beam.run_inference import prediction_log_pb2
from tfx_bsl.public.proto import model_spec_pb2

_IMG_SIZE = (224, 224)


def filter_empty_lines(text: str) -> Iterator[str]:
if len(text.strip()) > 0:
yield text


def read_and_process_image(
image_file_name: str,
path_to_dir: Optional[str] = None) -> Tuple[str, tf.Tensor]:
if path_to_dir is not None:
image_file_name = os.path.join(path_to_dir, image_file_name)
with FileSystems().open(image_file_name, 'r') as file:
data = Image.open(io.BytesIO(file.read())).convert('RGB')
# Note: Converts the image dtype from uint8 to float32
# https://www.tensorflow.org/api_docs/python/tf/image/resize
image = tf.keras.preprocessing.image.img_to_array(data)
image = tf.image.resize(image, _IMG_SIZE)
return image_file_name, image


def convert_image_to_example_proto(tensor: tf.Tensor) -> tf.train.Example:
"""
This method performs the following:
1. Accepts the tensor as input
2. Serializes the tensor into bytes and pass it through
tf.train.Feature
3. Pass the serialized tensor feature using tf.train.Example
Proto to the RunInference transform.
Args:
tensor: A TF tensor.
Returns:
example_proto: A tf.train.Example containing serialized tensor.
"""
serialized_non_scalar = tf.io.serialize_tensor(tensor)
feature_of_bytes = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[serialized_non_scalar.numpy()]))
features_for_example = {'image': feature_of_bytes}
example_proto = tf.train.Example(
features=tf.train.Features(feature=features_for_example))
return example_proto


class ProcessInferenceToString(beam.DoFn):
def process(
self, element: Tuple[str,
prediction_log_pb2.PredictionLog]) -> Iterable[str]:
"""
Args:
element: Tuple of str, and PredictionLog. Inference can be parsed
from prediction_log
returns:
str of filename and inference.
"""
filename, predict_log = element[0], element[1].predict_log
output_value = predict_log.response.outputs
output_tensor = (
tf.io.decode_raw(
output_value['output_0'].tensor_content, out_type=tf.float32))
max_index_output_tensor = tf.math.argmax(output_tensor, axis=0)
yield filename + ',' + str(tf.get_static_value(max_index_output_tensor))


def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--input',
dest='input',
required=True,
help='Path to the text file containing image names.')
parser.add_argument(
'--output',
dest='output',
required=True,
help='Path to save output predictions text file.')
parser.add_argument(
'--model_path',
dest='model_path',
required=True,
help="Path to the model.")
parser.add_argument(
'--images_dir',
default=None,
help='Path to the directory where images are stored.'
'Not required if image names in the input file have absolute path.')
return parser.parse_known_args(argv)


def run(argv=None, save_main_session=True, pipeline=None) -> PipelineResult:
"""
Args:
argv: Command line arguments defined for this example.
save_main_session: Used for internal testing.
test_pipeline: Used for internal testing.
"""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

saved_model_spec = model_spec_pb2.SavedModelSpec(
model_path=known_args.model_path)
inferece_spec_type = model_spec_pb2.InferenceSpecType(
saved_model_spec=saved_model_spec)
model_handler = CreateModelHandler(inferece_spec_type)
# create a KeyedModelHandler to accommodate image names as keys.
keyed_model_handler = KeyedModelHandler(model_handler)

if not pipeline:
pipeline = beam.Pipeline(options=pipeline_options)

filename_value_pair = (
pipeline
| 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
| 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
| 'ProcessImageData' >> beam.Map(
lambda image_name: read_and_process_image(
image_file_name=image_name, path_to_dir=known_args.images_dir)))

predictions = (
filename_value_pair
| 'ConvertToExampleProto' >>
beam.Map(lambda x: (x[0], convert_image_to_example_proto(x[1])))
| 'TFXRunInference' >> RunInference(keyed_model_handler)
| 'PostProcess' >> beam.ParDo(ProcessInferenceToString()))
_ = (
predictions
| "WriteOutputToGCS" >> beam.io.WriteToText(
known_args.output,
shard_name_template='',
append_trailing_newlines=True))

result = pipeline.run()
result.wait_until_finish()
return result


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
Loading

0 comments on commit 1685251

Please sign in to comment.