diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 155ea86219de..5b900f296688 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -72,6 +72,7 @@ from typing import ByteString from typing import Dict from typing import Generic +from typing import Iterable from typing import List from typing import Mapping from typing import NamedTuple @@ -308,6 +309,11 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: return schema_pb2.FieldType( map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) + elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str): + element_type = self.typing_to_runner_api(_get_args(type_)[0]) + return schema_pb2.FieldType( + array_type=schema_pb2.ArrayType(element_type=element_type)) + try: logical_type = LogicalType.from_typing(type_) except ValueError: diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 0cd9bdcadcc3..6f760f359b06 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -21,6 +21,7 @@ import collections import hashlib +import inspect import json import os import subprocess @@ -45,8 +46,12 @@ from apache_beam.transforms import external from apache_beam.transforms import window from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform +from apache_beam.typehints import native_type_compatibility from apache_beam.typehints import schemas from apache_beam.typehints import trivial_inference +from apache_beam.typehints.schemas import named_tuple_to_schema +from apache_beam.typehints.schemas import typing_from_runner_api +from apache_beam.typehints.schemas import typing_to_runner_api from apache_beam.utils import python_callable from apache_beam.utils import subprocess_server from apache_beam.version import __version__ as beam_version @@ -65,6 +70,9 @@ def provided_transforms(self) -> Iterable[str]: """Returns a list of transform type names this provider can handle.""" raise NotImplementedError(type(self)) + def config_schema(self, type): + return None + def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool: """Returns whether this transform requires inputs. @@ -140,6 +148,8 @@ def provided_transforms(self): return self._urns.keys() def schema_transforms(self): + if callable(self._service): + self._service = self._service() if self._schema_transforms is None: try: self._schema_transforms = { @@ -152,6 +162,11 @@ def schema_transforms(self): self._schema_transforms = {} return self._schema_transforms + def config_schema(self, type): + if self._urns[type] in self.schema_transforms(): + return named_tuple_to_schema( + self.schema_transforms()[self._urns[type]].configuration_schema) + def requires_inputs(self, typ, args): if self._urns[type] in self.schema_transforms(): return bool(self.schema_transforms()[self._urns[type]].inputs) @@ -392,6 +407,31 @@ def cache_artifacts(self): def provided_transforms(self): return self._transform_factories.keys() + def config_schema(self, typ): + factory = self._transform_factories[typ] + if isinstance(factory, type) and issubclass(factory, beam.PTransform): + # https://bugs.python.org/issue40897 + params = dict(inspect.signature(factory.__init__).parameters) + del params['self'] + else: + params = inspect.signature(factory).parameters + + def type_of(p): + t = p.annotation + if t == p.empty: + return Any + else: + return t + + names_and_types = [ + (name, typing_to_runner_api(type_of(p))) for name, p in params.items() + ] + return schema_pb2.Schema( + fields=[ + schema_pb2.Field(name=name, type=type) for name, + type in names_and_types + ]) + def create_transform(self, type, args, yaml_create_transform): return self._transform_factories[type](**args) @@ -490,7 +530,10 @@ def extract_field(x, name): # Or should this be posargs, args? # pylint: disable=dangerous-default-value - def fully_qualified_named_transform(constructor, args=(), kwargs={}): + def fully_qualified_named_transform( + constructor: str, + args: Iterable[Any] = (), + kwargs: Mapping[str, Any] = {}): with FullyQualifiedNamedTransform.with_filter('*'): return constructor >> FullyQualifiedNamedTransform( constructor, args, kwargs) @@ -662,6 +705,19 @@ def available(self) -> bool: def provided_transforms(self) -> Iterable[str]: return self._transforms.keys() + def config_schema(self, type): + underlying_schema = self._underlying_provider.config_schema( + self._transforms[type]) + if underlying_schema is None: + return None + underlying_schema_types = {f.name: f.type for f in underlying_schema.fields} + return schema_pb2.Schema( + fields=[ + schema_pb2.Field(name=src, type=underlying_schema_types[dest]) + for src, + dest in self._mappings[type].items() + ]) + def requires_inputs(self, typ, args): return self._underlying_provider.requires_inputs(typ, args) @@ -723,8 +779,42 @@ def standard_providers(): with open(os.path.join(os.path.dirname(__file__), 'standard_providers.yaml')) as fin: standard_providers = yaml.load(fin, Loader=SafeLoader) + return merge_providers( create_builtin_provider(), create_mapping_providers(), io_providers(), parse_providers(standard_providers)) + + +def list_providers(): + def pretty_type(field_type): + if field_type.WhichOneof('type_info') == 'row_type': + return pretty_schema(field_type.row_type.schema) + else: + t = typing_from_runner_api(field_type) + optional_base = native_type_compatibility.extract_optional_type(t) + if optional_base: + t = optional_base + suffix = '?' + else: + suffix = '' + s = str(t) + if s.startswith('