diff --git a/CHANGES.md b/CHANGES.md index 7fdca4e6914b..049ef1a1e227 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -64,6 +64,9 @@ ## New Features / Improvements * The Flink runner now supports Flink 1.16.x ([#25046](https://github.com/apache/beam/issues/25046)). +* Schema'd PTransforms can now be directly applied to Beam dataframes just like PCollections. + (Note that when doing multiple operations, it may be more efficient to explicitly chain the operations + like `df | (Transform1 | Transform2 | ...)` to avoid excessive conversions.) ## Breaking Changes diff --git a/sdks/python/apache_beam/dataframe/convert_test.py b/sdks/python/apache_beam/dataframe/convert_test.py index 152c0f00d18b..b00ce0e51fa8 100644 --- a/sdks/python/apache_beam/dataframe/convert_test.py +++ b/sdks/python/apache_beam/dataframe/convert_test.py @@ -184,6 +184,26 @@ def test_convert_memoization_clears_cache(self): gc.enable() logging.disable(logging.NOTSET) + def test_auto_convert(self): + class MySchemaTransform(beam.PTransform): + def expand(self, pcoll): + return pcoll | beam.Map( + lambda x: beam.Row( + a=x.n**2 - x.m**2, b=2 * x.m * x.n, c=x.n**2 + x.m**2)) + + with beam.Pipeline() as p: + pc_mn = p | beam.Create([ + (1, 2), (2, 3), (3, 10) + ]) | beam.MapTuple(lambda m, n: beam.Row(m=m, n=n)) + + df_mn = convert.to_dataframe(pc_mn) + + # Apply a transform directly to a dataframe to get another dataframe. + df_abc = df_mn | MySchemaTransform() + + pc_abc = convert.to_pcollection(df_abc) | beam.Map(tuple) + assert_that(pc_abc, equal_to([(3, 4, 5), (5, 12, 13), (91, 60, 109)])) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index ac3dd36c98ad..3020eecbaeb5 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -49,10 +49,12 @@ from pandas.api.types import is_list_like from pandas.core.groupby.generic import DataFrameGroupBy +from apache_beam.dataframe import convert from apache_beam.dataframe import expressions from apache_beam.dataframe import frame_base from apache_beam.dataframe import io from apache_beam.dataframe import partitionings +from apache_beam.transforms import PTransform __all__ = [ 'DeferredSeries', @@ -5394,6 +5396,24 @@ def func(df, *args, **kwargs): frame_base._elementwise_method(inplace_name, inplace=True, base=pd.DataFrame)) +# Allow dataframe | SchemaTransform +def _create_maybe_elementwise_or(base): + elementwise = frame_base._elementwise_method( + '__or__', restrictions={'level': None}, base=base) + + def _maybe_elementwise_or(self, right): + if isinstance(right, PTransform): + return convert.to_dataframe(convert.to_pcollection(self) | right) + else: + return elementwise(self, right) + + return _maybe_elementwise_or + + +DeferredSeries.__or__ = _create_maybe_elementwise_or(pd.Series) # type: ignore +DeferredDataFrame.__or__ = _create_maybe_elementwise_or(pd.DataFrame) # type: ignore + + for name in ['lt', 'le', 'gt', 'ge', 'eq', 'ne']: for p in '%s', '__%s__': # Note that non-underscore name is used for both as the __xxx__ methods are