From 081a0ca111c15e4693850cba5f05d53087ab8d84 Mon Sep 17 00:00:00 2001 From: Johnny Schmidt Date: Thu, 5 Sep 2024 15:04:33 -0700 Subject: [PATCH] Bulk Load CDK: Unit tests for memory manager (#45091) --- .../io/airbyte/cdk/state/MemoryManager.kt | 50 ++++++--- .../io/airbyte/cdk/state/MemoryManagerTest.kt | 102 ++++++++++++++++++ 2 files changed, 137 insertions(+), 15 deletions(-) create mode 100644 airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/MemoryManagerTest.kt diff --git a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/MemoryManager.kt b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/MemoryManager.kt index 4e03a2ab9b23..d191223b08fd 100644 --- a/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/MemoryManager.kt +++ b/airbyte-cdk/bulk/core/load/src/main/kotlin/io/airbyte/cdk/state/MemoryManager.kt @@ -4,10 +4,12 @@ package io.airbyte.cdk.state +import io.micronaut.context.annotation.Secondary import jakarta.inject.Singleton import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.locks.ReentrantLock -import kotlin.concurrent.withLock +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock /** * Manages memory usage for the destination. @@ -17,31 +19,49 @@ import kotlin.concurrent.withLock * TODO: Some degree of logging/monitoring around how accurate we're actually being? */ @Singleton -class MemoryManager { - private val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory() +class MemoryManager(availableMemoryProvider: AvailableMemoryProvider) { + private val totalMemoryBytes: Long = availableMemoryProvider.availableMemoryBytes private var usedMemoryBytes = AtomicLong(0L) - private val memoryLock = ReentrantLock() - private val memoryLockCondition = memoryLock.newCondition() + private val mutex = Mutex() + private val syncChannel = Channel(Channel.UNLIMITED) + val remainingMemoryBytes: Long + get() = totalMemoryBytes - usedMemoryBytes.get() + + /* Attempt to reserve memory. If enough memory is not available, waits until it is, then reserves. */ suspend fun reserveBlocking(memoryBytes: Long) { - memoryLock.withLock { - while (usedMemoryBytes.get() + memoryBytes > availableMemoryBytes) { - memoryLockCondition.await() + if (memoryBytes > totalMemoryBytes) { + throw IllegalArgumentException( + "Requested ${memoryBytes}b memory exceeds ${totalMemoryBytes}b total" + ) + } + + mutex.withLock { + while (usedMemoryBytes.get() + memoryBytes > totalMemoryBytes) { + syncChannel.receive() } usedMemoryBytes.addAndGet(memoryBytes) } } suspend fun reserveRatio(ratio: Double): Long { - val estimatedSize = (availableMemoryBytes.toDouble() * ratio).toLong() + val estimatedSize = (totalMemoryBytes.toDouble() * ratio).toLong() reserveBlocking(estimatedSize) return estimatedSize } - fun release(memoryBytes: Long) { - memoryLock.withLock { - usedMemoryBytes.addAndGet(-memoryBytes) - memoryLockCondition.signalAll() - } + suspend fun release(memoryBytes: Long) { + usedMemoryBytes.addAndGet(-memoryBytes) + syncChannel.send(Unit) } } + +interface AvailableMemoryProvider { + val availableMemoryBytes: Long +} + +@Singleton +@Secondary +class JavaRuntimeAvailableMemoryProvider : AvailableMemoryProvider { + override val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory() +} diff --git a/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/MemoryManagerTest.kt b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/MemoryManagerTest.kt new file mode 100644 index 000000000000..5bc28a27cda1 --- /dev/null +++ b/airbyte-cdk/bulk/core/load/src/test/kotlin/io/airbyte/cdk/state/MemoryManagerTest.kt @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.state + +import io.micronaut.context.annotation.Replaces +import io.micronaut.context.annotation.Requires +import io.micronaut.test.extensions.junit5.annotation.MicronautTest +import jakarta.inject.Singleton +import java.util.concurrent.atomic.AtomicBoolean +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +@MicronautTest +class MemoryManagerTest { + @Singleton + @Replaces(MemoryManager::class) + @Requires(env = ["test"]) + class MockAvailableMemoryProvider : AvailableMemoryProvider { + override val availableMemoryBytes: Long = 1000 + } + + @Test + fun testReserveBlocking() = runTest { + val memoryManager = MemoryManager(MockAvailableMemoryProvider()) + val reserved = AtomicBoolean(false) + + try { + withTimeout(5000) { memoryManager.reserveBlocking(900) } + } catch (e: Exception) { + Assertions.fail("Failed to reserve memory") + } + + Assertions.assertEquals(100, memoryManager.remainingMemoryBytes) + + val job = launch { + memoryManager.reserveBlocking(200) + reserved.set(true) + } + + memoryManager.reserveBlocking(0) + Assertions.assertFalse(reserved.get()) + + memoryManager.release(50) + memoryManager.reserveBlocking(0) + Assertions.assertEquals(150, memoryManager.remainingMemoryBytes) + Assertions.assertFalse(reserved.get()) + + memoryManager.release(25) + memoryManager.reserveBlocking(0) + Assertions.assertEquals(175, memoryManager.remainingMemoryBytes) + Assertions.assertFalse(reserved.get()) + + memoryManager.release(25) + try { + withTimeout(5000) { job.join() } + } catch (e: Exception) { + Assertions.fail("Failed to unblock reserving memory") + } + Assertions.assertEquals(0, memoryManager.remainingMemoryBytes) + Assertions.assertTrue(reserved.get()) + } + + @Test + fun testReserveBlockingMultithreaded() = runTest { + val memoryManager = MemoryManager(MockAvailableMemoryProvider()) + withContext(Dispatchers.IO) { + memoryManager.reserveBlocking(1000) + Assertions.assertEquals(0, memoryManager.remainingMemoryBytes) + val nIterations = 100000 + + val jobs = (0 until nIterations).map { launch { memoryManager.reserveBlocking(10) } } + + repeat(nIterations) { + memoryManager.release(10) + Assertions.assertTrue( + memoryManager.remainingMemoryBytes >= 0, + "Remaining memory is negative: ${memoryManager.remainingMemoryBytes}" + ) + } + jobs.forEach { it.join() } + Assertions.assertEquals(0, memoryManager.remainingMemoryBytes) + } + } + + @Test + fun testRequestingMoreThanAvailableThrows() = runTest { + val memoryManager = MemoryManager(MockAvailableMemoryProvider()) + try { + memoryManager.reserveBlocking(1001) + } catch (e: IllegalArgumentException) { + return@runTest + } + Assertions.fail("Requesting more memory than available should throw an exception") + } +}