-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bulk Load CDK: Unit tests for memory manager #45091
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,12 @@ | |
|
||
package io.airbyte.cdk.state | ||
|
||
import io.micronaut.context.annotation.Secondary | ||
import jakarta.inject.Singleton | ||
import java.util.concurrent.atomic.AtomicLong | ||
import java.util.concurrent.locks.ReentrantLock | ||
import kotlin.concurrent.withLock | ||
import kotlinx.coroutines.channels.Channel | ||
import kotlinx.coroutines.sync.Mutex | ||
import kotlinx.coroutines.sync.withLock | ||
|
||
/** | ||
* Manages memory usage for the destination. | ||
|
@@ -17,31 +19,49 @@ import kotlin.concurrent.withLock | |
* TODO: Some degree of logging/monitoring around how accurate we're actually being? | ||
*/ | ||
@Singleton | ||
class MemoryManager { | ||
private val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory() | ||
class MemoryManager(availableMemoryProvider: AvailableMemoryProvider) { | ||
private val totalMemoryBytes: Long = availableMemoryProvider.availableMemoryBytes | ||
private var usedMemoryBytes = AtomicLong(0L) | ||
private val memoryLock = ReentrantLock() | ||
private val memoryLockCondition = memoryLock.newCondition() | ||
private val mutex = Mutex() | ||
private val syncChannel = Channel<Unit>(Channel.UNLIMITED) | ||
|
||
val remainingMemoryBytes: Long | ||
get() = totalMemoryBytes - usedMemoryBytes.get() | ||
|
||
/* Attempt to reserve memory. If enough memory is not available, waits until it is, then reserves. */ | ||
suspend fun reserveBlocking(memoryBytes: Long) { | ||
memoryLock.withLock { | ||
while (usedMemoryBytes.get() + memoryBytes > availableMemoryBytes) { | ||
memoryLockCondition.await() | ||
if (memoryBytes > totalMemoryBytes) { | ||
throw IllegalArgumentException( | ||
"Requested ${memoryBytes}b memory exceeds ${totalMemoryBytes}b total" | ||
) | ||
} | ||
|
||
mutex.withLock { | ||
while (usedMemoryBytes.get() + memoryBytes > totalMemoryBytes) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I'm pretty sure this is correct. |
||
syncChannel.receive() | ||
} | ||
usedMemoryBytes.addAndGet(memoryBytes) | ||
} | ||
} | ||
|
||
suspend fun reserveRatio(ratio: Double): Long { | ||
val estimatedSize = (availableMemoryBytes.toDouble() * ratio).toLong() | ||
val estimatedSize = (totalMemoryBytes.toDouble() * ratio).toLong() | ||
reserveBlocking(estimatedSize) | ||
return estimatedSize | ||
} | ||
|
||
fun release(memoryBytes: Long) { | ||
memoryLock.withLock { | ||
usedMemoryBytes.addAndGet(-memoryBytes) | ||
memoryLockCondition.signalAll() | ||
} | ||
suspend fun release(memoryBytes: Long) { | ||
usedMemoryBytes.addAndGet(-memoryBytes) | ||
syncChannel.send(Unit) | ||
} | ||
} | ||
|
||
interface AvailableMemoryProvider { | ||
val availableMemoryBytes: Long | ||
} | ||
|
||
@Singleton | ||
@Secondary | ||
class JavaRuntimeAvailableMemoryProvider : AvailableMemoryProvider { | ||
override val availableMemoryBytes: Long = Runtime.getRuntime().maxMemory() | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* | ||
* Copyright (c) 2024 Airbyte, Inc., all rights reserved. | ||
*/ | ||
|
||
package io.airbyte.cdk.state | ||
|
||
import io.micronaut.context.annotation.Replaces | ||
import io.micronaut.context.annotation.Requires | ||
import io.micronaut.test.extensions.junit5.annotation.MicronautTest | ||
import jakarta.inject.Singleton | ||
import java.util.concurrent.atomic.AtomicBoolean | ||
import kotlinx.coroutines.Dispatchers | ||
import kotlinx.coroutines.launch | ||
import kotlinx.coroutines.test.runTest | ||
import kotlinx.coroutines.withContext | ||
import kotlinx.coroutines.withTimeout | ||
import org.junit.jupiter.api.Assertions | ||
import org.junit.jupiter.api.Test | ||
|
||
@MicronautTest | ||
class MemoryManagerTest { | ||
@Singleton | ||
@Replaces(MemoryManager::class) | ||
@Requires(env = ["test"]) | ||
class MockAvailableMemoryProvider : AvailableMemoryProvider { | ||
override val availableMemoryBytes: Long = 1000 | ||
} | ||
|
||
@Test | ||
fun testReserveBlocking() = runTest { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need tests imvolving several threads There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, it took a lot but I did finally flush out the bug you were afraid of. See There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. took a lot of efforts, or a lot of jobs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A lot of jobs to get one error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I usually test those by having some king of test-triggered breakpoint in the production code. As much as I dislike having test-only code in production codepath, anything else is always going to be very flaky or involve a lot of jobs/threads/workers... |
||
val memoryManager = MemoryManager(MockAvailableMemoryProvider()) | ||
val reserved = AtomicBoolean(false) | ||
|
||
try { | ||
withTimeout(5000) { memoryManager.reserveBlocking(900) } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we use a lock (or a channel) instead of a timeout here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm testing that it doesn't block, so I'm not sure what else I can do. |
||
} catch (e: Exception) { | ||
Assertions.fail<Unit>("Failed to reserve memory") | ||
} | ||
|
||
Assertions.assertEquals(100, memoryManager.remainingMemoryBytes) | ||
|
||
val job = launch { | ||
memoryManager.reserveBlocking(200) | ||
reserved.set(true) | ||
} | ||
|
||
memoryManager.reserveBlocking(0) | ||
Assertions.assertFalse(reserved.get()) | ||
|
||
memoryManager.release(50) | ||
memoryManager.reserveBlocking(0) | ||
Assertions.assertEquals(150, memoryManager.remainingMemoryBytes) | ||
Assertions.assertFalse(reserved.get()) | ||
|
||
memoryManager.release(25) | ||
memoryManager.reserveBlocking(0) | ||
Assertions.assertEquals(175, memoryManager.remainingMemoryBytes) | ||
Assertions.assertFalse(reserved.get()) | ||
|
||
memoryManager.release(25) | ||
try { | ||
withTimeout(5000) { job.join() } | ||
} catch (e: Exception) { | ||
Assertions.fail<Unit>("Failed to unblock reserving memory") | ||
} | ||
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes) | ||
Assertions.assertTrue(reserved.get()) | ||
} | ||
|
||
@Test | ||
fun testReserveBlockingMultithreaded() = runTest { | ||
val memoryManager = MemoryManager(MockAvailableMemoryProvider()) | ||
withContext(Dispatchers.IO) { | ||
memoryManager.reserveBlocking(1000) | ||
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes) | ||
val nIterations = 100000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess a lot of jobs... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, 10,000 wasn't enough to fail consistently. 100,000 and I'd always hit it at least once. |
||
|
||
val jobs = (0 until nIterations).map { launch { memoryManager.reserveBlocking(10) } } | ||
|
||
repeat(nIterations) { | ||
memoryManager.release(10) | ||
Assertions.assertTrue( | ||
memoryManager.remainingMemoryBytes >= 0, | ||
"Remaining memory is negative: ${memoryManager.remainingMemoryBytes}" | ||
) | ||
} | ||
jobs.forEach { it.join() } | ||
Assertions.assertEquals(0, memoryManager.remainingMemoryBytes) | ||
} | ||
} | ||
|
||
@Test | ||
fun testRequestingMoreThanAvailableThrows() = runTest { | ||
val memoryManager = MemoryManager(MockAvailableMemoryProvider()) | ||
try { | ||
memoryManager.reserveBlocking(1001) | ||
} catch (e: IllegalArgumentException) { | ||
return@runTest | ||
} | ||
Assertions.fail<Unit>("Requesting more memory than available should throw an exception") | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add a
TODO: Be more fair. The current implementation will guarantee a first-come first-served order, which means if a function requests a large amount of memory that's not available yet, any subsequent call will block until the large memory request is satisfied, even though there may be enough available memory to serve the smaller requests