diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt index e5ac1ddf92c0..be21089da20e 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationCatalog.kt @@ -23,6 +23,9 @@ data class DestinationCatalog( return byDescriptor[descriptor] ?: throw IllegalArgumentException("Stream not found: namespace=$namespace, name=$name") } + + fun asProtocolObject(): ConfiguredAirbyteCatalog = + ConfiguredAirbyteCatalog().withStreams(streams.map { it.asProtocolObject() }) } interface DestinationCatalogFactory { diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationStream.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationStream.kt index 9d79d2bd948d..77c143ab6a30 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationStream.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/command/DestinationStream.kt @@ -4,7 +4,11 @@ package io.airbyte.cdk.command +import com.fasterxml.jackson.databind.node.ObjectNode +import io.airbyte.protocol.models.v0.AirbyteStream import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import io.airbyte.protocol.models.v0.DestinationSyncMode +import io.airbyte.protocol.models.v0.StreamDescriptor import jakarta.inject.Singleton /** @@ -15,20 +19,51 @@ import jakarta.inject.Singleton * * TODO: Add dedicated schema type, converted from json-schema. */ -class DestinationStream(val descriptor: Descriptor) { - data class Descriptor(val namespace: String, val name: String) - - override fun hashCode(): Int { - return descriptor.hashCode() - } - - override fun equals(other: Any?): Boolean { - return other is DestinationStream && descriptor == other.descriptor +data class DestinationStream( + val descriptor: Descriptor, + val importType: ImportType, + val schema: ObjectNode, + val generationId: Long, + val minimumGenerationId: Long, + val syncId: Long, +) { + data class Descriptor(val namespace: String, val name: String) { + fun asProtocolObject(): StreamDescriptor = + StreamDescriptor().withNamespace(namespace).withName(name) } - override fun toString(): String { - return "DestinationStream(descriptor=$descriptor)" - } + /** + * This is not fully round-trippable. Destinations don't care about most of the stuff in an + * AirbyteStream (e.g. we don't care about defaultCursorField, we only care about the _actual_ + * cursor field; we don't care about the source sync mode, we only care about the destination + * sync mode; etc.). + */ + fun asProtocolObject(): ConfiguredAirbyteStream = + ConfiguredAirbyteStream() + .withStream( + AirbyteStream() + .withNamespace(descriptor.namespace) + .withName(descriptor.name) + .withJsonSchema(schema) + ) + .withGenerationId(generationId) + .withMinimumGenerationId(minimumGenerationId) + .withSyncId(syncId) + .apply { + when (importType) { + is Append -> { + destinationSyncMode = DestinationSyncMode.APPEND + } + is Dedupe -> { + destinationSyncMode = DestinationSyncMode.APPEND_DEDUP + cursorField = importType.cursor + primaryKey = importType.primaryKey + } + Overwrite -> { + destinationSyncMode = DestinationSyncMode.OVERWRITE + } + } + } } @Singleton @@ -39,7 +74,46 @@ class DestinationStreamFactory { DestinationStream.Descriptor( namespace = stream.stream.namespace, name = stream.stream.name - ) + ), + importType = + when (stream.destinationSyncMode) { + null -> throw IllegalArgumentException("Destination sync mode was null") + DestinationSyncMode.APPEND -> Append + DestinationSyncMode.OVERWRITE -> Overwrite + DestinationSyncMode.APPEND_DEDUP -> + Dedupe(primaryKey = stream.primaryKey, cursor = stream.cursorField) + }, + schema = stream.stream.jsonSchema as ObjectNode, + generationId = stream.generationId, + minimumGenerationId = stream.minimumGenerationId, + syncId = stream.syncId, ) } } + +sealed interface ImportType + +data object Append : ImportType + +data class Dedupe( + /** + * theoretically, the path to the fields in the PK. In practice, most destinations only support + * PK at the root level, i.e. `listOf(listOf(pkField1), listOf(pkField2), etc)`. + */ + val primaryKey: List>, + /** + * theoretically, the path to the cursor. In practice, most destinations only support cursors at + * the root level, i.e. `listOf(cursorField)`. + */ + val cursor: List, +) : ImportType +/** + * A legacy destination sync mode. Modern destinations depend on platform to set + * overwrite/record-retaining behavior via the generationId / minimumGenerationId parameters, and + * should treat this as equivalent to Append. + * + * [Overwrite] is approximately equivalent to an [Append] sync, with nonzeao generationId equal to + * minimumGenerationId. + */ +// TODO should this even exist? +data object Overwrite : ImportType diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt index ae388a275bb6..d6c37cfafce0 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/DestinationMessage.kt @@ -7,9 +7,17 @@ package io.airbyte.cdk.message import com.fasterxml.jackson.databind.JsonNode import io.airbyte.cdk.command.DestinationCatalog import io.airbyte.cdk.command.DestinationStream +import io.airbyte.cdk.message.CheckpointMessage.Checkpoint +import io.airbyte.cdk.message.CheckpointMessage.Stats +import io.airbyte.protocol.models.v0.AirbyteGlobalState import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteRecordMessage +import io.airbyte.protocol.models.v0.AirbyteRecordMessageMeta +import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStateStats import io.airbyte.protocol.models.v0.AirbyteStreamState +import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage.AirbyteStreamStatus import io.airbyte.protocol.models.v0.AirbyteTraceMessage import jakarta.inject.Singleton @@ -18,60 +26,181 @@ import jakarta.inject.Singleton * Internal representation of destination messages. These are intended to be specialized for * usability. Data should be marshalled to these from frontline deserialized objects. */ -sealed class DestinationMessage +sealed interface DestinationMessage { + fun asProtocolMessage(): AirbyteMessage +} /** Records. */ -sealed class DestinationRecordMessage : DestinationMessage() { - abstract val stream: DestinationStream +sealed interface DestinationStreamAffinedMessage : DestinationMessage { + val stream: DestinationStream } data class DestinationRecord( override val stream: DestinationStream, val data: JsonNode? = null, val emittedAtMs: Long, - val serialized: String -) : DestinationRecordMessage() + val meta: Meta?, + val serialized: String, +) : DestinationStreamAffinedMessage { + data class Meta(val changes: List?) { + fun asProtocolObject(): AirbyteRecordMessageMeta = + AirbyteRecordMessageMeta().also { + if (changes != null) { + it.changes = changes.map { change -> change.asProtocolObject() } + } + } + } + + data class Change( + val field: String, + // Using the raw protocol enums here. + // By definition, we just want to pass these through directly. + val change: AirbyteRecordMessageMetaChange.Change, + val reason: AirbyteRecordMessageMetaChange.Reason, + ) { + fun asProtocolObject(): AirbyteRecordMessageMetaChange = + AirbyteRecordMessageMetaChange().withField(field).withChange(change).withReason(reason) + } + + override fun asProtocolMessage(): AirbyteMessage = + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(stream.descriptor.name) + .withNamespace(stream.descriptor.namespace) + .withEmittedAt(emittedAtMs) + .withData(data) + .also { + if (meta != null) { + it.meta = meta.asProtocolObject() + } + } + ) +} + +private fun statusToProtocolMessage( + stream: DestinationStream, + emittedAtMs: Long, + status: AirbyteStreamStatus, +): AirbyteMessage = + AirbyteMessage() + .withType(AirbyteMessage.Type.TRACE) + .withTrace( + AirbyteTraceMessage() + .withType(AirbyteTraceMessage.Type.STREAM_STATUS) + .withEmittedAt(emittedAtMs.toDouble()) + .withStreamStatus( + AirbyteStreamStatusTraceMessage() + .withStreamDescriptor(stream.descriptor.asProtocolObject()) + .withStatus(status) + ) + ) data class DestinationStreamComplete( override val stream: DestinationStream, - val emittedAtMs: Long -) : DestinationRecordMessage() + val emittedAtMs: Long, +) : DestinationStreamAffinedMessage { + override fun asProtocolMessage(): AirbyteMessage = + statusToProtocolMessage(stream, emittedAtMs, AirbyteStreamStatus.COMPLETE) +} + +data class DestinationStreamIncomplete( + override val stream: DestinationStream, + val emittedAtMs: Long, +) : DestinationStreamAffinedMessage { + override fun asProtocolMessage(): AirbyteMessage = + statusToProtocolMessage(stream, emittedAtMs, AirbyteStreamStatus.INCOMPLETE) +} /** State. */ -sealed class CheckpointMessage : DestinationMessage() { +sealed interface CheckpointMessage : DestinationMessage { data class Stats(val recordCount: Long) - data class StreamCheckpoint( + data class Checkpoint( val stream: DestinationStream, val state: JsonNode, - ) + ) { + fun asProtocolObject(): AirbyteStreamState = + AirbyteStreamState() + .withStreamDescriptor(stream.descriptor.asProtocolObject()) + .withStreamState(state) + } - abstract val sourceStats: Stats - abstract val destinationStats: Stats? + val sourceStats: Stats + val destinationStats: Stats? - abstract fun withDestinationStats(stats: Stats): CheckpointMessage + fun withDestinationStats(stats: Stats): CheckpointMessage } data class StreamCheckpoint( - val streamCheckpoint: StreamCheckpoint, + val checkpoint: Checkpoint, override val sourceStats: Stats, override val destinationStats: Stats? = null -) : CheckpointMessage() { +) : CheckpointMessage { override fun withDestinationStats(stats: Stats) = - StreamCheckpoint(streamCheckpoint, sourceStats, stats) + StreamCheckpoint(checkpoint, sourceStats, stats) + + override fun asProtocolMessage(): AirbyteMessage { + val stateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(checkpoint.asProtocolObject()) + .withSourceStats( + AirbyteStateStats().withRecordCount(sourceStats.recordCount.toDouble()) + ) + .also { + if (destinationStats != null) { + it.destinationStats = + AirbyteStateStats() + .withRecordCount(destinationStats.recordCount.toDouble()) + } + } + return AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) + } } data class GlobalCheckpoint( val state: JsonNode, override val sourceStats: Stats, override val destinationStats: Stats? = null, - val streamCheckpoints: List = emptyList() -) : CheckpointMessage() { + val checkpoints: List = emptyList() +) : CheckpointMessage { override fun withDestinationStats(stats: Stats) = - GlobalCheckpoint(state, sourceStats, stats, streamCheckpoints) + GlobalCheckpoint(state, sourceStats, stats, checkpoints) + + override fun asProtocolMessage(): AirbyteMessage { + val stateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal( + AirbyteGlobalState() + .withSharedState(state) + .withStreamStates(checkpoints.map { it.asProtocolObject() }) + ) + .withSourceStats( + AirbyteStateStats().withRecordCount(sourceStats.recordCount.toDouble()) + ) + .also { + if (destinationStats != null) { + it.destinationStats = + AirbyteStateStats() + .withRecordCount(destinationStats.recordCount.toDouble()) + } + } + return AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) + } } /** Catchall for anything unimplemented. */ -data object Undefined : DestinationMessage() +data object Undefined : DestinationMessage { + override fun asProtocolMessage(): AirbyteMessage { + // Arguably we could accept the raw message in the constructor? + // But that seems weird - when would we ever want to reemit that message? + throw NotImplementedError( + "Unrecognized messages cannot be safely converted back to a protocol object." + ) + } +} @Singleton class DestinationMessageFactory(private val catalog: DestinationCatalog) { @@ -84,9 +213,20 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) { namespace = message.record.namespace, name = message.record.stream, ), - // TODO: Map to AirbyteType data = message.record.data, emittedAtMs = message.record.emittedAt, + meta = + message.record.meta?.let { meta -> + DestinationRecord.Meta( + meta.changes?.map { + DestinationRecord.Change( + field = it.field, + change = it.change, + reason = it.reason, + ) + } + ) + }, serialized = serialized ) AirbyteMessage.Type.TRACE -> { @@ -96,11 +236,14 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) { namespace = status.streamDescriptor.namespace, name = status.streamDescriptor.name, ) - if ( - message.trace.type == AirbyteTraceMessage.Type.STREAM_STATUS && - status.status == AirbyteStreamStatus.COMPLETE - ) { - DestinationStreamComplete(stream, message.trace.emittedAt.toLong()) + if (message.trace.type == AirbyteTraceMessage.Type.STREAM_STATUS) { + when (status.status) { + AirbyteStreamStatus.COMPLETE -> + DestinationStreamComplete(stream, message.trace.emittedAt.toLong()) + AirbyteStreamStatus.INCOMPLETE -> + DestinationStreamIncomplete(stream, message.trace.emittedAt.toLong()) + else -> Undefined + } } else { Undefined } @@ -109,20 +252,16 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) { when (message.state.type) { AirbyteStateMessage.AirbyteStateType.STREAM -> StreamCheckpoint( - streamCheckpoint = fromAirbyteStreamState(message.state.stream), + checkpoint = fromAirbyteStreamState(message.state.stream), sourceStats = - CheckpointMessage.Stats( - recordCount = message.state.sourceStats.recordCount.toLong() - ) + Stats(recordCount = message.state.sourceStats.recordCount.toLong()) ) AirbyteStateMessage.AirbyteStateType.GLOBAL -> GlobalCheckpoint( sourceStats = - CheckpointMessage.Stats( - recordCount = message.state.sourceStats.recordCount.toLong() - ), + Stats(recordCount = message.state.sourceStats.recordCount.toLong()), state = message.state.global.sharedState, - streamCheckpoints = + checkpoints = message.state.global.streamStates.map { fromAirbyteStreamState(it) } ) else -> // TODO: Do we still need to handle LEGACY? @@ -133,11 +272,9 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) { } } - private fun fromAirbyteStreamState( - streamState: AirbyteStreamState - ): CheckpointMessage.StreamCheckpoint { + private fun fromAirbyteStreamState(streamState: AirbyteStreamState): Checkpoint { val descriptor = streamState.streamDescriptor - return CheckpointMessage.StreamCheckpoint( + return Checkpoint( stream = catalog.getStream(namespace = descriptor.namespace, name = descriptor.name), state = streamState.streamState ) diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt index 5a325115c9d0..c331625d4141 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/message/MessageConverter.kt @@ -40,7 +40,7 @@ class DefaultMessageConverter : MessageConverter AirbyteStateMessage() .withSourceStats( @@ -56,23 +56,19 @@ class DefaultMessageConverter : MessageConverter { } /** - * Routes @[DestinationRecordMessage]s by stream to the appropriate channel and @ + * Routes @[DestinationStreamAffinedMessage]s by stream to the appropriate channel and @ * [CheckpointMessage]s to the state manager. * * TODO: Handle other message types. @@ -41,7 +41,7 @@ class DestinationMessageQueueWriter( override suspend fun publish(message: DestinationMessage, sizeBytes: Long) { when (message) { /* If the input message represents a record. */ - is DestinationRecordMessage -> { + is DestinationStreamAffinedMessage -> { val manager = streamsManager.getManager(message.stream) when (message) { /* If a data record */ @@ -56,7 +56,8 @@ class DestinationMessageQueueWriter( } /* If an end-of-stream marker. */ - is DestinationStreamComplete -> { + is DestinationStreamComplete, + is DestinationStreamIncomplete -> { val wrapped = StreamCompleteWrapped(index = manager.countEndOfStream()) messageQueue.getChannel(message.stream).send(wrapped) } @@ -70,7 +71,7 @@ class DestinationMessageQueueWriter( * stats. */ is StreamCheckpoint -> { - val stream = message.streamCheckpoint.stream + val stream = message.checkpoint.stream val manager = streamsManager.getManager(stream) val (currentIndex, countSinceLast) = manager.markCheckpoint() val messageWithCount = diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/ProcessRecordsTask.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/ProcessRecordsTask.kt index 1e1d1f8417d6..ec78f9dee3e4 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/ProcessRecordsTask.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/task/ProcessRecordsTask.kt @@ -9,8 +9,9 @@ import io.airbyte.cdk.message.BatchEnvelope import io.airbyte.cdk.message.Deserializer import io.airbyte.cdk.message.DestinationMessage import io.airbyte.cdk.message.DestinationRecord -import io.airbyte.cdk.message.DestinationRecordMessage +import io.airbyte.cdk.message.DestinationStreamAffinedMessage import io.airbyte.cdk.message.DestinationStreamComplete +import io.airbyte.cdk.message.DestinationStreamIncomplete import io.airbyte.cdk.message.SpooledRawMessagesLocalFile import io.airbyte.cdk.state.StreamManager import io.airbyte.cdk.state.StreamsManager @@ -44,15 +45,17 @@ class ProcessRecordsTask( .bufferedReader(Charsets.UTF_8) .lineSequence() .map { - when (val record = deserializer.deserialize(it)) { - is DestinationRecordMessage -> record + when (val message = deserializer.deserialize(it)) { + is DestinationStreamAffinedMessage -> message else -> throw IllegalStateException( - "Expected record message, got ${record::class}" + "Expected record message, got ${message::class}" ) } } - .takeWhile { it !is DestinationStreamComplete } + .takeWhile { + it !is DestinationStreamComplete && it !is DestinationStreamIncomplete + } .map { it as DestinationRecord } .iterator() streamLoader.processRecords(records, fileEnvelope.batch.totalSizeBytes) diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/DestinationCatalogTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/DestinationCatalogTest.kt new file mode 100644 index 000000000000..d6db2d9601c9 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/DestinationCatalogTest.kt @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.command + +import io.airbyte.protocol.models.Jsons +import io.airbyte.protocol.models.v0.AirbyteStream +import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import io.airbyte.protocol.models.v0.DestinationSyncMode +import kotlin.test.assertEquals +import org.junit.jupiter.api.Test + +class DestinationCatalogTest { + @Test + fun roundTrip() { + val originalCatalog = + ConfiguredAirbyteCatalog() + .withStreams( + listOf( + ConfiguredAirbyteStream() + .withSyncId(12) + .withMinimumGenerationId(34) + .withGenerationId(56) + .withDestinationSyncMode(DestinationSyncMode.APPEND) + .withStream( + AirbyteStream() + .withJsonSchema(Jsons.deserialize("""{"type": "object"}""")) + .withNamespace("namespace1") + .withName("name1") + ), + ConfiguredAirbyteStream() + .withSyncId(12) + .withMinimumGenerationId(34) + .withGenerationId(56) + .withDestinationSyncMode(DestinationSyncMode.APPEND_DEDUP) + .withStream( + AirbyteStream() + .withJsonSchema(Jsons.deserialize("""{"type": "object"}""")) + .withNamespace("namespace2") + .withName("name2") + ) + .withPrimaryKey(listOf(listOf("id1"), listOf("id2"))) + .withCursorField(listOf("cursor")), + ConfiguredAirbyteStream() + .withSyncId(12) + .withMinimumGenerationId(34) + .withGenerationId(56) + .withDestinationSyncMode(DestinationSyncMode.OVERWRITE) + .withStream( + AirbyteStream() + .withJsonSchema(Jsons.deserialize("""{"type": "object"}""")) + .withNamespace("namespace3") + .withName("name3") + ), + ), + ) + + val streamFactory = DestinationStreamFactory() + val catalogFactory = DefaultDestinationCatalogFactory(originalCatalog, streamFactory) + val destinationCatalog = catalogFactory.make() + assertEquals(originalCatalog, destinationCatalog.asProtocolObject()) + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockCatalogFactory.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockCatalogFactory.kt index dc19a9e28452..93df547487a6 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockCatalogFactory.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/command/MockCatalogFactory.kt @@ -4,6 +4,8 @@ package io.airbyte.cdk.command +import com.fasterxml.jackson.databind.node.ObjectNode +import io.airbyte.protocol.models.Jsons import io.micronaut.context.annotation.Factory import io.micronaut.context.annotation.Replaces import io.micronaut.context.annotation.Requires @@ -15,8 +17,28 @@ import jakarta.inject.Singleton @Requires(env = ["test"]) class MockCatalogFactory : DestinationCatalogFactory { companion object { - val stream1 = DestinationStream(DestinationStream.Descriptor("test", "stream1")) - val stream2 = DestinationStream(DestinationStream.Descriptor("test", "stream2")) + val stream1 = + DestinationStream( + DestinationStream.Descriptor("test", "stream1"), + Append, + Jsons.deserialize( + """{"type": "object", "properties": {"id": {"type": "integer"}}}""" + ) as ObjectNode, + generationId = 42, + minimumGenerationId = 0, + syncId = 42, + ) + val stream2 = + DestinationStream( + DestinationStream.Descriptor("test", "stream2"), + Append, + Jsons.deserialize( + """{"type": "object", "properties": {"id": {"type": "integer"}}}""" + ) as ObjectNode, + generationId = 42, + minimumGenerationId = 0, + syncId = 42, + ) } @Singleton diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageTest.kt new file mode 100644 index 000000000000..16e82efde998 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/message/DestinationMessageTest.kt @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.message + +import com.fasterxml.jackson.databind.node.ObjectNode +import io.airbyte.cdk.command.Append +import io.airbyte.cdk.command.DestinationCatalog +import io.airbyte.cdk.command.DestinationStream +import io.airbyte.protocol.models.Jsons +import io.airbyte.protocol.models.v0.AirbyteGlobalState +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.AirbyteRecordMessage +import io.airbyte.protocol.models.v0.AirbyteRecordMessageMeta +import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange +import io.airbyte.protocol.models.v0.AirbyteStateMessage +import io.airbyte.protocol.models.v0.AirbyteStateStats +import io.airbyte.protocol.models.v0.AirbyteStreamState +import io.airbyte.protocol.models.v0.AirbyteStreamStatusTraceMessage +import io.airbyte.protocol.models.v0.AirbyteTraceMessage +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource + +class DestinationMessageTest { + private val factory = + DestinationMessageFactory( + DestinationCatalog( + listOf( + DestinationStream( + descriptor, + Append, + Jsons.deserialize("{}") as ObjectNode, + generationId = 42, + minimumGenerationId = 0, + syncId = 42, + ) + ) + ) + ) + + @ParameterizedTest + @MethodSource("roundTrippableMessages") + fun testRoundTrip(message: AirbyteMessage) { + val roundTripped = + factory.fromAirbyteMessage(message, Jsons.serialize(message)).asProtocolMessage() + assertEquals(message, roundTripped) + } + + // Checkpoint messages aren't round-trippable. + // We don't read in destinationStats (because we're the ones setting that field). + @Test + fun testStreamCheckpoint() { + val inputMessage = + AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor(descriptor.asProtocolObject()) + .withStreamState(blob1) + ) + // Note: only source stats, no destination stats + .withSourceStats(AirbyteStateStats().withRecordCount(2.0)) + ) + + val parsedMessage = + factory.fromAirbyteMessage(inputMessage, Jsons.serialize(inputMessage)) + as StreamCheckpoint + + assertEquals( + inputMessage.also { + it.state.destinationStats = AirbyteStateStats().withRecordCount(3.0) + }, + parsedMessage.withDestinationStats(CheckpointMessage.Stats(3)).asProtocolMessage(), + ) + } + + @Test + fun testGlobalCheckpoint() { + val inputMessage = + AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal( + AirbyteGlobalState() + .withSharedState(blob1) + .withStreamStates( + listOf( + AirbyteStreamState() + .withStreamDescriptor(descriptor.asProtocolObject()) + .withStreamState(blob2), + ), + ), + ) + // Note: only source stats, no destination stats + .withSourceStats(AirbyteStateStats().withRecordCount(2.0)) + ) + + val parsedMessage = + factory.fromAirbyteMessage( + inputMessage, + Jsons.serialize(inputMessage), + ) as GlobalCheckpoint + + assertEquals( + inputMessage.also { + it.state.destinationStats = AirbyteStateStats().withRecordCount(3.0) + }, + parsedMessage.withDestinationStats(CheckpointMessage.Stats(3)).asProtocolMessage(), + ) + } + + companion object { + private val descriptor = DestinationStream.Descriptor("namespace", "name") + private val blob1 = Jsons.deserialize("""{"foo": "bar"}""") + private val blob2 = Jsons.deserialize("""{"foo": "bar"}""") + + @JvmStatic + fun roundTrippableMessages(): List = + listOf( + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream("name") + .withNamespace("namespace") + .withEmittedAt(1234) + .withMeta( + AirbyteRecordMessageMeta() + .withChanges( + listOf( + AirbyteRecordMessageMetaChange() + .withField("foo") + .withReason( + AirbyteRecordMessageMetaChange.Reason + .DESTINATION_FIELD_SIZE_LIMITATION + ) + .withChange( + AirbyteRecordMessageMetaChange.Change.NULLED + ) + ) + ) + ) + .withData(blob1) + ), + AirbyteMessage() + .withType(AirbyteMessage.Type.TRACE) + .withTrace( + AirbyteTraceMessage() + .withType(AirbyteTraceMessage.Type.STREAM_STATUS) + .withEmittedAt(1234.0) + .withStreamStatus( + AirbyteStreamStatusTraceMessage() + // Intentionally no "reasons" here - destinations never + // inspect that + // field, so it's not round-trippable + .withStreamDescriptor(descriptor.asProtocolObject()) + .withStatus( + AirbyteStreamStatusTraceMessage.AirbyteStreamStatus + .COMPLETE + ) + ) + ), + AirbyteMessage() + .withType(AirbyteMessage.Type.TRACE) + .withTrace( + AirbyteTraceMessage() + .withType(AirbyteTraceMessage.Type.STREAM_STATUS) + .withEmittedAt(1234.0) + .withStreamStatus( + AirbyteStreamStatusTraceMessage() + // Intentionally no "reasons" here - destinations never + // inspect that + // field, so it's not round-trippable + .withStreamDescriptor(descriptor.asProtocolObject()) + .withStatus( + AirbyteStreamStatusTraceMessage.AirbyteStreamStatus + .INCOMPLETE + ) + ) + ), + ) + .map { Arguments.of(it) } + } +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StreamsManagerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StreamsManagerTest.kt index 5ba013f4fab4..857a7b24bc5e 100644 --- a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StreamsManagerTest.kt +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/StreamsManagerTest.kt @@ -4,7 +4,9 @@ package io.airbyte.cdk.state +import com.fasterxml.jackson.databind.node.ObjectNode import com.google.common.collect.Range +import io.airbyte.cdk.command.Append import io.airbyte.cdk.command.DestinationCatalog import io.airbyte.cdk.command.DestinationStream import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream1 @@ -12,6 +14,7 @@ import io.airbyte.cdk.command.MockCatalogFactory.Companion.stream2 import io.airbyte.cdk.message.Batch import io.airbyte.cdk.message.BatchEnvelope import io.airbyte.cdk.message.SimpleBatch +import io.airbyte.protocol.models.Jsons import io.micronaut.test.extensions.junit5.annotation.MicronautTest import jakarta.inject.Inject import jakarta.inject.Named @@ -70,7 +73,16 @@ class StreamsManagerTest { val streamsManager = StreamsManagerFactory(catalog).make() Assertions.assertThrows(IllegalArgumentException::class.java) { streamsManager.getManager( - DestinationStream(DestinationStream.Descriptor("test", "non-existent")) + DestinationStream( + DestinationStream.Descriptor("test", "non-existent"), + Append, + Jsons.deserialize( + """{"type": "object", "properties": {"id": {"type": "integer"}}}""" + ) as ObjectNode, + generationId = 42, + minimumGenerationId = 0, + syncId = 42, + ) ) } }