diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index f615290cccf9..76a3c324d1ec 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -329,6 +329,7 @@ def encode(self, value): DATACLASS_TYPE = 101 NAMED_TUPLE_TYPE = 102 ENUM_TYPE = 103 +NESTED_STATE_TYPE = 104 # Types that can be encoded as iterables, but are not literally # lists, etc. due to being lazy. The actual type is not preserved @@ -455,6 +456,17 @@ def encode_special_deterministic(self, value, stream): self.encode_type(type(value), stream) # Enum values can be of any type. self.encode_to_stream(value.value, stream, True) + elif hasattr(value, "__getstate__"): + if not hasattr(value, "__setstate__"): + raise TypeError( + "Unable to deterministically encode '%s' of type '%s', " + "for the input of '%s'. The object defines __getstate__ but not " + "__setstate__." % + (value, type(value), self.requires_deterministic_step_label)) + stream.write_byte(NESTED_STATE_TYPE) + self.encode_type(type(value), stream) + state_value = value.__getstate__() + self.encode_to_stream(state_value, stream, True) else: raise TypeError( "Unable to deterministically encode '%s' of type '%s', " @@ -510,6 +522,12 @@ def decode_from_stream(self, stream, nested): elif t == ENUM_TYPE: cls = self.decode_type(stream) return cls(self.decode_from_stream(stream, True)) + elif t == NESTED_STATE_TYPE: + cls = self.decode_type(stream) + state = self.decode_from_stream(stream, True) + value = cls.__new__(cls) + value.__setstate__(state) + return value elif t == UNKNOWN_TYPE: return self.fallback_coder_impl.decode_from_stream(stream, nested) else: diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 17825a9bf2c7..209ee5596789 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -69,6 +69,22 @@ class MyEnum(enum.Enum): MyFlag = enum.Flag('MyFlag', 'F1 F2 F3') # pylint: disable=too-many-function-args +class DefinesGetState: + def __init__(self, value): + self.value = value + + def __getstate__(self): + return self.value + + def __eq__(self, other): + return type(other) is type(self) and other.value == self.value + + +class DefinesGetAndSetState(DefinesGetState): + def __setstate__(self, value): + self.value = value + + # Defined out of line for picklability. class CustomCoder(coders.Coder): def encode(self, x): @@ -236,6 +252,13 @@ def test_deterministic_coder(self): self.check_coder(deterministic_coder, list(MyIntFlag)) self.check_coder(deterministic_coder, list(MyFlag)) + self.check_coder( + deterministic_coder, + [DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))]) + + with self.assertRaises(TypeError): + self.check_coder(deterministic_coder, DefinesGetState(1)) + def test_dill_coder(self): cell_value = (lambda x: lambda: x)(0).__closure__[0] self.check_coder(coders.DillCoder(), 'a', 1, cell_value)