Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk Load CDK: Unit tests for memory manager #45091

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

package io.airbyte.cdk.state

import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock

/**
* Manages memory usage for the destination.
Expand All @@ -17,31 +19,49 @@ import kotlin.concurrent.withLock
* TODO: Some degree of logging/monitoring around how accurate we're actually being?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a TODO: Be more fair. The current implementation will guarantee a first-come first-served order, which means if a function requests a large amount of memory that's not available yet, any subsequent call will block until the large memory request is satisfied, even though there may be enough available memory to serve the smaller requests

*/
@Singleton
class MemoryManager {
private val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory()
class MemoryManager(availableMemoryProvider: AvailableMemoryProvider) {
private val totalMemoryBytes: Long = availableMemoryProvider.availableMemoryBytes
private var usedMemoryBytes = AtomicLong(0L)
private val memoryLock = ReentrantLock()
private val memoryLockCondition = memoryLock.newCondition()
private val mutex = Mutex()
private val syncChannel = Channel<Unit>(Channel.UNLIMITED)

val remainingMemoryBytes: Long
get() = totalMemoryBytes - usedMemoryBytes.get()

/* Attempt to reserve memory. If enough memory is not available, waits until it is, then reserves. */
suspend fun reserveBlocking(memoryBytes: Long) {
memoryLock.withLock {
while (usedMemoryBytes.get() + memoryBytes > availableMemoryBytes) {
memoryLockCondition.await()
if (memoryBytes > totalMemoryBytes) {
throw IllegalArgumentException(
"Requested ${memoryBytes}b memory exceeds ${totalMemoryBytes}b total"
)
}

mutex.withLock {
while (usedMemoryBytes.get() + memoryBytes > totalMemoryBytes) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I'm pretty sure this is correct.
Just as an interesting case, we could also do without the AtomicLong and a volatile would suffice :)
Thank you for fixing it.

syncChannel.receive()
}
usedMemoryBytes.addAndGet(memoryBytes)
}
}

suspend fun reserveRatio(ratio: Double): Long {
val estimatedSize = (availableMemoryBytes.toDouble() * ratio).toLong()
val estimatedSize = (totalMemoryBytes.toDouble() * ratio).toLong()
reserveBlocking(estimatedSize)
return estimatedSize
}

fun release(memoryBytes: Long) {
memoryLock.withLock {
usedMemoryBytes.addAndGet(-memoryBytes)
memoryLockCondition.signalAll()
}
suspend fun release(memoryBytes: Long) {
usedMemoryBytes.addAndGet(-memoryBytes)
syncChannel.send(Unit)
}
}

interface AvailableMemoryProvider {
val availableMemoryBytes: Long
}

@Singleton
@Secondary
class JavaRuntimeAvailableMemoryProvider : AvailableMemoryProvider {
override val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.state

import io.micronaut.context.annotation.Replaces
import io.micronaut.context.annotation.Requires
import io.micronaut.test.extensions.junit5.annotation.MicronautTest
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicBoolean
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

@MicronautTest
class MemoryManagerTest {
@Singleton
@Replaces(MemoryManager::class)
@Requires(env = ["test"])
class MockAvailableMemoryProvider : AvailableMemoryProvider {
override val availableMemoryBytes: Long = 1000
}

@Test
fun testReserveBlocking() = runTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need tests imvolving several threads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, it took a lot but I did finally flush out the bug you were afraid of. See testReserveBlockingMultithreaded()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

took a lot of efforts, or a lot of jobs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of jobs to get one error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I usually test those by having some king of test-triggered breakpoint in the production code. As much as I dislike having test-only code in production codepath, anything else is always going to be very flaky or involve a lot of jobs/threads/workers...

val memoryManager = MemoryManager(MockAvailableMemoryProvider())
val reserved = AtomicBoolean(false)

try {
withTimeout(5000) { memoryManager.reserveBlocking(900) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we use a lock (or a channel) instead of a timeout here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm testing that it doesn't block, so I'm not sure what else I can do.

} catch (e: Exception) {
Assertions.fail<Unit>("Failed to reserve memory")
}

Assertions.assertEquals(100, memoryManager.remainingMemoryBytes)

val job = launch {
memoryManager.reserveBlocking(200)
reserved.set(true)
}

memoryManager.reserveBlocking(0)
Assertions.assertFalse(reserved.get())

memoryManager.release(50)
memoryManager.reserveBlocking(0)
Assertions.assertEquals(150, memoryManager.remainingMemoryBytes)
Assertions.assertFalse(reserved.get())

memoryManager.release(25)
memoryManager.reserveBlocking(0)
Assertions.assertEquals(175, memoryManager.remainingMemoryBytes)
Assertions.assertFalse(reserved.get())

memoryManager.release(25)
try {
withTimeout(5000) { job.join() }
} catch (e: Exception) {
Assertions.fail<Unit>("Failed to unblock reserving memory")
}
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
Assertions.assertTrue(reserved.get())
}

@Test
fun testReserveBlockingMultithreaded() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
withContext(Dispatchers.IO) {
memoryManager.reserveBlocking(1000)
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
val nIterations = 100000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess a lot of jobs...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, 10,000 wasn't enough to fail consistently. 100,000 and I'd always hit it at least once.


val jobs = (0 until nIterations).map { launch { memoryManager.reserveBlocking(10) } }

repeat(nIterations) {
memoryManager.release(10)
Assertions.assertTrue(
memoryManager.remainingMemoryBytes >= 0,
"Remaining memory is negative: ${memoryManager.remainingMemoryBytes}"
)
}
jobs.forEach { it.join() }
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes)
}
}

@Test
fun testRequestingMoreThanAvailableThrows() = runTest {
val memoryManager = MemoryManager(MockAvailableMemoryProvider())
try {
memoryManager.reserveBlocking(1001)
} catch (e: IllegalArgumentException) {
return@runTest
}
Assertions.fail<Unit>("Requesting more memory than available should throw an exception")
}
}
Loading