diff --git a/sdks/python/apache_beam/transforms/batch_dofn_test.py b/sdks/python/apache_beam/transforms/batch_dofn_test.py index de35c29024a5..d2aceb371492 100644 --- a/sdks/python/apache_beam/transforms/batch_dofn_test.py +++ b/sdks/python/apache_beam/transforms/batch_dofn_test.py @@ -41,12 +41,12 @@ def process_batch(self, batch: List[int], *args, yield [element / 2 for element in batch] -class BatchDoFnNoReturnAnnotation(beam.DoFn): +class NoReturnAnnotation(beam.DoFn): def process_batch(self, batch: List[int], *args, **kwargs): yield [element * 2 for element in batch] -class BatchDoFnOverrideTypeInference(beam.DoFn): +class OverrideTypeInference(beam.DoFn): def process_batch(self, batch, *args, **kwargs): yield [element * 2 for element in batch] @@ -104,7 +104,7 @@ def get_test_class_name(cls, num, params_dict): "expected_output_batch_type": beam.typehints.List[float] }, { - "dofn": BatchDoFnNoReturnAnnotation(), + "dofn": NoReturnAnnotation(), "input_element_type": int, "expected_process_defined": False, "expected_process_batch_defined": True, @@ -112,7 +112,7 @@ def get_test_class_name(cls, num, params_dict): "expected_output_batch_type": beam.typehints.List[int] }, { - "dofn": BatchDoFnOverrideTypeInference(), + "dofn": OverrideTypeInference(), "input_element_type": int, "expected_process_defined": False, "expected_process_batch_defined": True, @@ -168,7 +168,7 @@ def test_can_yield_batches(self): self.assertEqual(self.dofn._can_yield_batches, expected) -class BatchDoFnNoInputAnnotation(beam.DoFn): +class NoInputAnnotation(beam.DoFn): def process_batch(self, batch, *args, **kwargs): yield [element * 2 for element in batch] @@ -198,6 +198,12 @@ def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[int]: yield batch[0] +class NoElementOutputAnnotation(beam.DoFn): + def process_batch(self, batch: List[int], *args, + **kwargs) -> Iterator[List[int]]: + yield [element * 2 for element in batch] + + class BatchDoFnTest(unittest.TestCase): def test_map_pardo(self): # verify batch dofn accessors work well with beam.Map generated DoFn @@ -213,12 +219,11 @@ def test_no_input_annotation_raises(self): p = beam.Pipeline() pc = p | beam.Create([1, 2, 3]) - with self.assertRaisesRegex(TypeError, - r'BatchDoFnNoInputAnnotation.process_batch'): - _ = pc | beam.ParDo(BatchDoFnNoInputAnnotation()) + with self.assertRaisesRegex(TypeError, r'NoInputAnnotation.process_batch'): + _ = pc | beam.ParDo(NoInputAnnotation()) def test_unsupported_dofn_param_raises(self): - class BatchDoFnBadParam(beam.DoFn): + class BadParam(beam.DoFn): @no_type_check def process_batch(self, batch: List[int], key=beam.DoFn.KeyParam): yield batch * key @@ -226,9 +231,8 @@ def process_batch(self, batch: List[int], key=beam.DoFn.KeyParam): p = beam.Pipeline() pc = p | beam.Create([1, 2, 3]) - with self.assertRaisesRegex(NotImplementedError, - r'BatchDoFnBadParam.*KeyParam'): - _ = pc | beam.ParDo(BatchDoFnBadParam()) + with self.assertRaisesRegex(NotImplementedError, r'BadParam.*KeyParam'): + _ = pc | beam.ParDo(BadParam()) def test_mismatched_batch_producer_raises(self): p = beam.Pipeline() @@ -256,6 +260,27 @@ def test_mismatched_element_producer_raises(self): r'(?ms)MismatchedElementProducingDoFn.*process:.*process_batch:'): _ = pc | beam.ParDo(MismatchedElementProducingDoFn()) + def test_cant_infer_batchconverter_input_raises(self): + p = beam.Pipeline() + pc = p | beam.Create(['a', 'b', 'c']) + + with self.assertRaisesRegex( + TypeError, + # Error should mention "input", and the name of the DoFn + r'input.*BatchDoFn.*'): + _ = pc | beam.ParDo(BatchDoFn()) + + def test_cant_infer_batchconverter_output_raises(self): + p = beam.Pipeline() + pc = p | beam.Create([1, 2, 3]) + + with self.assertRaisesRegex( + TypeError, + # Error should mention "output", the name of the DoFn, and suggest + # overriding DoFn.infer_output_type + r'output.*NoElementOutputAnnotation.*DoFn\.infer_output_type'): + _ = pc | beam.ParDo(NoElementOutputAnnotation()) + def test_element_to_batch_dofn_typehint(self): # Verify that element to batch DoFn sets the correct typehint on the output # PCollection. diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 50ff32e57a33..69c003ee5b8c 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1511,10 +1511,17 @@ def infer_batch_converters(self, input_element_type): "process_batch method on {self.fn!r} does not have " "an input type annoation") - # Generate a batch converter to convert between the input type and the - # (batch) input type of process_batch - self.fn.input_batch_converter = BatchConverter.from_typehints( - element_type=input_element_type, batch_type=input_batch_type) + try: + # Generate a batch converter to convert between the input type and the + # (batch) input type of process_batch + self.fn.input_batch_converter = BatchConverter.from_typehints( + element_type=input_element_type, batch_type=input_batch_type) + except TypeError as e: + raise TypeError( + "Failed to find a BatchConverter for the input types of DoFn " + f"{self.fn!r} (element_type={input_element_type!r}, " + f"batch_type={input_batch_type!r}).") from e + else: self.fn.input_batch_converter = None @@ -1530,8 +1537,16 @@ def infer_batch_converters(self, input_element_type): # Generate a batch converter to convert between the output type and the # (batch) output type of process_batch output_element_type = self.infer_output_type(input_element_type) - self.fn.output_batch_converter = BatchConverter.from_typehints( - element_type=output_element_type, batch_type=output_batch_type) + + try: + self.fn.output_batch_converter = BatchConverter.from_typehints( + element_type=output_element_type, batch_type=output_batch_type) + except TypeError as e: + raise TypeError( + "Failed to find a BatchConverter for the *output* types of DoFn " + f"{self.fn!r} (element_type={output_element_type!r}, " + f"batch_type={output_batch_type!r}). Maybe you need to override " + "DoFn.infer_output_type to set the output element type?") from e else: self.fn.output_batch_converter = None