Skip to content

Commit

Permalink
Fix custom coders not being used in Reshuffle (non global window) (#3…
Browse files Browse the repository at this point in the history
…3363)

* Fix typehint in ReshufflePerKey on global window setting.

* Only update the type hint on global window setting. Need more work in non-global windows.

* Apply yapf

* Fix some failed tests.

* Revert change to setup.py

* Fix custom coders not being used in reshuffle in non-global windows

* Revert changes in setup.py. Reformat.

* Make WindowedValue a generic class. Support its conversion to the correct type constraint in Beam.

* Cython does not support Python generic class. Add a subclass as a workroundand keep it un-cythonized.

* Add comments

* Fix type error.

* Remove the base class of WindowedValue in TypedWindowedValue.

* Move TypedWindowedValue out from windowed_value.py

* Revise the comments

* Fix the module location when matching.

* Fix test failure where __name__ of a type alias not found in python 3.9

* Add a note about the window coder.

---------

Co-authored-by: Robert Bradshaw <[email protected]>
  • Loading branch information
shunping and robertwb authored Dec 14, 2024
1 parent f92dde1 commit 2aab9cd
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 21 deletions.
8 changes: 8 additions & 0 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,14 @@ def __hash__(self):
return hash(
(self.wrapped_value_coder, self.timestamp_coder, self.window_coder))

@classmethod
def from_type_hint(cls, typehint, registry):
# type: (Any, CoderRegistry) -> WindowedValueCoder
# Ideally this'd take two parameters so that one could hint at
# the window type as well instead of falling back to the
# pickle coders.
return cls(registry.get_coder(typehint.inner_type))


Coder.register_structured_urn(
common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder)
Expand Down
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/coders/typecoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def register_standard_coders(self, fallback_coder):
self._register_coder_internal(str, coders.StrUtf8Coder)
self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
self._register_coder_internal(
typehints.WindowedTypeConstraint, coders.WindowedValueCoder)
# Default fallback coders applied in that order until the first matching
# coder found.
default_fallback_coders = [
Expand Down
6 changes: 3 additions & 3 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.decorators import get_signature
from apache_beam.typehints.native_type_compatibility import TypedWindowedValue
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import shared
from apache_beam.utils import windowed_value
Expand Down Expand Up @@ -972,9 +973,8 @@ def restore_timestamps(element):
key, windowed_values = element
return [wv.with_value((key, wv.value)) for wv in windowed_values]

# TODO(https://github.com/apache/beam/issues/33356): Support reshuffling
# unpicklable objects with a non-global window setting.
ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
ungrouped = pcoll | Map(reify_timestamps).with_input_types(
Tuple[K, V]).with_output_types(Tuple[K, TypedWindowedValue[V]])

# TODO(https://github.com/apache/beam/issues/19785) Using global window as
# one of the standard window. This is to mitigate the Dataflow Java Runner
Expand Down
51 changes: 33 additions & 18 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,32 +1010,33 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
equal_to(expected_data),
label="formatted_after_reshuffle")

def test_reshuffle_unpicklable_in_global_window(self):
global _Unpicklable
global _Unpicklable
global _UnpicklableCoder

class _Unpicklable(object):
def __init__(self, value):
self.value = value
class _Unpicklable(object):
def __init__(self, value):
self.value = value

def __getstate__(self):
raise NotImplementedError()
def __getstate__(self):
raise NotImplementedError()

def __setstate__(self, state):
raise NotImplementedError()
def __setstate__(self, state):
raise NotImplementedError()

class _UnpicklableCoder(beam.coders.Coder):
def encode(self, value):
return str(value.value).encode()
class _UnpicklableCoder(beam.coders.Coder):
def encode(self, value):
return str(value.value).encode()

def decode(self, encoded):
return _Unpicklable(int(encoded.decode()))
def decode(self, encoded):
return _Unpicklable(int(encoded.decode()))

def to_type_hint(self):
return _Unpicklable
def to_type_hint(self):
return _Unpicklable

def is_deterministic(self):
return True
def is_deterministic(self):
return True

def test_reshuffle_unpicklable_in_global_window(self):
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

with TestPipeline() as pipeline:
Expand All @@ -1049,6 +1050,20 @@ def is_deterministic(self):
| beam.Map(lambda u: u.value * 10))
assert_that(result, equal_to(expected_data))

def test_reshuffle_unpicklable_in_non_global_window(self):
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

with TestPipeline() as pipeline:
data = [_Unpicklable(i) for i in range(5)]
expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40]
result = (
pipeline
| beam.Create(data)
| beam.WindowInto(window.SlidingWindows(size=3, period=1))
| beam.Reshuffle()
| beam.Map(lambda u: u.value * 10))
assert_that(result, equal_to(expected_data))


class WithKeysTest(unittest.TestCase):
def setUp(self):
Expand Down
26 changes: 26 additions & 0 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
import sys
import types
import typing
from typing import Generic
from typing import TypeVar

from apache_beam.typehints import typehints

T = TypeVar('T')

_LOGGER = logging.getLogger(__name__)

# Describes an entry in the type map in convert_to_beam_type.
Expand Down Expand Up @@ -216,6 +220,18 @@ def convert_collections_to_typing(typ):
return typ


# During type inference of WindowedValue, we need to pass in the inner value
# type. This cannot be achieved immediately with WindowedValue class because it
# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T])
# could work in theory. However, the class is cythonized and it seems that
# cython does not handle generic classes well.
# The workaround here is to create a separate class solely for the type
# inference purpose. This class should never be used for creating instances.
class TypedWindowedValue(Generic[T]):
def __init__(self, *args, **kwargs):
raise NotImplementedError("This class is solely for type inference")


def convert_to_beam_type(typ):
"""Convert a given typing type to a Beam type.
Expand Down Expand Up @@ -267,6 +283,12 @@ def convert_to_beam_type(typ):
# TODO(https://github.com/apache/beam/issues/20076): Currently unhandled.
_LOGGER.info('Converting NewType type hint to Any: "%s"', typ)
return typehints.Any
elif typ_module == 'apache_beam.typehints.native_type_compatibility' and \
getattr(typ, "__name__", typ.__origin__.__name__) == 'TypedWindowedValue':
# Need to pass through WindowedValue class so that it can be converted
# to the correct type constraint in Beam
# This is needed to fix https://github.com/apache/beam/issues/33356
pass
elif (typ_module != 'typing') and (typ_module != 'collections.abc'):
# Only translate types from the typing and collections.abc modules.
return typ
Expand Down Expand Up @@ -324,6 +346,10 @@ def convert_to_beam_type(typ):
match=_match_is_exactly_collection,
arity=1,
beam_type=typehints.Collection),
_TypeMapEntry(
match=_match_issubclass(TypedWindowedValue),
arity=1,
beam_type=typehints.WindowedValue),
]

# Find the first matching entry.
Expand Down
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,15 @@ def type_check(self, instance):
repr(self.inner_type),
instance.value.__class__.__name__))

def bind_type_variables(self, bindings):
bound_inner_type = bind_type_variables(self.inner_type, bindings)
if bound_inner_type == self.inner_type:
return self
return WindowedValue[bound_inner_type]

def __repr__(self):
return 'WindowedValue[%s]' % repr(self.inner_type)


class GeneratorHint(IteratorHint):
"""A Generator type hint.
Expand Down

0 comments on commit 2aab9cd

Please sign in to comment.