Skip to content

Commit

Permalink
Bulk Load CDK: Unit tests for memory manager (#45091)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Sep 5, 2024
1 parent 6730a3b commit 081a0ca
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

package io.airbyte.cdk.state

import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock

/**
* Manages memory usage for the destination.
Expand All @@ -17,31 +19,49 @@ import kotlin.concurrent.withLock
* TODO: Some degree of logging/monitoring around how accurate we're actually being?
*/
@Singleton
class MemoryManager {
private val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory()
class MemoryManager(availableMemoryProvider: AvailableMemoryProvider) {
private val totalMemoryBytes: Long = availableMemoryProvider.availableMemoryBytes
private var usedMemoryBytes = AtomicLong(0L)
private val memoryLock = ReentrantLock()
private val memoryLockCondition = memoryLock.newCondition()
private val mutex = Mutex()
private val syncChannel = Channel<Unit>(Channel.UNLIMITED)

val remainingMemoryBytes: Long
get() = totalMemoryBytes - usedMemoryBytes.get()

/* Attempt to reserve memory. If enough memory is not available, waits until it is, then reserves. */
suspend fun reserveBlocking(memoryBytes: Long) {
memoryLock.withLock {
while (usedMemoryBytes.get() + memoryBytes > availableMemoryBytes) {
memoryLockCondition.await()
if (memoryBytes > totalMemoryBytes) {
throw IllegalArgumentException(
"Requested ${memoryBytes}b memory exceeds ${totalMemoryBytes}b total"
)
}

mutex.withLock {
while (usedMemoryBytes.get() + memoryBytes > totalMemoryBytes) {
syncChannel.receive()
}
usedMemoryBytes.addAndGet(memoryBytes)
}
}

suspend fun reserveRatio(ratio: Double): Long {
val estimatedSize = (availableMemoryBytes.toDouble() * ratio).toLong()
val estimatedSize = (totalMemoryBytes.toDouble() * ratio).toLong()
reserveBlocking(estimatedSize)
return estimatedSize
}

fun release(memoryBytes: Long) {
memoryLock.withLock {
usedMemoryBytes.addAndGet(-memoryBytes)
memoryLockCondition.signalAll()
}
suspend fun release(memoryBytes: Long) {
usedMemoryBytes.addAndGet(-memoryBytes)
syncChannel.send(Unit)
}
}

interface AvailableMemoryProvider {
val availableMemoryBytes: Long
}

@Singleton
@Secondary
class JavaRuntimeAvailableMemoryProvider : AvailableMemoryProvider {
override val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.state

import io.micronaut.context.annotation.Replaces
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicBoolean
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

@MicronautTest
class MemoryManagerTest {
@Singleton
@Replaces(MemoryManager::class)
@Requires(env = ["test"])
class MockAvailableMemoryProvider : AvailableMemoryProvider {
override val availableMemoryBytes: Long = 1000
}

@Test
fun testReserveBlocking() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
val reserved = AtomicBoolean(false)

try {
withTimeout(5000) { memoryManager.reserveBlocking(900) }
} catch (e: Exception) {
Assertions.fail<Unit>("Failed to reserve memory")
}

Assertions.assertEquals(100, memoryManager.remainingMemoryBytes)

val job = launch {
memoryManager.reserveBlocking(200)
reserved.set(true)
}

memoryManager.reserveBlocking(0)
Assertions.assertFalse(reserved.get())

memoryManager.release(50)
memoryManager.reserveBlocking(0)
Assertions.assertEquals(150, memoryManager.remainingMemoryBytes)
Assertions.assertFalse(reserved.get())

memoryManager.release(25)
memoryManager.reserveBlocking(0)
Assertions.assertEquals(175, memoryManager.remainingMemoryBytes)
Assertions.assertFalse(reserved.get())

memoryManager.release(25)
try {
withTimeout(5000) { job.join() }
} catch (e: Exception) {
Assertions.fail<Unit>("Failed to unblock reserving memory")
}
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
Assertions.assertTrue(reserved.get())
}

@Test
fun testReserveBlockingMultithreaded() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
withContext(Dispatchers.IO) {
memoryManager.reserveBlocking(1000)
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
val nIterations = 100000

val jobs = (0 until nIterations).map { launch { memoryManager.reserveBlocking(10) } }

repeat(nIterations) {
memoryManager.release(10)
Assertions.assertTrue(
memoryManager.remainingMemoryBytes >= 0,
"Remaining memory is negative: ${memoryManager.remainingMemoryBytes}"
)
}
jobs.forEach { it.join() }
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
}
}

@Test
fun testRequestingMoreThanAvailableThrows() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
try {
memoryManager.reserveBlocking(1001)
} catch (e: IllegalArgumentException) {
return@runTest
}
Assertions.fail<Unit>("Requesting more memory than available should throw an exception")
}
}

0 comments on commit 081a0ca

Please sign in to comment.