diff --git a/sdks/python/apache_beam/dataframe/frame_base.py b/sdks/python/apache_beam/dataframe/frame_base.py index 1da12ececff6..48a4c29d0589 100644 --- a/sdks/python/apache_beam/dataframe/frame_base.py +++ b/sdks/python/apache_beam/dataframe/frame_base.py @@ -500,6 +500,8 @@ def wrap(func): removed_arg_names = removed_args if removed_args is not None else [] + # We would need to add position only arguments if they ever become a thing + # in Pandas (as of 2.1 currently they aren't). base_arg_spec = getfullargspec(unwrap(getattr(base_type, func.__name__))) base_arg_names = base_arg_spec.args # Some arguments are keyword only and we still want to check against those. @@ -514,6 +516,9 @@ def wrap(func): @functools.wraps(func) def wrapper(*args, **kwargs): + if len(args) > len(base_arg_names): + raise TypeError(f"{func.__name__} got too many positioned arguments.") + for name, value in zip(base_arg_names, args): if name in kwargs: raise TypeError( @@ -523,7 +528,7 @@ def wrapper(*args, **kwargs): # Still have to populate these for the Beam function signature. if removed_args: for name in removed_args: - if not name in kwargs: + if name not in kwargs: kwargs[name] = None return func(**kwargs) @@ -646,13 +651,18 @@ def wrap(func): return func base_argspec = getfullargspec(unwrap(getattr(base_type, func.__name__))) - if not base_argspec.defaults: + if not base_argspec.defaults and not base_argspec.kwonlydefaults: return func - arg_to_default = dict( - zip( - base_argspec.args[-len(base_argspec.defaults):], - base_argspec.defaults)) + arg_to_default = {} + if base_argspec.defaults: + arg_to_default.update( + zip( + base_argspec.args[-len(base_argspec.defaults):], + base_argspec.defaults)) + + if base_argspec.kwonlydefaults: + arg_to_default.update(base_argspec.kwonlydefaults) unwrapped_func = unwrap(func) # args that do not have defaults in func, but do have defaults in base diff --git a/sdks/python/apache_beam/dataframe/frame_base_test.py b/sdks/python/apache_beam/dataframe/frame_base_test.py index 2d16d02ba1ea..b3077320720f 100644 --- a/sdks/python/apache_beam/dataframe/frame_base_test.py +++ b/sdks/python/apache_beam/dataframe/frame_base_test.py @@ -72,7 +72,7 @@ def add_one(frame): def test_args_to_kwargs(self): class Base(object): - def func(self, a=1, b=2, c=3): + def func(self, a=1, b=2, c=3, *, kw_only=4): pass class Proxy(object): @@ -87,6 +87,9 @@ def func(self, **kwargs): self.assertEqual(proxy.func(2, 4, 6), {'a': 2, 'b': 4, 'c': 6}) self.assertEqual(proxy.func(2, c=6), {'a': 2, 'c': 6}) self.assertEqual(proxy.func(c=6, a=2), {'a': 2, 'c': 6}) + self.assertEqual(proxy.func(2, kw_only=20), {'a': 2, 'kw_only': 20}) + with self.assertRaises(TypeError): # got too many positioned arguments + proxy.func(2, 4, 6, 8) def test_args_to_kwargs_populates_defaults(self): class Base(object): @@ -129,6 +132,48 @@ def func_removed_args(self, a, c, **kwargs): proxy.func_removed_args() self.assertEqual(proxy.func_removed_args(12, d=100), {'a': 12, 'd': 100}) + def test_args_to_kwargs_populates_default_handles_kw_only(self): + class Base(object): + def func(self, a, b=2, c=3, *, kw_only=4): + pass + + class ProxyUsesKwOnly(object): + @frame_base.args_to_kwargs(Base) + @frame_base.populate_defaults(Base) + def func(self, a, kw_only, **kwargs): + return dict(kwargs, a=a, kw_only=kw_only) + + proxy = ProxyUsesKwOnly() + + # pylint: disable=too-many-function-args,no-value-for-parameter + with self.assertRaises(TypeError): # missing 1 required positional argument + proxy.func() + + self.assertEqual(proxy.func(100), {'a': 100, 'kw_only': 4}) + self.assertEqual( + proxy.func(2, 4, 6, kw_only=8), { + 'a': 2, 'b': 4, 'c': 6, 'kw_only': 8 + }) + with self.assertRaises(TypeError): + proxy.func(2, 4, 6, 8) # got too many positioned arguments + + class ProxyDoesntUseKwOnly(object): + @frame_base.args_to_kwargs(Base) + @frame_base.populate_defaults(Base) + def func(self, a, **kwargs): + return dict(kwargs, a=a) + + proxy = ProxyDoesntUseKwOnly() + + # pylint: disable=too-many-function-args,no-value-for-parameter + with self.assertRaises(TypeError): # missing 1 required positional argument + proxy.func() + self.assertEqual(proxy.func(100), {'a': 100}) + self.assertEqual( + proxy.func(2, 4, 6, kw_only=8), { + 'a': 2, 'b': 4, 'c': 6, 'kw_only': 8 + }) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index e2390bda28be..a7a278fbf62a 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -907,7 +907,7 @@ def sort_index(self, axis, **kwargs): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( 'sort_index', - lambda df: df.sort_index(axis, **kwargs), + lambda df: df.sort_index(axis=axis, **kwargs), [self._expr], requires_partition_by=partitionings.Arbitrary(), preserves_partition_by=partitionings.Arbitrary(), @@ -2687,7 +2687,7 @@ def set_axis(self, labels, axis, **kwargs): return frame_base.DeferredFrame.wrap( expressions.ComputedExpression( 'set_axis', - lambda df: df.set_axis(labels, axis, **kwargs), + lambda df: df.set_axis(labels, axis=axis, **kwargs), [self._expr], requires_partition_by=partitionings.Arbitrary(), preserves_partition_by=partitionings.Arbitrary()))