Skip to content

Commit

Permalink
Make BatchConverter inference errors more helpful (#23965)
Browse files Browse the repository at this point in the history
* Add better error messages for BatchConverter inference failure

* Remove extraneous 'BatchDoFn' from class names

* yapf
  • Loading branch information
TheNeuralBit authored Nov 4, 2022
1 parent e2463a4 commit 7da182a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 18 deletions.
49 changes: 37 additions & 12 deletions sdks/python/apache_beam/transforms/batch_dofn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def process_batch(self, batch: List[int], *args,
yield [element / 2 for element in batch]


class BatchDoFnNoReturnAnnotation(beam.DoFn):
class NoReturnAnnotation(beam.DoFn):
def process_batch(self, batch: List[int], *args, **kwargs):
yield [element * 2 for element in batch]


class BatchDoFnOverrideTypeInference(beam.DoFn):
class OverrideTypeInference(beam.DoFn):
def process_batch(self, batch, *args, **kwargs):
yield [element * 2 for element in batch]

Expand Down Expand Up @@ -104,15 +104,15 @@ def get_test_class_name(cls, num, params_dict):
"expected_output_batch_type": beam.typehints.List[float]
},
{
"dofn": BatchDoFnNoReturnAnnotation(),
"dofn": NoReturnAnnotation(),
"input_element_type": int,
"expected_process_defined": False,
"expected_process_batch_defined": True,
"expected_input_batch_type": beam.typehints.List[int],
"expected_output_batch_type": beam.typehints.List[int]
},
{
"dofn": BatchDoFnOverrideTypeInference(),
"dofn": OverrideTypeInference(),
"input_element_type": int,
"expected_process_defined": False,
"expected_process_batch_defined": True,
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_can_yield_batches(self):
self.assertEqual(self.dofn._can_yield_batches, expected)


class BatchDoFnNoInputAnnotation(beam.DoFn):
class NoInputAnnotation(beam.DoFn):
def process_batch(self, batch, *args, **kwargs):
yield [element * 2 for element in batch]

Expand Down Expand Up @@ -198,6 +198,12 @@ def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[int]:
yield batch[0]


class NoElementOutputAnnotation(beam.DoFn):
def process_batch(self, batch: List[int], *args,
**kwargs) -> Iterator[List[int]]:
yield [element * 2 for element in batch]


class BatchDoFnTest(unittest.TestCase):
def test_map_pardo(self):
# verify batch dofn accessors work well with beam.Map generated DoFn
Expand All @@ -213,22 +219,20 @@ def test_no_input_annotation_raises(self):
p = beam.Pipeline()
pc = p | beam.Create([1, 2, 3])

with self.assertRaisesRegex(TypeError,
r'BatchDoFnNoInputAnnotation.process_batch'):
_ = pc | beam.ParDo(BatchDoFnNoInputAnnotation())
with self.assertRaisesRegex(TypeError, r'NoInputAnnotation.process_batch'):
_ = pc | beam.ParDo(NoInputAnnotation())

def test_unsupported_dofn_param_raises(self):
class BatchDoFnBadParam(beam.DoFn):
class BadParam(beam.DoFn):
@no_type_check
def process_batch(self, batch: List[int], key=beam.DoFn.KeyParam):
yield batch * key

p = beam.Pipeline()
pc = p | beam.Create([1, 2, 3])

with self.assertRaisesRegex(NotImplementedError,
r'BatchDoFnBadParam.*KeyParam'):
_ = pc | beam.ParDo(BatchDoFnBadParam())
with self.assertRaisesRegex(NotImplementedError, r'BadParam.*KeyParam'):
_ = pc | beam.ParDo(BadParam())

def test_mismatched_batch_producer_raises(self):
p = beam.Pipeline()
Expand Down Expand Up @@ -256,6 +260,27 @@ def test_mismatched_element_producer_raises(self):
r'(?ms)MismatchedElementProducingDoFn.*process:.*process_batch:'):
_ = pc | beam.ParDo(MismatchedElementProducingDoFn())

def test_cant_infer_batchconverter_input_raises(self):
p = beam.Pipeline()
pc = p | beam.Create(['a', 'b', 'c'])

with self.assertRaisesRegex(
TypeError,
# Error should mention "input", and the name of the DoFn
r'input.*BatchDoFn.*'):
_ = pc | beam.ParDo(BatchDoFn())

def test_cant_infer_batchconverter_output_raises(self):
p = beam.Pipeline()
pc = p | beam.Create([1, 2, 3])

with self.assertRaisesRegex(
TypeError,
# Error should mention "output", the name of the DoFn, and suggest
# overriding DoFn.infer_output_type
r'output.*NoElementOutputAnnotation.*DoFn\.infer_output_type'):
_ = pc | beam.ParDo(NoElementOutputAnnotation())

def test_element_to_batch_dofn_typehint(self):
# Verify that element to batch DoFn sets the correct typehint on the output
# PCollection.
Expand Down
27 changes: 21 additions & 6 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,10 +1511,17 @@ def infer_batch_converters(self, input_element_type):
"process_batch method on {self.fn!r} does not have "
"an input type annoation")

# Generate a batch converter to convert between the input type and the
# (batch) input type of process_batch
self.fn.input_batch_converter = BatchConverter.from_typehints(
element_type=input_element_type, batch_type=input_batch_type)
try:
# Generate a batch converter to convert between the input type and the
# (batch) input type of process_batch
self.fn.input_batch_converter = BatchConverter.from_typehints(
element_type=input_element_type, batch_type=input_batch_type)
except TypeError as e:
raise TypeError(
"Failed to find a BatchConverter for the input types of DoFn "
f"{self.fn!r} (element_type={input_element_type!r}, "
f"batch_type={input_batch_type!r}).") from e

else:
self.fn.input_batch_converter = None

Expand All @@ -1530,8 +1537,16 @@ def infer_batch_converters(self, input_element_type):
# Generate a batch converter to convert between the output type and the
# (batch) output type of process_batch
output_element_type = self.infer_output_type(input_element_type)
self.fn.output_batch_converter = BatchConverter.from_typehints(
element_type=output_element_type, batch_type=output_batch_type)

try:
self.fn.output_batch_converter = BatchConverter.from_typehints(
element_type=output_element_type, batch_type=output_batch_type)
except TypeError as e:
raise TypeError(
"Failed to find a BatchConverter for the *output* types of DoFn "
f"{self.fn!r} (element_type={output_element_type!r}, "
f"batch_type={output_batch_type!r}). Maybe you need to override "
"DoFn.infer_output_type to set the output element type?") from e
else:
self.fn.output_batch_converter = None

Expand Down

0 comments on commit 7da182a

Please sign in to comment.