Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a simple validation transform to yaml. #32956

Merged
merged 6 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions sdks/python/apache_beam/yaml/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ def json_type_to_beam_type(json_type: Dict[str, Any]) -> schema_pb2.FieldType:
raise ValueError(f'Unable to convert {json_type} to a Beam schema.')


def beam_schema_to_json_schema(
beam_schema: schema_pb2.Schema) -> Dict[str, Any]:
return {
'type': 'object',
'properties': {
field.name: beam_type_to_json_type(field.type)
for field in beam_schema.fields
},
'additionalProperties': False
}


def beam_type_to_json_type(beam_type: schema_pb2.FieldType) -> Dict[str, Any]:
type_info = beam_type.WhichOneof("type_info")
if type_info == "atomic_type":
Expand Down Expand Up @@ -267,3 +279,52 @@ def json_formater(
convert = row_to_json(
schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema)))
return lambda row: json.dumps(convert(row), sort_keys=True).encode('utf-8')


def _validate_compatible(weak_schema, strong_schema):
if not weak_schema:
return
Comment on lines +285 to +286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to my other comment, but wouldn't this pass silently? I'm not sure how it would reach this state in the first place, but if the incoming PCollection does not have a schema (perhaps if preceded by a transform that does not output Row?), this would pass validation even if given a json schema to validate against.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, weak_schema could be {}. This comes up in practice if we don't know anything about this part of the input (e.g. it's Any), which is also the fallback for beam_type_to_json_type.

if weak_schema['type'] != strong_schema['type']:
raise ValueError(
'Incompatible types: %r vs %r' %
(weak_schema['type'] != strong_schema['type']))
if weak_schema['type'] == 'array':
_validate_compatible(weak_schema['items'], strong_schema['items'])
elif weak_schema == 'object':
for required in strong_schema.get('required', []):
if required not in weak_schema['properties']:
raise ValueError('Missing or unkown property %r' % required)
for name, spec in weak_schema.get('properties', {}):
if name in strong_schema['properties']:
try:
_validate_compatible(spec, strong_schema['properties'][name])
except Exception as exn:
raise ValueError('Incompatible schema for %r' % name) from exn
elif not strong_schema.get('additionalProperties'):
raise ValueError(
'Prohibited property: {property}; '
'perhaps additionalProperties: False is missing?')


def row_validator(beam_schema: schema_pb2.Schema,
json_schema: Dict[str, Any]) -> Callable[[Any], Any]:
"""Returns a callable that will fail on elements not respecting json_schema.
"""
if not json_schema:
return lambda x: None
Comment on lines +313 to +314
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably this never happens since json_schema is a required parameter to the Validate transform, but wouldn't this also imply that if a json_schema is not given, it will pass silently?

Nit, but perhaps having an error or warn log on compilation here instead would make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles the degenerate case where the schema is {}, which in json schema means "anything goes."

Though one could ask why one would even have this transform at all, it's quite possible the schema is provided elsewhere and we want to handle this case gracefully (similar to how empty lists are handled despite being degenerate).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok I was unaware of the {} case - thanks for the explanation!


# Validate that this compiles, but avoid pickling the validator itself.
_ = jsonschema.validators.validator_for(json_schema)(json_schema)
_validate_compatible(beam_schema_to_json_schema(beam_schema), json_schema)
validator_ptr = [None]

convert = row_to_json(
schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema)))

def validate(row):
if validator_ptr[0] is None:
validator_ptr[0] = jsonschema.validators.validator_for(json_schema)(
json_schema)
validator_ptr[0].validate(convert(row))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was recently doing something similar in Python and was informed about PEP 3104

Suggested change
if validator_ptr[0] is None:
validator_ptr[0] = jsonschema.validators.validator_for(json_schema)(
json_schema)
validator_ptr[0].validate(convert(row))
nonlocal validator_ptr
if not validator_ptr:
validator_ptr = jsonschema.validators.validator_for(json_schema)(
json_schema)
validator_ptr.validate(convert(row))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. Good call.


return validate
39 changes: 38 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from apache_beam.typehints.native_type_compatibility import convert_to_beam_type
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.typehints.schemas import schema_from_element_type
from apache_beam.utils import python_callable
from apache_beam.yaml import json_utils
from apache_beam.yaml import options
Expand Down Expand Up @@ -435,7 +436,8 @@ def _map_errors_to_standard_format(input_type):
# TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple.

return beam.Map(
lambda x: beam.Row(element=x[0], msg=str(x[1][1]), stack=str(x[1][2]))
lambda x: beam.Row(
element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2]))
).with_output_types(
RowTypeConstraint.from_fields([("element", input_type), ("msg", str),
("stack", str)]))
Expand Down Expand Up @@ -475,6 +477,40 @@ def expand(pcoll, error_handling=None, **kwargs):
return expand


class _Validate(beam.PTransform):
"""Validates each element of a PCollection against a json schema.

Args:
schema: A json schema against which to validate each element.
error_handling: Whether and how to handle errors during iteration.
If this is not set, invalid elements will fail the pipeline, otherwise
invalid elements will be passed to the specified error output along
with information about how the schema was invalidated.
"""
def __init__(
self,
schema: Dict[str, Any],
error_handling: Optional[Mapping[str, Any]] = None):
self._schema = schema
self._exception_handling_args = exception_handling_args(error_handling)

@maybe_with_exception_handling
def expand(self, pcoll):
validator = json_utils.row_validator(
schema_from_element_type(pcoll.element_type), self._schema)

def invoke_validator(x):
validator(x)
return x

return pcoll | beam.Map(invoke_validator)

def with_exception_handling(self, **kwargs):
# It's possible there's an error in iteration...
self._exception_handling_args = kwargs
return self


class _Explode(beam.PTransform):
"""Explodes (aka unnest/flatten) one or more fields producing multiple rows.

Expand Down Expand Up @@ -797,6 +833,7 @@ def create_mapping_providers():
'Partition-python': _Partition,
'Partition-javascript': _Partition,
'Partition-generic': _Partition,
'Validate': _Validate,
}),
yaml_provider.SqlBackedProvider({
'Filter-sql': _SqlFilterTransform,
Expand Down
37 changes: 37 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,43 @@ def test_explode(self):
beam.Row(a=3, b='y', c=.125, range=2),
]))

def test_validate(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(key='good', small=[5], nested=beam.Row(big=100)),
beam.Row(key='bad1', small=[500], nested=beam.Row(big=100)),
beam.Row(key='bad2', small=[5], nested=beam.Row(big=1)),
])
result = elements | YamlTransform(
'''
type: Validate
config:
schema:
type: object
properties:
small:
type: array
items:
type: integer
maximum: 10
nested:
type: object
properties:
big:
type: integer
minimum: 10
error_handling:
output: bad
''')

assert_that(
result['good'] | beam.Map(lambda x: x.key), equal_to(['good']))
assert_that(
result['bad'] | beam.Map(lambda x: x.element.key),
equal_to(['bad1', 'bad2']),
label='Errors')

def test_validate_explicit_types(self):
with self.assertRaisesRegex(TypeError, r'.*violates schema.*'):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
Expand Down
Loading