Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YAML] Transform schema introspection. #28478

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from typing import ByteString
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
Expand Down Expand Up @@ -308,6 +309,11 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType:
return schema_pb2.FieldType(
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))

elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str):
element_type = self.typing_to_runner_api(_get_args(type_)[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))

try:
logical_type = LogicalType.from_typing(type_)
except ValueError:
Expand Down
92 changes: 91 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import collections
import hashlib
import inspect
import json
import os
import subprocess
Expand All @@ -45,8 +46,12 @@
from apache_beam.transforms import external
from apache_beam.transforms import window
from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import schemas
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.schemas import named_tuple_to_schema
from apache_beam.typehints.schemas import typing_from_runner_api
from apache_beam.typehints.schemas import typing_to_runner_api
from apache_beam.utils import python_callable
from apache_beam.utils import subprocess_server
from apache_beam.version import __version__ as beam_version
Expand All @@ -65,6 +70,9 @@ def provided_transforms(self) -> Iterable[str]:
"""Returns a list of transform type names this provider can handle."""
raise NotImplementedError(type(self))

def config_schema(self, type):
return None

def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool:
"""Returns whether this transform requires inputs.

Expand Down Expand Up @@ -140,6 +148,8 @@ def provided_transforms(self):
return self._urns.keys()

def schema_transforms(self):
if callable(self._service):
self._service = self._service()
if self._schema_transforms is None:
try:
self._schema_transforms = {
Expand All @@ -152,6 +162,11 @@ def schema_transforms(self):
self._schema_transforms = {}
return self._schema_transforms

def config_schema(self, type):
if self._urns[type] in self.schema_transforms():
return named_tuple_to_schema(
self.schema_transforms()[self._urns[type]].configuration_schema)

def requires_inputs(self, typ, args):
if self._urns[type] in self.schema_transforms():
return bool(self.schema_transforms()[self._urns[type]].inputs)
Expand Down Expand Up @@ -392,6 +407,31 @@ def cache_artifacts(self):
def provided_transforms(self):
return self._transform_factories.keys()

def config_schema(self, typ):
factory = self._transform_factories[typ]
if isinstance(factory, type) and issubclass(factory, beam.PTransform):
# https://bugs.python.org/issue40897
params = dict(inspect.signature(factory.__init__).parameters)
del params['self']
else:
params = inspect.signature(factory).parameters

def type_of(p):
t = p.annotation
if t == p.empty:
return Any
else:
return t

names_and_types = [
(name, typing_to_runner_api(type_of(p))) for name, p in params.items()
]
return schema_pb2.Schema(
fields=[
schema_pb2.Field(name=name, type=type) for name,
type in names_and_types
])

def create_transform(self, type, args, yaml_create_transform):
return self._transform_factories[type](**args)

Expand Down Expand Up @@ -490,7 +530,10 @@ def extract_field(x, name):

# Or should this be posargs, args?
# pylint: disable=dangerous-default-value
def fully_qualified_named_transform(constructor, args=(), kwargs={}):
def fully_qualified_named_transform(
constructor: str,
args: Iterable[Any] = (),
kwargs: Mapping[str, Any] = {}):
with FullyQualifiedNamedTransform.with_filter('*'):
return constructor >> FullyQualifiedNamedTransform(
constructor, args, kwargs)
Expand Down Expand Up @@ -662,6 +705,19 @@ def available(self) -> bool:
def provided_transforms(self) -> Iterable[str]:
return self._transforms.keys()

def config_schema(self, type):
underlying_schema = self._underlying_provider.config_schema(
self._transforms[type])
if underlying_schema is None:
return None
underlying_schema_types = {f.name: f.type for f in underlying_schema.fields}
return schema_pb2.Schema(
fields=[
schema_pb2.Field(name=src, type=underlying_schema_types[dest])
for src,
dest in self._mappings[type].items()
])

def requires_inputs(self, typ, args):
return self._underlying_provider.requires_inputs(typ, args)

Expand Down Expand Up @@ -723,8 +779,42 @@ def standard_providers():
with open(os.path.join(os.path.dirname(__file__),
'standard_providers.yaml')) as fin:
standard_providers = yaml.load(fin, Loader=SafeLoader)

return merge_providers(
create_builtin_provider(),
create_mapping_providers(),
io_providers(),
parse_providers(standard_providers))


def list_providers():
def pretty_type(field_type):
if field_type.WhichOneof('type_info') == 'row_type':
return pretty_schema(field_type.row_type.schema)
else:
t = typing_from_runner_api(field_type)
optional_base = native_type_compatibility.extract_optional_type(t)
if optional_base:
t = optional_base
suffix = '?'
else:
suffix = ''
s = str(t)
if s.startswith('<class '):
s = t.__name__
return s + suffix

def pretty_schema(s):
if s is None:
return '[no schema]'
return 'Row(%s)' % ', '.join(
f'{f.name}={pretty_type(f.type)}' for f in s.fields)

for t, providers in sorted(standard_providers().items()):
print(t)
for p in providers:
print('\t', type(p).__name__, pretty_schema(p.config_schema(t)))


if __name__ == '__main__':
list_providers()
Loading