From ff7fc63815231d17d9c50be506f17a5618e02fbe Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 15 Sep 2023 09:15:26 -0700 Subject: [PATCH 1/2] [YAML] Transform schema introspection. --- sdks/python/apache_beam/typehints/schemas.py | 6 ++ sdks/python/apache_beam/yaml/yaml_provider.py | 97 ++++++++++++++++++- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 155ea86219de..896c84e690fe 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -72,6 +72,7 @@ from typing import ByteString from typing import Dict from typing import Generic +from typing import Iterable from typing import List from typing import Mapping from typing import NamedTuple @@ -303,6 +304,11 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: return schema_pb2.FieldType( array_type=schema_pb2.ArrayType(element_type=element_type)) + elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str): + element_type = self.typing_to_runner_api(_get_args(type_)[0]) + return schema_pb2.FieldType( + array_type=schema_pb2.ArrayType(element_type=element_type)) + elif _safe_issubclass(type_, Mapping): key_type, value_type = map(self.typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType( diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 6e035811d4b9..736b1cab4658 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -21,6 +21,7 @@ import collections import hashlib +import inspect import json import os import subprocess @@ -45,8 +46,12 @@ from apache_beam.transforms import external from apache_beam.transforms import window from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform +from apache_beam.typehints import native_type_compatibility from apache_beam.typehints import schemas from apache_beam.typehints import trivial_inference +from apache_beam.typehints.schemas import named_tuple_to_schema +from apache_beam.typehints.schemas import typing_from_runner_api +from apache_beam.typehints.schemas import typing_to_runner_api from apache_beam.utils import python_callable from apache_beam.utils import subprocess_server from apache_beam.version import __version__ as beam_version @@ -65,6 +70,9 @@ def provided_transforms(self) -> Iterable[str]: """Returns a list of transform type names this provider can handle.""" raise NotImplementedError(type(self)) + def config_schema(self, type): + return None + def create_transform( self, typ: str, @@ -129,7 +137,7 @@ def __init__(self, urns, service): def provided_transforms(self): return self._urns.keys() - def create_transform(self, type, args, yaml_create_transform): + def schema_transforms(self): if callable(self._service): self._service = self._service() if self._schema_transforms is None: @@ -142,8 +150,16 @@ def create_transform(self, type, args, yaml_create_transform): except Exception: # It's possible this service doesn't vend schema transforms. self._schema_transforms = {} + return self._schema_transforms + + def config_schema(self, type): + if self._urns[type] in self.schema_transforms(): + return named_tuple_to_schema( + self.schema_transforms()[self._urns[type]].configuration_schema) + + def create_transform(self, type, args, yaml_create_transform): urn = self._urns[type] - if urn in self._schema_transforms: + if urn in self.schema_transforms(): return external.SchemaAwareExternalTransform( urn, self._service, rearrange_based_on_discovery=True, **args) else: @@ -371,6 +387,31 @@ def cache_artifacts(self): def provided_transforms(self): return self._transform_factories.keys() + def config_schema(self, typ): + factory = self._transform_factories[typ] + if isinstance(factory, type) and issubclass(factory, beam.PTransform): + # https://bugs.python.org/issue40897 + params = dict(inspect.signature(factory.__init__).parameters) + del params['self'] + else: + params = inspect.signature(factory).parameters + + def type_of(p): + t = p.annotation + if t == p.empty: + return Any + else: + return t + + names_and_types = [ + (name, typing_to_runner_api(type_of(p))) for name, p in params.items() + ] + return schema_pb2.Schema( + fields=[ + schema_pb2.Field(name=name, type=type) for name, + type in names_and_types + ]) + def create_transform(self, type, args, yaml_create_transform): return self._transform_factories[type](**args) @@ -461,7 +502,10 @@ def extract_field(x, name): # Or should this be posargs, args? # pylint: disable=dangerous-default-value - def fully_qualified_named_transform(constructor, args=(), kwargs={}): + def fully_qualified_named_transform( + constructor: str, + args: Iterable[Any] = (), + kwargs: Mapping[str, Any] = {}): with FullyQualifiedNamedTransform.with_filter('*'): return constructor >> FullyQualifiedNamedTransform( constructor, args, kwargs) @@ -639,6 +683,19 @@ def available(self) -> bool: def provided_transforms(self) -> Iterable[str]: return self._transforms.keys() + def config_schema(self, type): + underlying_schema = self._underlying_provider.config_schema( + self._transforms[type]) + if underlying_schema is None: + return None + underlying_schema_types = {f.name: f.type for f in underlying_schema.fields} + return schema_pb2.Schema( + fields=[ + schema_pb2.Field(name=src, type=underlying_schema_types[dest]) + for src, + dest in self._mappings[type].items() + ]) + def create_transform( self, typ: str, @@ -695,8 +752,42 @@ def standard_providers(): with open(os.path.join(os.path.dirname(__file__), 'standard_providers.yaml')) as fin: standard_providers = yaml.load(fin, Loader=SafeLoader) + return merge_providers( create_builtin_provider(), create_mapping_provider(), io_providers(), parse_providers(standard_providers)) + + +def list_providers(): + def pretty_type(field_type): + if field_type.WhichOneof('type_info') == 'row_type': + return pretty_schema(field_type.row_type.schema) + else: + t = typing_from_runner_api(field_type) + optional_base = native_type_compatibility.extract_optional_type(t) + if optional_base: + t = optional_base + suffix = '?' + else: + suffix = '' + s = str(t) + if s.startswith(' Date: Fri, 15 Sep 2023 10:34:04 -0700 Subject: [PATCH 2/2] test mapping before iter --- sdks/python/apache_beam/typehints/schemas.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 896c84e690fe..5b900f296688 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -304,16 +304,16 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: return schema_pb2.FieldType( array_type=schema_pb2.ArrayType(element_type=element_type)) - elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str): - element_type = self.typing_to_runner_api(_get_args(type_)[0]) - return schema_pb2.FieldType( - array_type=schema_pb2.ArrayType(element_type=element_type)) - elif _safe_issubclass(type_, Mapping): key_type, value_type = map(self.typing_to_runner_api, _get_args(type_)) return schema_pb2.FieldType( map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) + elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str): + element_type = self.typing_to_runner_api(_get_args(type_)[0]) + return schema_pb2.FieldType( + array_type=schema_pb2.ArrayType(element_type=element_type)) + try: logical_type = LogicalType.from_typing(type_) except ValueError: