Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(airbyte-cdk): add global_state => per_partition transformation #45122

Merged
merged 6 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@

from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
from airbyte_cdk.utils import AirbyteTracedException


class CursorFactory:
Expand Down Expand Up @@ -45,6 +43,7 @@ class PerPartitionCursor(DeclarativeCursor):
_NO_CURSOR_STATE: Mapping[str, Any] = {}
_KEY = 0
_VALUE = 1
_state_to_migrate_from: Mapping[str, Any] = {}

def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter):
self._cursor_factory = cursor_factory
Expand All @@ -57,7 +56,8 @@ def stream_slices(self) -> Iterable[StreamSlice]:
for partition in slices:
cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
if not cursor:
cursor = self._create_cursor(self._NO_CURSOR_STATE)
partition_state = self._state_to_migrate_from if self._state_to_migrate_from else self._NO_CURSOR_STATE
cursor = self._create_cursor(partition_state)
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor

for cursor_slice in cursor.stream_slices():
Expand Down Expand Up @@ -97,15 +97,13 @@ def set_initial_state(self, stream_state: StreamState) -> None:
return

if "states" not in stream_state:
raise AirbyteTracedException(
internal_message=f"Could not sync parse the following state: {stream_state}",
message="The state for is format invalid. Validate that the migration steps included a reset and that it was performed "
"properly. Otherwise, please contact Airbyte support.",
failure_type=FailureType.config_error,
)
# We assume that `stream_state` is in a global format that can be applied to all partitions.
# Example: {"global_state_format_key": "global_state_format_value"}
self._state_to_migrate_from = stream_state

for state in stream_state["states"]:
self._cursor_per_partition[self._to_partition_key(state["partition"])] = self._create_cursor(state["cursor"])
else:
for state in stream_state["states"]:
self._cursor_per_partition[self._to_partition_key(state["partition"])] = self._create_cursor(state["cursor"])

# Set parent state for partition routers based on parent streams
self._partition_router.set_initial_state(stream_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from unittest.mock import Mock

import pytest
from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, PerPartitionKeySerializer, StreamSlice
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.types import Record
from airbyte_cdk.utils import AirbyteTracedException

PARTITION = {
"partition_key string": "partition value",
Expand Down Expand Up @@ -519,10 +517,37 @@ def test_get_stream_state_includes_parent_state(mocked_cursor_factory, mocked_pa
assert stream_state == expected_state


def test_given_invalid_state_when_set_initial_state_then_raise_config_error(mocked_cursor_factory, mocked_partition_router) -> None:
cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router)

with pytest.raises(AirbyteTracedException) as exception:
cursor.set_initial_state({"invalid_state": 1})
def test_per_partition_state_when_set_initial_global_state(mocked_cursor_factory, mocked_partition_router) -> None:
first_partition = {"first_partition_key": "first_partition_value"}
second_partition = {"second_partition_key": "second_partition_value"}
global_state = {"global_state_format_key": "global_state_format_value"}

assert exception.value.failure_type == FailureType.config_error
mocked_partition_router.stream_slices.return_value = [
StreamSlice(partition=first_partition, cursor_slice={}),
StreamSlice(partition=second_partition, cursor_slice={}),
]
mocked_cursor_factory.create.side_effect = [
MockedCursorBuilder().with_stream_state(global_state).build(),
MockedCursorBuilder().with_stream_state(global_state).build(),
]
cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router)
global_state = {"global_state_format_key": "global_state_format_value"}
cursor.set_initial_state(global_state)
assert cursor._state_to_migrate_from == global_state
list(cursor.stream_slices())
assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_count == 1
assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_args[0] == (
{"global_state_format_key": "global_state_format_value"},
)
assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_count == 1
assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_args[0] == (
{"global_state_format_key": "global_state_format_value"},
)
expected_state = [
{"cursor": {"global_state_format_key": "global_state_format_value"}, "partition": {"first_partition_key": "first_partition_value"}},
{
"cursor": {"global_state_format_key": "global_state_format_value"},
"partition": {"second_partition_key": "second_partition_value"},
},
]
assert cursor.get_stream_state()["states"] == expected_state
Loading