Skip to content

Commit

Permalink
State manager test, rational range behavior throughout (#44945)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Sep 3, 2024
1 parent 2073513 commit 095d30e
Show file tree
Hide file tree
Showing 6 changed files with 629 additions and 132 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.message

import io.airbyte.protocol.models.v0.AirbyteGlobalState
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.AirbyteStateMessage
import io.airbyte.protocol.models.v0.AirbyteStateStats
import io.airbyte.protocol.models.v0.AirbyteStreamState
import io.airbyte.protocol.models.v0.StreamDescriptor
import jakarta.inject.Singleton

/**
* Converts the internal @[DestinationStateMessage] case class to the Protocol state messages
* required by @[io.airbyte.cdk.output.OutputConsumer]
*/
interface MessageConverter<T, U> {
fun from(message: T): U
}

@Singleton
class DefaultMessageConverter : MessageConverter<DestinationStateMessage, AirbyteMessage> {
override fun from(message: DestinationStateMessage): AirbyteMessage {
val state =
when (message) {
is DestinationStreamState ->
AirbyteStateMessage()
.withSourceStats(
AirbyteStateStats()
.withRecordCount(message.sourceStats.recordCount.toDouble())
)
.withDestinationStats(
message.destinationStats?.let {
AirbyteStateStats().withRecordCount(it.recordCount.toDouble())
}
?: throw IllegalStateException(
"Destination stats must be provided for DestinationStreamState"
)
)
.withType(AirbyteStateMessage.AirbyteStateType.STREAM)
.withStream(fromStreamState(message.streamState))
is DestinationGlobalState ->
AirbyteStateMessage()
.withSourceStats(
AirbyteStateStats()
.withRecordCount(message.sourceStats.recordCount.toDouble())
)
.withDestinationStats(
message.destinationStats?.let {
AirbyteStateStats().withRecordCount(it.recordCount.toDouble())
}
)
.withType(AirbyteStateMessage.AirbyteStateType.GLOBAL)
.withGlobal(
AirbyteGlobalState()
.withSharedState(message.state)
.withStreamStates(message.streamStates.map { fromStreamState(it) })
)
}
return AirbyteMessage().withState(state)
}

private fun fromStreamState(
streamState: DestinationStateMessage.StreamState
): AirbyteStreamState {
return AirbyteStreamState()
.withStreamDescriptor(
StreamDescriptor()
.withNamespace(streamState.stream.descriptor.namespace)
.withName(streamState.stream.descriptor.name)
)
.withStreamState(streamState.state)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DestinationMessageQueueWriter(
private val catalog: DestinationCatalog,
private val messageQueue: MessageQueue<DestinationStream, DestinationRecordWrapped>,
private val streamsManager: StreamsManager,
private val stateManager: StateManager
private val stateManager: StateManager<DestinationStream, DestinationStateMessage>
) : MessageQueueWriter<DestinationMessage> {
/**
* Deserialize and route the message to the appropriate channel.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,29 @@ package io.airbyte.cdk.state

import io.airbyte.cdk.command.DestinationCatalog
import io.airbyte.cdk.command.DestinationStream
import io.airbyte.cdk.message.AirbyteStateMessageFactory
import io.airbyte.cdk.message.DestinationStateMessage
import io.airbyte.cdk.output.OutputConsumer
import io.airbyte.cdk.message.MessageConverter
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.core.util.clhm.ConcurrentLinkedHashMap
import jakarta.inject.Singleton
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicReference
import java.util.function.Consumer

/**
* Interface for state management. Should accept stream and global state, as well as requests to
* flush all data-sufficient states.
*/
interface StateManager {
fun addStreamState(
stream: DestinationStream,
index: Long,
stateMessage: DestinationStateMessage
)
fun addGlobalState(
streamIndexes: List<Pair<DestinationStream, Long>>,
stateMessage: DestinationStateMessage
)
interface StateManager<K, T> {
fun addStreamState(key: K, index: Long, stateMessage: T)
fun addGlobalState(keyIndexes: List<Pair<K, Long>>, stateMessage: T)
fun flushStates()
}

/**
* Destination state manager.
* Message-type agnostic streams state manager.
*
* Accepts global and stream states, and enforces that stream and global state are not mixed.
* Determines ready states by querying the StreamsManager for the state of the record index range
Expand All @@ -44,50 +39,73 @@ interface StateManager {
* TODO: Ensure that state is flushed at the end, and require that all state be flushed before the
* destination can succeed.
*/
@Singleton
class DefaultStateManager(
private val catalog: DestinationCatalog,
private val streamsManager: StreamsManager,
private val stateMessageFactory: AirbyteStateMessageFactory,
private val outputConsumer: OutputConsumer
) : StateManager {
abstract class StreamsStateManager<T, U>() : StateManager<DestinationStream, T> {
private val log = KotlinLogging.logger {}

data class GlobalState(
abstract val catalog: DestinationCatalog
abstract val streamsManager: StreamsManager
abstract val outputFactory: MessageConverter<T, U>
abstract val outputConsumer: Consumer<U>

data class GlobalState<T>(
val streamIndexes: List<Pair<DestinationStream, Long>>,
val stateMessage: DestinationStateMessage
val stateMessage: T
)

private val stateIsGlobal: AtomicReference<Boolean?> = AtomicReference(null)
private val streamStates:
ConcurrentHashMap<DestinationStream, LinkedHashMap<Long, DestinationStateMessage>> =
ConcurrentHashMap<DestinationStream, ConcurrentLinkedHashMap<Long, T>> =
ConcurrentHashMap()
private val globalStates: ConcurrentLinkedQueue<GlobalState> = ConcurrentLinkedQueue()

override fun addStreamState(
stream: DestinationStream,
index: Long,
stateMessage: DestinationStateMessage
) {
if (stateIsGlobal.getAndSet(false) != false) {
private val globalStates: ConcurrentLinkedQueue<GlobalState<T>> = ConcurrentLinkedQueue()

override fun addStreamState(key: DestinationStream, index: Long, stateMessage: T) {
if (stateIsGlobal.updateAndGet { it == true } != false) {
throw IllegalStateException("Global state cannot be mixed with non-global state")
}

val streamStates = streamStates.getOrPut(stream) { LinkedHashMap() }
streamStates[index] = stateMessage
log.info { "Added state for stream: $stream at index: $index" }
streamStates.compute(key) { _, indexToMessage ->
val map =
if (indexToMessage == null) {
// If the map doesn't exist yet, build it.
ConcurrentLinkedHashMap.Builder<Long, T>().maximumWeightedCapacity(1000).build()
} else {
if (indexToMessage.isNotEmpty()) {
// Make sure the messages are coming in order
val oldestIndex = indexToMessage.ascendingKeySet().first()
if (oldestIndex > index) {
throw IllegalStateException(
"State message received out of order ($oldestIndex before $index)"
)
}
}
indexToMessage
}
// Actually add the message
map[index] = stateMessage
map
}

log.info { "Added state for stream: $key at index: $index" }
}

override fun addGlobalState(
streamIndexes: List<Pair<DestinationStream, Long>>,
stateMessage: DestinationStateMessage
) {
if (stateIsGlobal.getAndSet(true) != true) {
// TODO: Is it an error if we don't get all the streams every time?
override fun addGlobalState(keyIndexes: List<Pair<DestinationStream, Long>>, stateMessage: T) {
if (stateIsGlobal.updateAndGet { it != false } != true) {
throw IllegalStateException("Global state cannot be mixed with non-global state")
}

globalStates.add(GlobalState(streamIndexes, stateMessage))
log.info { "Added global state with stream indexes: $streamIndexes" }
val head = globalStates.peek()
if (head != null) {
val keyIndexesByStream = keyIndexes.associate { it.first to it.second }
head.streamIndexes.forEach {
if (keyIndexesByStream[it.first]!! < it.second) {
throw IllegalStateException("Global state message received out of order")
}
}
}

globalStates.add(GlobalState(keyIndexes, stateMessage))
log.info { "Added global state with stream indexes: $keyIndexes" }
}

override fun flushStates() {
Expand All @@ -105,19 +123,19 @@ class DefaultStateManager(
}

private fun flushGlobalStates() {
if (globalStates.isEmpty()) {
return
}

val head = globalStates.peek()
val allStreamsPersisted =
head.streamIndexes.all { (stream, index) ->
streamsManager.getManager(stream).areRecordsPersistedUntil(index)
while (!globalStates.isEmpty()) {
val head = globalStates.peek()
val allStreamsPersisted =
head.streamIndexes.all { (stream, index) ->
streamsManager.getManager(stream).areRecordsPersistedUntil(index)
}
if (allStreamsPersisted) {
globalStates.poll()
val outMessage = outputFactory.from(head.stateMessage)
outputConsumer.accept(outMessage)
} else {
break
}
if (allStreamsPersisted) {
globalStates.poll()
val outMessage = stateMessageFactory.fromDestinationStateMessage(head.stateMessage)
outputConsumer.accept(outMessage)
}
}

Expand All @@ -131,7 +149,7 @@ class DefaultStateManager(
streamStates.remove(index)
?: throw IllegalStateException("State not found for index: $index")
log.info { "Flushing state for stream: $stream at index: $index" }
val outMessage = stateMessageFactory.fromDestinationStateMessage(stateMessage)
val outMessage = outputFactory.from(stateMessage)
outputConsumer.accept(outMessage)
} else {
break
Expand All @@ -140,3 +158,11 @@ class DefaultStateManager(
}
}
}

@Singleton
class DefaultStateManager(
override val catalog: DestinationCatalog,
override val streamsManager: StreamsManager,
override val outputFactory: MessageConverter<DestinationStateMessage, AirbyteMessage>,
override val outputConsumer: Consumer<AirbyteMessage>
) : StreamsStateManager<DestinationStateMessage, AirbyteMessage>()
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,23 @@ class DefaultStreamManager(
rangesState[batch.batch.state]
?: throw IllegalArgumentException("Invalid batch state: ${batch.batch.state}")

stateRanges.addAll(batch.ranges)
// Force the ranges to overlap at their endpoints, in order to work around
// the behavior of `.encloses`, which otherwise would not consider adjacent ranges as
// contiguous.
// This ensures that a state message received at eg, index 10 (after messages 0..9 have
// been received), will pass `{'[0..5]','[6..9]'}.encloses('[0..10)')`.
val expanded =
batch.ranges.asRanges().map { it.span(Range.singleton(it.upperEndpoint() + 1)) }

stateRanges.addAll(expanded)
log.info { "Updated ranges for $stream[${batch.batch.state}]: $stateRanges" }
}

/** True if all records in [0, index] have reached the given state. */
private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean {

val completeRanges = rangesState[state]!!
return completeRanges.encloses(Range.closed(0L, index - 1))
return completeRanges.encloses(Range.closedOpen(0L, index))
}

/** True if all records have associated [Batch.State.COMPLETE] batches. */
Expand Down
Loading

0 comments on commit 095d30e

Please sign in to comment.