Skip to content

Commit

Permalink
Use __getstate__ & __setstate__ for deterministic coding
Browse files Browse the repository at this point in the history
`__getstate__` and `__setstate__` are one of several Python protocols
that can be used to implement pickling for a custom type. If the return
value from calling `obj.__getstate__()` on a user-provided object is
deterministic, then we can use it to deterministically code the object.

This provides a low-boilerplate and robust way to make a custom object
suitable for use as a key in Beam transforms.

Note that we don't need to worry about these protocols in the
non-deterministic case, because the fall-back coder already handles
these via pickle.
  • Loading branch information
shoyer committed Apr 29, 2021
1 parent 7418c84 commit b36b8e5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
18 changes: 18 additions & 0 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def encode(self, value):
DATACLASS_TYPE = 101
NAMED_TUPLE_TYPE = 102
ENUM_TYPE = 103
NESTED_STATE_TYPE = 104

# Types that can be encoded as iterables, but are not literally
# lists, etc. due to being lazy. The actual type is not preserved
Expand Down Expand Up @@ -455,6 +456,17 @@ def encode_special_deterministic(self, value, stream):
self.encode_type(type(value), stream)
# Enum values can be of any type.
self.encode_to_stream(value.value, stream, True)
elif hasattr(value, "__getstate__"):
if not hasattr(value, "__setstate__"):
raise TypeError(
"Unable to deterministically encode '%s' of type '%s', "
"for the input of '%s'. The object defines __getstate__ but not "
"__setstate__." %
(value, type(value), self.requires_deterministic_step_label))
stream.write_byte(NESTED_STATE_TYPE)
self.encode_type(type(value), stream)
state_value = value.__getstate__()
self.encode_to_stream(state_value, stream, True)
else:
raise TypeError(
"Unable to deterministically encode '%s' of type '%s', "
Expand Down Expand Up @@ -510,6 +522,12 @@ def decode_from_stream(self, stream, nested):
elif t == ENUM_TYPE:
cls = self.decode_type(stream)
return cls(self.decode_from_stream(stream, True))
elif t == NESTED_STATE_TYPE:
cls = self.decode_type(stream)
state = self.decode_from_stream(stream, True)
value = cls.__new__(cls)
value.__setstate__(state)
return value
elif t == UNKNOWN_TYPE:
return self.fallback_coder_impl.decode_from_stream(stream, nested)
else:
Expand Down
23 changes: 23 additions & 0 deletions sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ class MyEnum(enum.Enum):
MyFlag = enum.Flag('MyFlag', 'F1 F2 F3') # pylint: disable=too-many-function-args


class DefinesGetState:
def __init__(self, value):
self.value = value

def __getstate__(self):
return self.value

def __eq__(self, other):
return type(other) is type(self) and other.value == self.value


class DefinesGetAndSetState(DefinesGetState):
def __setstate__(self, value):
self.value = value


# Defined out of line for picklability.
class CustomCoder(coders.Coder):
def encode(self, x):
Expand Down Expand Up @@ -236,6 +252,13 @@ def test_deterministic_coder(self):
self.check_coder(deterministic_coder, list(MyIntFlag))
self.check_coder(deterministic_coder, list(MyFlag))

self.check_coder(
deterministic_coder,
[DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))])

with self.assertRaises(TypeError):
self.check_coder(deterministic_coder, DefinesGetState(1))

def test_dill_coder(self):
cell_value = (lambda x: lambda: x)(0).__closure__[0]
self.check_coder(coders.DillCoder(), 'a', 1, cell_value)
Expand Down

0 comments on commit b36b8e5

Please sign in to comment.