Skip to content

Commit

Permalink
Address reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chamikaramj committed Oct 24, 2022
1 parent 89fe0ea commit 1dde786
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public static ExpansionServiceSchemaTransformProvider of() {
return transformProvider;
}

/**
* Currently {@link PCollectionRowTuple} is a Java only concept. We may get a different input type
* when other SDKs use a Java schema-aware transform in a pipeline, hence we use this transform to
* convert input/output types.
*/
static class RowTransform extends PTransform {

private PTransform<PCollectionRowTuple, PCollectionRowTuple> rowTuplePTransform;
Expand Down Expand Up @@ -118,6 +123,11 @@ public POutput expand(PInput input) {
return PDone.in(input.getPipeline());
}
}

@Override
public String getName() {
return "RowTransform_of_" + this.rowTuplePTransform.getName();
}
}

@Override
Expand Down
99 changes: 36 additions & 63 deletions sdks/python/apache_beam/transforms/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import functools
import glob
import logging
import re
import threading
from collections import OrderedDict
from collections import namedtuple
from typing import Dict

import grpc
Expand Down Expand Up @@ -105,25 +105,10 @@ def payload(self):
"""
return self.build().SerializeToString()

def _get_schema_proto_and_payload(self, ignored_arg_format, *args, **kwargs):
def _get_schema_proto_and_payload(self, **kwargs):
named_fields = []
fields_to_values = OrderedDict()
next_field_id = 0
if args and not ignored_arg_format:
raise ValueError(
'If args are provided an ignored_arg_format must'
'also be provided')
for value in args:
if value is None:
raise ValueError(
'Received value None. None values are currently not supported')
named_fields.append(
((JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT % next_field_id),
convert_to_typing_type(instance_to_type(value))))
fields_to_values[(
JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT %
next_field_id)] = value
next_field_id += 1

for key, value in kwargs.items():
if not key:
raise ValueError('Parameter name cannot be empty')
Expand Down Expand Up @@ -200,8 +185,7 @@ def __init__(self, identifier, **kwargs):
self._kwargs = kwargs

def build(self):
schema_proto, payload = self._get_schema_proto_and_payload(
ignored_arg_format=None, **self._kwargs)
schema_proto, payload = self._get_schema_proto_and_payload(**self._kwargs)
payload = external_transforms_pb2.SchemaTransformPayload(
identifier=self._identifier,
configuration_schema=schema_proto,
Expand Down Expand Up @@ -230,13 +214,26 @@ def __init__(self, class_name):
self._constructor_param_kwargs = None
self._builder_methods_and_params = OrderedDict()

def _args_to_named_fields(self, args):
next_field_id = 0
named_fields = OrderedDict()
for value in args:
if value is None:
raise ValueError(
'Received value None. None values are currently not supported')
named_fields[(
JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT %
next_field_id)] = value
next_field_id += 1
return named_fields

def build(self):
constructor_param_args = self._constructor_param_args or []
constructor_param_kwargs = self._constructor_param_kwargs or {}
all_constructor_param_kwargs = self._args_to_named_fields(
self._constructor_param_args)
if self._constructor_param_kwargs:
all_constructor_param_kwargs.update(self._constructor_param_kwargs)
constructor_schema, constructor_payload = (
self._get_schema_proto_and_payload(
JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT,
*constructor_param_args, **constructor_param_kwargs))
self._get_schema_proto_and_payload(**all_constructor_param_kwargs))
payload = external_transforms_pb2.JavaClassLookupPayload(
class_name=self._class_name,
constructor_schema=constructor_schema,
Expand All @@ -246,10 +243,12 @@ def build(self):

for builder_method_name, params in self._builder_methods_and_params.items():
builder_method_args, builder_method_kwargs = params
all_builder_method_kwargs = self._args_to_named_fields(
builder_method_args)
if builder_method_kwargs:
all_builder_method_kwargs.update(builder_method_kwargs)
builder_method_schema, builder_method_payload = (
self._get_schema_proto_and_payload(
JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT,
*builder_method_args, **builder_method_kwargs))
self._get_schema_proto_and_payload(**all_builder_method_kwargs))
builder_method = external_transforms_pb2.BuilderMethod(
name=builder_method_name,
schema=builder_method_schema,
Expand Down Expand Up @@ -311,31 +310,10 @@ def _has_constructor(self):
self._constructor_param_kwargs)


class SchemaTransformsConfig(object):
"""
Information regarding a SchemaTransform available in an external SDK.
"""
def __init__(self, identifier, schema, named_inputs, named_outputs):
self._identifier = identifier
self._configuration_schema = schema
self._named_inputs = named_inputs
self._named_outputs = named_outputs

@property
def identifier(self):
return self._identifier

@property
def configuration_schema(self):
return self._configuration_schema

@property
def named_inputs(self):
return self._named_inputs

@property
def named_outputs(self):
return self._named_outputs
# Information regarding a SchemaTransform available in an external SDK.
SchemaTransformsConfig = namedtuple(
'SchemaTransformsConfig',
['identifier', 'configuration_schema', 'inputs', 'outputs'])


class SchemaAwareExternalTransform(ptransform.PTransform):
Expand Down Expand Up @@ -374,31 +352,26 @@ def expand(self, pcolls):
self._expansion_service)

@staticmethod
def discover(expansion_service, regex=None):
def discover(expansion_service):
"""Discover all SchemaTransforms available to the given expansion service.
:return: a list of SchemaTransformsConfigs that represent the discovered
SchemaTransforms.
"""

matcher = re.compile(regex) if regex else None

with ExternalTransform.service(expansion_service) as service:
discover_response = service.DiscoverSchemaTransform(
beam_expansion_api_pb2.DiscoverSchemaTransformRequest())

for identifier in discover_response.schema_transform_configs:
if matcher and not matcher.match(identifier):
continue

proto_config = discover_response.schema_transform_configs[identifier]
schema = named_tuple_from_schema(proto_config.config_schema)

yield SchemaTransformsConfig(
identifier,
schema,
proto_config.input_pcollection_names,
proto_config.output_pcollection_names)
identifier=identifier,
configuration_schema=schema,
inputs=proto_config.input_pcollection_names,
outputs=proto_config.output_pcollection_names)


class JavaExternalTransform(ptransform.PTransform):
Expand Down

0 comments on commit 1dde786

Please sign in to comment.