diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/AirbyteStateMessageFactory.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/AirbyteStateMessageFactory.kt deleted file mode 100644 index ed88daa07553..000000000000 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/AirbyteStateMessageFactory.kt +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2024 Airbyte, Inc., all rights reserved. - */ - -package io.airbyte.cdk.message - -import io.airbyte.protocol.models.v0.AirbyteGlobalState -import io.airbyte.protocol.models.v0.AirbyteStateMessage -import io.airbyte.protocol.models.v0.AirbyteStateStats -import io.airbyte.protocol.models.v0.AirbyteStreamState -import io.airbyte.protocol.models.v0.StreamDescriptor -import jakarta.inject.Singleton - -/** - * Converts the internal @[DestinationStateMessage] case class to the Protocol state messages - * required by @[io.airbyte.cdk.output.OutputConsumer] - */ -interface AirbyteStateMessageFactory { - fun fromDestinationStateMessage(message: DestinationStateMessage): AirbyteStateMessage -} - -@Singleton -class DefaultAirbyteStateMessageFactory : AirbyteStateMessageFactory { - override fun fromDestinationStateMessage( - message: DestinationStateMessage - ): AirbyteStateMessage { - return when (message) { - is DestinationStreamState -> - AirbyteStateMessage() - .withSourceStats( - AirbyteStateStats() - .withRecordCount(message.sourceStats.recordCount.toDouble()) - ) - .withDestinationStats( - message.destinationStats?.let { - AirbyteStateStats().withRecordCount(it.recordCount.toDouble()) - } - ?: throw IllegalStateException( - "Destination stats must be provided for DestinationStreamState" - ) - ) - .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(fromStreamState(message.streamState)) - is DestinationGlobalState -> - AirbyteStateMessage() - .withSourceStats( - AirbyteStateStats() - .withRecordCount(message.sourceStats.recordCount.toDouble()) - ) - .withDestinationStats( - message.destinationStats?.let { - AirbyteStateStats().withRecordCount(it.recordCount.toDouble()) - } - ) - .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) - .withGlobal( - AirbyteGlobalState() - .withSharedState(message.state) - .withStreamStates(message.streamStates.map { fromStreamState(it) }) - ) - } - } - - private fun fromStreamState( - streamState: DestinationStateMessage.StreamState - ): AirbyteStreamState { - return AirbyteStreamState() - .withStreamDescriptor( - StreamDescriptor() - .withNamespace(streamState.stream.descriptor.namespace) - .withName(streamState.stream.descriptor.name) - ) - .withStreamState(streamState.state) - } -} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt new file mode 100644 index 000000000000..97d40900b210 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.message + +import io.airbyte.protocol.models.v0.AirbyteGlobalState +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStateStats +import io.airbyte.protocol.models.v0.AirbyteStreamState +import io.airbyte.protocol.models.v0.StreamDescriptor +import jakarta.inject.Singleton + +/** + * Converts the internal @[DestinationStateMessage] case class to the Protocol state messages + * required by @[io.airbyte.cdk.output.OutputConsumer] + */ +interface MessageConverter { + fun from(message: T): U +} + +@Singleton +class DefaultMessageConverter : MessageConverter { + override fun from(message: DestinationStateMessage): AirbyteMessage { + val state = + when (message) { + is DestinationStreamState -> + AirbyteStateMessage() + .withSourceStats( + AirbyteStateStats() + .withRecordCount(message.sourceStats.recordCount.toDouble()) + ) + .withDestinationStats( + message.destinationStats?.let { + AirbyteStateStats().withRecordCount(it.recordCount.toDouble()) + } + ?: throw IllegalStateException( + "Destination stats must be provided for DestinationStreamState" + ) + ) + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(fromStreamState(message.streamState)) + is DestinationGlobalState -> + AirbyteStateMessage() + .withSourceStats( + AirbyteStateStats() + .withRecordCount(message.sourceStats.recordCount.toDouble()) + ) + .withDestinationStats( + message.destinationStats?.let { + AirbyteStateStats().withRecordCount(it.recordCount.toDouble()) + } + ) + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal( + AirbyteGlobalState() + .withSharedState(message.state) + .withStreamStates(message.streamStates.map { fromStreamState(it) }) + ) + } + return AirbyteMessage().withState(state) + } + + private fun fromStreamState( + streamState: DestinationStateMessage.StreamState + ): AirbyteStreamState { + return AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(streamState.stream.descriptor.namespace) + .withName(streamState.stream.descriptor.name) + ) + .withStreamState(streamState.state) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt index 898ce683ba2f..50c9f637ef88 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageQueueWriter.kt @@ -31,7 +31,7 @@ class DestinationMessageQueueWriter( private val catalog: DestinationCatalog, private val messageQueue: MessageQueue, private val streamsManager: StreamsManager, - private val stateManager: StateManager + private val stateManager: StateManager ) : MessageQueueWriter { /** * Deserialize and route the message to the appropriate channel. diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StateManager.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StateManager.kt index e6c47ecd6dd5..9c900b4e4379 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StateManager.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StateManager.kt @@ -6,34 +6,29 @@ package io.airbyte.cdk.state import io.airbyte.cdk.command.DestinationCatalog import io.airbyte.cdk.command.DestinationStream -import io.airbyte.cdk.message.AirbyteStateMessageFactory import io.airbyte.cdk.message.DestinationStateMessage -import io.airbyte.cdk.output.OutputConsumer +import io.airbyte.cdk.message.MessageConverter +import io.airbyte.protocol.models.v0.AirbyteMessage import io.github.oshai.kotlinlogging.KotlinLogging +import io.micronaut.core.util.clhm.ConcurrentLinkedHashMap import jakarta.inject.Singleton import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicReference +import java.util.function.Consumer /** * Interface for state management. Should accept stream and global state, as well as requests to * flush all data-sufficient states. */ -interface StateManager { - fun addStreamState( - stream: DestinationStream, - index: Long, - stateMessage: DestinationStateMessage - ) - fun addGlobalState( - streamIndexes: List>, - stateMessage: DestinationStateMessage - ) +interface StateManager { + fun addStreamState(key: K, index: Long, stateMessage: T) + fun addGlobalState(keyIndexes: List>, stateMessage: T) fun flushStates() } /** - * Destination state manager. + * Message-type agnostic streams state manager. * * Accepts global and stream states, and enforces that stream and global state are not mixed. * Determines ready states by querying the StreamsManager for the state of the record index range @@ -44,50 +39,73 @@ interface StateManager { * TODO: Ensure that state is flushed at the end, and require that all state be flushed before the * destination can succeed. */ -@Singleton -class DefaultStateManager( - private val catalog: DestinationCatalog, - private val streamsManager: StreamsManager, - private val stateMessageFactory: AirbyteStateMessageFactory, - private val outputConsumer: OutputConsumer -) : StateManager { +abstract class StreamsStateManager() : StateManager { private val log = KotlinLogging.logger {} - data class GlobalState( + abstract val catalog: DestinationCatalog + abstract val streamsManager: StreamsManager + abstract val outputFactory: MessageConverter + abstract val outputConsumer: Consumer + + data class GlobalState( val streamIndexes: List>, - val stateMessage: DestinationStateMessage + val stateMessage: T ) private val stateIsGlobal: AtomicReference = AtomicReference(null) private val streamStates: - ConcurrentHashMap> = + ConcurrentHashMap> = ConcurrentHashMap() - private val globalStates: ConcurrentLinkedQueue = ConcurrentLinkedQueue() - - override fun addStreamState( - stream: DestinationStream, - index: Long, - stateMessage: DestinationStateMessage - ) { - if (stateIsGlobal.getAndSet(false) != false) { + private val globalStates: ConcurrentLinkedQueue> = ConcurrentLinkedQueue() + + override fun addStreamState(key: DestinationStream, index: Long, stateMessage: T) { + if (stateIsGlobal.updateAndGet { it == true } != false) { throw IllegalStateException("Global state cannot be mixed with non-global state") } - val streamStates = streamStates.getOrPut(stream) { LinkedHashMap() } - streamStates[index] = stateMessage - log.info { "Added state for stream: $stream at index: $index" } + streamStates.compute(key) { _, indexToMessage -> + val map = + if (indexToMessage == null) { + // If the map doesn't exist yet, build it. + ConcurrentLinkedHashMap.Builder().maximumWeightedCapacity(1000).build() + } else { + if (indexToMessage.isNotEmpty()) { + // Make sure the messages are coming in order + val oldestIndex = indexToMessage.ascendingKeySet().first() + if (oldestIndex > index) { + throw IllegalStateException( + "State message received out of order ($oldestIndex before $index)" + ) + } + } + indexToMessage + } + // Actually add the message + map[index] = stateMessage + map + } + + log.info { "Added state for stream: $key at index: $index" } } - override fun addGlobalState( - streamIndexes: List>, - stateMessage: DestinationStateMessage - ) { - if (stateIsGlobal.getAndSet(true) != true) { + // TODO: Is it an error if we don't get all the streams every time? + override fun addGlobalState(keyIndexes: List>, stateMessage: T) { + if (stateIsGlobal.updateAndGet { it != false } != true) { throw IllegalStateException("Global state cannot be mixed with non-global state") } - globalStates.add(GlobalState(streamIndexes, stateMessage)) - log.info { "Added global state with stream indexes: $streamIndexes" } + val head = globalStates.peek() + if (head != null) { + val keyIndexesByStream = keyIndexes.associate { it.first to it.second } + head.streamIndexes.forEach { + if (keyIndexesByStream[it.first]!! < it.second) { + throw IllegalStateException("Global state message received out of order") + } + } + } + + globalStates.add(GlobalState(keyIndexes, stateMessage)) + log.info { "Added global state with stream indexes: $keyIndexes" } } override fun flushStates() { @@ -105,19 +123,19 @@ class DefaultStateManager( } private fun flushGlobalStates() { - if (globalStates.isEmpty()) { - return - } - - val head = globalStates.peek() - val allStreamsPersisted = - head.streamIndexes.all { (stream, index) -> - streamsManager.getManager(stream).areRecordsPersistedUntil(index) + while (!globalStates.isEmpty()) { + val head = globalStates.peek() + val allStreamsPersisted = + head.streamIndexes.all { (stream, index) -> + streamsManager.getManager(stream).areRecordsPersistedUntil(index) + } + if (allStreamsPersisted) { + globalStates.poll() + val outMessage = outputFactory.from(head.stateMessage) + outputConsumer.accept(outMessage) + } else { + break } - if (allStreamsPersisted) { - globalStates.poll() - val outMessage = stateMessageFactory.fromDestinationStateMessage(head.stateMessage) - outputConsumer.accept(outMessage) } } @@ -131,7 +149,7 @@ class DefaultStateManager( streamStates.remove(index) ?: throw IllegalStateException("State not found for index: $index") log.info { "Flushing state for stream: $stream at index: $index" } - val outMessage = stateMessageFactory.fromDestinationStateMessage(stateMessage) + val outMessage = outputFactory.from(stateMessage) outputConsumer.accept(outMessage) } else { break @@ -140,3 +158,11 @@ class DefaultStateManager( } } } + +@Singleton +class DefaultStateManager( + override val catalog: DestinationCatalog, + override val streamsManager: StreamsManager, + override val outputFactory: MessageConverter, + override val outputConsumer: Consumer +) : StreamsStateManager() diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt index b0faf0188d08..4b74df5bf3ba 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/StreamManager.kt @@ -100,7 +100,15 @@ class DefaultStreamManager( rangesState[batch.batch.state] ?: throw IllegalArgumentException("Invalid batch state: ${batch.batch.state}") - stateRanges.addAll(batch.ranges) + // Force the ranges to overlap at their endpoints, in order to work around + // the behavior of `.encloses`, which otherwise would not consider adjacent ranges as + // contiguous. + // This ensures that a state message received at eg, index 10 (after messages 0..9 have + // been received), will pass `{'[0..5]','[6..9]'}.encloses('[0..10)')`. + val expanded = + batch.ranges.asRanges().map { it.span(Range.singleton(it.upperEndpoint() + 1)) } + + stateRanges.addAll(expanded) log.info { "Updated ranges for $stream[${batch.batch.state}]: $stateRanges" } } @@ -108,7 +116,7 @@ class DefaultStreamManager( private fun isProcessingCompleteForState(index: Long, state: Batch.State): Boolean { val completeRanges = rangesState[state]!! - return completeRanges.encloses(Range.closed(0L, index - 1)) + return completeRanges.encloses(Range.closedOpen(0L, index)) } /** True if all records have associated [Batch.State.COMPLETE] batches. */ diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt new file mode 100644 index 000000000000..5c34cd8446e8 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StateManagerTest.kt @@ -0,0 +1,462 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import com.google.common.collect.Range +import com.google.common.collect.RangeSet +import com.google.common.collect.TreeRangeSet +import io.airbyte.cdk.command.DestinationCatalog +import io.airbyte.cdk.command.DestinationCatalogFactory +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.message.Batch +import io.airbyte.cdk.message.BatchEnvelope +import io.airbyte.cdk.message.MessageConverter +import io.micronaut.context.annotation.Factory +import io.micronaut.context.annotation.Prototype +import io.micronaut.context.annotation.Replaces +import io.micronaut.context.annotation.Requires +import io.micronaut.test.extensions.junit5.annotation.MicronautTest +import jakarta.inject.Inject +import jakarta.inject.Singleton +import java.util.function.Consumer +import java.util.stream.Stream +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource + +@MicronautTest +class StateManagerTest { + @Inject lateinit var stateManager: TestStateManager + + companion object { + val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1")) + val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2")) + } + + @Factory + @Replaces(factory = DestinationCatalogFactory::class) + class MockCatalogFactory { + @Singleton + @Requires(env = ["test"]) + fun make(): DestinationCatalog { + return DestinationCatalog(streams = listOf(stream1, stream2)) + } + } + + /** + * Test state messages. + * + * StateIn: What is passed to the manager. StateOut: What is sent from the manager to the output + * consumer. + */ + sealed class MockStateIn + data class MockStreamStateIn(val stream: DestinationStream, val payload: Int) : MockStateIn() + data class MockGlobalStateIn(val payload: Int) : MockStateIn() + + sealed class MockStateOut + data class MockStreamStateOut(val stream: DestinationStream, val payload: String) : + MockStateOut() + data class MockGlobalStateOut(val payload: String) : MockStateOut() + + @Singleton + class MockStateMessageFactory : MessageConverter { + override fun from(message: MockStateIn): MockStateOut { + return when (message) { + is MockStreamStateIn -> + MockStreamStateOut(message.stream, message.payload.toString()) + is MockGlobalStateIn -> MockGlobalStateOut(message.payload.toString()) + } + } + } + + @Prototype + class MockOutputConsumer : Consumer { + val collectedStreamOutput = mutableMapOf>() + val collectedGlobalOutput = mutableListOf() + override fun accept(t: MockStateOut) { + when (t) { + is MockStreamStateOut -> + collectedStreamOutput.getOrPut(t.stream) { mutableListOf() }.add(t.payload) + is MockGlobalStateOut -> collectedGlobalOutput.add(t.payload) + } + } + } + + /** + * The only thing we really need is `areRecordsPersistedUntil`. (Technically we're emulating the + * @[StreamManager] behavior here, since the state manager doesn't actually know what ranges are + * closed, but less than that would make the test unrealistic.) + */ + class MockStreamManager : StreamManager { + var persistedRanges: RangeSet = TreeRangeSet.create() + + override fun countRecordIn(sizeBytes: Long): Long { + throw NotImplementedError() + } + + override fun markCheckpoint(): Pair { + throw NotImplementedError() + } + + override fun updateBatchState(batch: BatchEnvelope) { + throw NotImplementedError() + } + + override fun isBatchProcessingComplete(): Boolean { + throw NotImplementedError() + } + + override fun areRecordsPersistedUntil(index: Long): Boolean { + return persistedRanges.encloses(Range.closedOpen(0, index)) + } + + override fun markClosed() { + throw NotImplementedError() + } + + override fun streamIsClosed(): Boolean { + throw NotImplementedError() + } + + override suspend fun awaitStreamClosed() { + throw NotImplementedError() + } + } + + @Prototype + class MockStreamsManager(catalog: DestinationCatalog) : StreamsManager { + private val mockManagers = catalog.streams.associateWith { MockStreamManager() } + + fun addPersistedRanges(stream: DestinationStream, ranges: List>) { + mockManagers[stream]!!.persistedRanges.addAll(ranges) + } + + override fun getManager(stream: DestinationStream): StreamManager { + return mockManagers[stream] + ?: throw IllegalArgumentException("Stream not found: $stream") + } + + override suspend fun awaitAllStreamsComplete() { + throw NotImplementedError() + } + } + + @Prototype + class TestStateManager( + override val catalog: DestinationCatalog, + override val streamsManager: MockStreamsManager, + override val outputFactory: MessageConverter, + override val outputConsumer: MockOutputConsumer + ) : StreamsStateManager() + + sealed class TestEvent + data class TestStreamMessage(val stream: DestinationStream, val index: Long, val message: Int) : + TestEvent() { + fun toMockStateIn() = MockStreamStateIn(stream, message) + } + data class TestGlobalMessage( + val streamIndexes: List>, + val message: Int + ) : TestEvent() { + fun toMockStateIn() = MockGlobalStateIn(message) + } + data class FlushPoint( + val persistedRanges: Map>> = mapOf() + ) : TestEvent() + + data class TestCase( + val name: String, + val events: List, + // Order matters, but only per stream + val expectedStreamOutput: Map> = mapOf(), + val expectedGlobalOutput: List = listOf(), + val expectedException: Class? = null + ) + + class StateManagerTestArgumentsProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream { + return listOf( + TestCase( + name = + "One stream, two stream messages, flush all if all ranges are persisted", + events = + listOf( + TestStreamMessage(stream1, 10L, 1), + TestStreamMessage(stream1, 20L, 2), + FlushPoint( + persistedRanges = + mapOf(stream1 to listOf(Range.closed(0L, 20L))) + ) + ), + expectedStreamOutput = mapOf(stream1 to listOf("1", "2")) + ), + TestCase( + name = "One stream, two messages, flush only the first", + events = + listOf( + TestStreamMessage(stream1, 10L, 1), + TestStreamMessage(stream1, 20L, 2), + FlushPoint( + persistedRanges = + mapOf(stream1 to listOf(Range.closed(0L, 10L))) + ) + ), + expectedStreamOutput = mapOf(stream1 to listOf("1")) + ), + TestCase( + name = "Two streams, two messages each, flush all", + events = + listOf( + TestStreamMessage(stream1, 10L, 11), + TestStreamMessage(stream2, 30L, 21), + TestStreamMessage(stream1, 20L, 12), + TestStreamMessage(stream2, 40L, 22), + FlushPoint( + persistedRanges = + mapOf( + stream1 to listOf(Range.closed(0L, 20L)), + stream2 to listOf(Range.closed(0L, 40L)) + ) + ) + ), + expectedStreamOutput = + mapOf(stream1 to listOf("11", "12"), stream2 to listOf("22", "21")) + ), + TestCase( + name = "One stream, only later range persisted", + events = + listOf( + TestStreamMessage(stream1, 10L, 1), + TestStreamMessage(stream1, 20L, 2), + FlushPoint( + persistedRanges = + mapOf(stream1 to listOf(Range.closed(10L, 20L))) + ) + ), + expectedStreamOutput = mapOf() + ), + TestCase( + name = "One stream, out of order (should fail)", + events = + listOf( + TestStreamMessage(stream1, 20L, 2), + TestStreamMessage(stream1, 10L, 1), + FlushPoint( + persistedRanges = + mapOf(stream1 to listOf(Range.closed(0L, 20L))) + ) + ), + expectedException = IllegalStateException::class.java + ), + TestCase( + name = "Global state, two messages, flush all", + events = + listOf( + TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), + TestGlobalMessage(listOf(stream1 to 20L, stream2 to 30L), 2), + FlushPoint( + persistedRanges = + mapOf( + stream1 to listOf(Range.closed(0L, 20L)), + stream2 to listOf(Range.closed(0L, 30L)) + ) + ) + ), + expectedGlobalOutput = listOf("1", "2") + ), + TestCase( + name = "Global state, two messages, range only covers the first", + events = + listOf( + TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), + TestGlobalMessage(listOf(stream1 to 20L, stream2 to 30L), 2), + FlushPoint( + persistedRanges = + mapOf( + stream1 to listOf(Range.closed(0L, 10L)), + stream2 to listOf(Range.closed(0L, 20L)) + ) + ) + ), + expectedGlobalOutput = listOf("1") + ), + TestCase( + name = + "Global state, two messages, where the range only covers *one stream*", + events = + listOf( + TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), + TestGlobalMessage(listOf(stream1 to 20L, stream2 to 30L), 2), + FlushPoint( + mapOf( + stream1 to listOf(Range.closed(0L, 20L)), + stream2 to listOf(Range.closed(0L, 20L)) + ) + ) + ), + expectedGlobalOutput = listOf("1") + ), + TestCase( + name = "Global state, out of order (should fail)", + events = + listOf( + TestGlobalMessage(listOf(stream1 to 20L, stream2 to 30L), 2), + TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), + FlushPoint( + mapOf( + stream1 to listOf(Range.closed(0L, 20L)), + stream2 to listOf(Range.closed(0L, 30L)) + ) + ), + ), + expectedException = IllegalStateException::class.java + ), + TestCase( + name = "Mixed: first stream state, then global (should fail)", + events = + listOf( + TestStreamMessage(stream1, 10L, 1), + TestGlobalMessage(listOf(stream1 to 20L, stream2 to 30L), 2), + FlushPoint( + mapOf( + stream1 to listOf(Range.closed(0L, 20L)), + stream2 to listOf(Range.closed(0L, 30L)) + ) + ) + ), + expectedException = IllegalStateException::class.java + ), + TestCase( + name = "Mixed: first global, then stream state (should fail)", + events = + listOf( + TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), + TestStreamMessage(stream1, 20L, 2), + FlushPoint( + persistedRanges = + mapOf( + stream1 to listOf(Range.closed(0L, 20L)), + stream2 to listOf(Range.closed(0L, 30L)) + ) + ) + ), + expectedException = IllegalStateException::class.java + ), + TestCase( + name = "No messages, just a flush", + events = listOf(FlushPoint()), + expectedStreamOutput = mapOf(), + expectedGlobalOutput = listOf() + ), + TestCase( + name = "Two stream messages, flush against empty ranges", + events = + listOf( + TestStreamMessage(stream1, 10L, 1), + TestStreamMessage(stream1, 20L, 2), + FlushPoint() + ), + expectedStreamOutput = mapOf() + ), + TestCase( + name = "Stream state, multiple flush points", + events = + listOf( + TestStreamMessage(stream1, 10L, 1), + FlushPoint(), + TestStreamMessage(stream1, 20L, 2), + FlushPoint(mapOf(stream1 to listOf(Range.closed(0L, 10L)))), + TestStreamMessage(stream1, 30L, 3), + FlushPoint(mapOf(stream1 to listOf(Range.closed(10L, 30L)))) + ), + expectedStreamOutput = mapOf(stream1 to listOf("1", "2", "3")) + ), + TestCase( + name = "Global state, multiple flush points, no output", + events = + listOf( + TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), + FlushPoint(), + TestGlobalMessage(listOf(stream1 to 20L, stream2 to 30L), 2), + FlushPoint( + mapOf( + stream1 to listOf(Range.closed(0L, 20L)), + ) + ), + TestGlobalMessage(listOf(stream1 to 30L, stream2 to 40L), 3), + FlushPoint(mapOf(stream2 to listOf(Range.closed(20L, 30L)))) + ), + expectedGlobalOutput = listOf() + ), + TestCase( + name = "Global state, multiple flush points, no output until end", + events = + listOf( + TestGlobalMessage(listOf(stream1 to 10L, stream2 to 20L), 1), + FlushPoint(), + TestGlobalMessage(listOf(stream1 to 20L, stream2 to 30L), 2), + FlushPoint( + mapOf( + stream1 to listOf(Range.closed(0L, 20L)), + ) + ), + TestGlobalMessage(listOf(stream1 to 30L, stream2 to 40L), 3), + FlushPoint( + mapOf( + stream1 to listOf(Range.closed(20L, 30L)), + stream2 to listOf(Range.closed(0L, 40L)) + ) + ) + ), + expectedGlobalOutput = listOf("1", "2", "3") + ) + ) + .stream() + .map { Arguments.of(it) } + } + } + + @ParameterizedTest + @ArgumentsSource(StateManagerTestArgumentsProvider::class) + fun testAddingAndFlushingState(testCase: TestCase) { + if (testCase.expectedException != null) { + Assertions.assertThrows(testCase.expectedException) { runTestCase(testCase) } + } else { + runTestCase(testCase) + Assertions.assertEquals( + testCase.expectedStreamOutput, + stateManager.outputConsumer.collectedStreamOutput, + testCase.name + ) + Assertions.assertEquals( + testCase.expectedGlobalOutput, + stateManager.outputConsumer.collectedGlobalOutput, + testCase.name + ) + } + } + + private fun runTestCase(testCase: TestCase) { + testCase.events.forEach { + when (it) { + is TestStreamMessage -> { + stateManager.addStreamState(it.stream, it.index, it.toMockStateIn()) + } + is TestGlobalMessage -> { + stateManager.addGlobalState(it.streamIndexes, it.toMockStateIn()) + } + is FlushPoint -> { + it.persistedRanges.forEach { (stream, ranges) -> + stateManager.streamsManager.addPersistedRanges(stream, ranges) + } + stateManager.flushStates() + } + } + } + } +}