Skip to content

Commit

Permalink
Fix flaky streaming dataflow tests (#30572)
Browse files Browse the repository at this point in the history
* remove waiting/sleeping arbitratily in tests since it is leading to flakiness
  • Loading branch information
m-trieu authored Mar 14, 2024
1 parent 52c0d5a commit 14d25c3
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o
computationId -> Optional.ofNullable(computationMap.get(computationId)))),
clientId,
computationMap,
new WindmillStateCache(options.getWorkerCacheMb()),
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()),
createWorkUnitExecutor(options),
IntrinsicMapTaskExecutorFactory.defaultFactory(),
new DataflowWorkUnitClient(options, LOG),
Expand All @@ -502,7 +502,7 @@ static StreamingDataflowWorker forTesting(
Supplier<Instant> clock,
Function<String, ScheduledExecutorService> executorSupplier) {
BoundedQueueExecutor boundedQueueExecutor = createWorkUnitExecutor(options);
WindmillStateCache stateCache = new WindmillStateCache(options.getWorkerCacheMb());
WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
computationMap.putAll(
createComputationMapForTesting(mapTasks, boundedQueueExecutor, stateCache::forComputation));
return new StreamingDataflowWorker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Weigher;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.MapMaker;
import org.checkerframework.checker.nullness.qual.Nullable;

Expand All @@ -56,6 +55,7 @@
* thread at a time, so this is safe.
*/
public class WindmillStateCache implements StatusDataProvider {
private static final int STATE_CACHE_CONCURRENCY_LEVEL = 4;
// Convert Megabytes to bytes
private static final long MEGABYTES = 1024 * 1024;
// Estimate of overhead per StateId.
Expand All @@ -72,20 +72,28 @@ public class WindmillStateCache implements StatusDataProvider {
// Contains the current valid ForKey object. Entries in the cache are keyed by ForKey with pointer
// equality so entries may be invalidated by creating a new key object, rendering the previous
// entries inaccessible. They will be evicted through normal cache operation.
private final ConcurrentMap<WindmillComputationKey, ForKey> keyIndex =
new MapMaker().weakValues().concurrencyLevel(4).makeMap();
private final ConcurrentMap<WindmillComputationKey, ForKey> keyIndex;
private final long workerCacheBytes; // Copy workerCacheMb and convert to bytes.

public WindmillStateCache(long workerCacheMb) {
final Weigher<Weighted, Weighted> weigher = Weighers.weightedKeysAndValues();
workerCacheBytes = workerCacheMb * MEGABYTES;
stateCache =
private WindmillStateCache(
long workerCacheMb,
ConcurrentMap<WindmillComputationKey, ForKey> keyIndex,
Cache<StateId, StateCacheEntry> stateCache) {
this.workerCacheBytes = workerCacheMb * MEGABYTES;
this.stateCache = stateCache;
this.keyIndex = keyIndex;
}

public static WindmillStateCache ofSizeMbs(long workerCacheMb) {
return new WindmillStateCache(
workerCacheMb,
new MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap(),
CacheBuilder.newBuilder()
.maximumWeight(workerCacheBytes)
.maximumWeight(workerCacheMb * MEGABYTES)
.recordStats()
.weigher(weigher)
.concurrencyLevel(4)
.build();
.weigher(Weighers.weightedKeysAndValues())
.concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL)
.build());
}

private EntryStats calculateEntryStats() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public void setUp() {
"computationId",
new ReaderCache(Duration.standardMinutes(1), Executors.newCachedThreadPool()),
stateNameMap,
new WindmillStateCache(options.getWorkerCacheMb()).forComputation("comp"),
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation("comp"),
StreamingStepMetricsContainer.createRegistry(),
new DataflowExecutionStateTracker(
ExecutionStateSampler.newForTest(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,8 @@ public void testFailedWorkItemsAbort() throws Exception {
"computationId",
new ReaderCache(Duration.standardMinutes(1), Runnable::run),
/*stateNameMap=*/ ImmutableMap.of(),
new WindmillStateCache(options.getWorkerCacheMb()).forComputation("computationId"),
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb())
.forComputation("computationId"),
StreamingStepMetricsContainer.createRegistry(),
new DataflowExecutionStateTracker(
ExecutionStateSampler.newForTest(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ private static WindmillComputationKey computationKey(
@Before
public void setUp() {
options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
cache = new WindmillStateCache(400);
cache = WindmillStateCache.ofSizeMbs(400);
assertEquals(0, cache.getWeight());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ private static void assertTagMultimapUpdates(
public void setUp() {
MockitoAnnotations.initMocks(this);
options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
cache = new WindmillStateCache(options.getWorkerCacheMb());
cache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
resetUnderTest();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,89 +17,111 @@
*/
package org.apache.beam.runners.dataflow.worker.windmill.work.budget;

import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertFalse;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
public class GetWorkBudgetRefresherTest {
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private static final int WAIT_BUFFER = 10;
private final Runnable redistributeBudget = Mockito.mock(Runnable.class);
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);

private GetWorkBudgetRefresher createBudgetRefresher() {
return createBudgetRefresher(false);
private GetWorkBudgetRefresher createBudgetRefresher(Runnable redistributeBudget) {
return createBudgetRefresher(false, redistributeBudget);
}

private GetWorkBudgetRefresher createBudgetRefresher(Boolean isBudgetRefreshPaused) {
private GetWorkBudgetRefresher createBudgetRefresher(
boolean isBudgetRefreshPaused, Runnable redistributeBudget) {
return new GetWorkBudgetRefresher(() -> isBudgetRefreshPaused, redistributeBudget);
}

@Test
public void testStop_successfullyTerminates() throws InterruptedException {
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher();
CountDownLatch redistributeBudgetLatch = new CountDownLatch(1);
Runnable redistributeBudget = redistributeBudgetLatch::countDown;
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(redistributeBudget);
budgetRefresher.start();
budgetRefresher.stop();
budgetRefresher.requestBudgetRefresh();
Thread.sleep(WAIT_BUFFER);
verifyNoInteractions(redistributeBudget);
boolean redistributeBudgetRan =
redistributeBudgetLatch.await(WAIT_BUFFER, TimeUnit.MILLISECONDS);
// Make sure that redistributeBudgetLatch.countDown() is never called.
assertThat(redistributeBudgetLatch.getCount()).isEqualTo(1);
assertFalse(redistributeBudgetRan);
}

@Test
public void testRequestBudgetRefresh_triggersBudgetRefresh() throws InterruptedException {
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher();
CountDownLatch redistributeBudgetLatch = new CountDownLatch(1);
Runnable redistributeBudget = redistributeBudgetLatch::countDown;
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(redistributeBudget);
budgetRefresher.start();
budgetRefresher.requestBudgetRefresh();
// Wait a bit for redistribute budget to run.
Thread.sleep(WAIT_BUFFER);
verify(redistributeBudget, times(1)).run();
// Wait for redistribute budget to run.
redistributeBudgetLatch.await();
assertThat(redistributeBudgetLatch.getCount()).isEqualTo(0);
}

@Test
public void testScheduledBudgetRefresh() throws InterruptedException {
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher();
CountDownLatch redistributeBudgetLatch = new CountDownLatch(1);
Runnable redistributeBudget = redistributeBudgetLatch::countDown;
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(redistributeBudget);
budgetRefresher.start();
Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS + WAIT_BUFFER);
verify(redistributeBudget, times(1)).run();
// Wait for scheduled redistribute budget to run.
redistributeBudgetLatch.await();
assertThat(redistributeBudgetLatch.getCount()).isEqualTo(0);
}

@Test
public void testTriggeredAndScheduledBudgetRefresh_concurrent() throws InterruptedException {
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher();
CountDownLatch redistributeBudgetLatch = new CountDownLatch(2);
Runnable redistributeBudget = redistributeBudgetLatch::countDown;
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(redistributeBudget);
budgetRefresher.start();
Thread budgetRefreshTriggerThread = new Thread(budgetRefresher::requestBudgetRefresh);
budgetRefreshTriggerThread.start();
Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS + WAIT_BUFFER);
budgetRefreshTriggerThread.join();

// Wait a bit for redistribute budget to run.
Thread.sleep(WAIT_BUFFER);
verify(redistributeBudget, times(2)).run();
// Wait for triggered and scheduled redistribute budget to run.
redistributeBudgetLatch.await();
assertThat(redistributeBudgetLatch.getCount()).isEqualTo(0);
}

@Test
public void testTriggeredBudgetRefresh_doesNotRunWhenBudgetRefreshPaused()
throws InterruptedException {
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(true);
CountDownLatch redistributeBudgetLatch = new CountDownLatch(1);
Runnable redistributeBudget = redistributeBudgetLatch::countDown;
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(true, redistributeBudget);
budgetRefresher.start();
budgetRefresher.requestBudgetRefresh();
Thread.sleep(WAIT_BUFFER);
verifyNoInteractions(redistributeBudget);
boolean redistributeBudgetRan =
redistributeBudgetLatch.await(WAIT_BUFFER, TimeUnit.MILLISECONDS);
// Make sure that redistributeBudgetLatch.countDown() is never called.
assertThat(redistributeBudgetLatch.getCount()).isEqualTo(1);
assertFalse(redistributeBudgetRan);
}

@Test
public void testScheduledBudgetRefresh_doesNotRunWhenBudgetRefreshPaused()
throws InterruptedException {
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(true);
CountDownLatch redistributeBudgetLatch = new CountDownLatch(1);
Runnable redistributeBudget = redistributeBudgetLatch::countDown;
GetWorkBudgetRefresher budgetRefresher = createBudgetRefresher(true, redistributeBudget);
budgetRefresher.start();
Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS + WAIT_BUFFER);
verifyNoInteractions(redistributeBudget);
boolean redistributeBudgetRan =
redistributeBudgetLatch.await(
GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS + WAIT_BUFFER,
TimeUnit.MILLISECONDS);
// Make sure that redistributeBudgetLatch.countDown() is never called.
assertThat(redistributeBudgetLatch.getCount()).isEqualTo(1);
assertFalse(redistributeBudgetRan);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package org.apache.beam.runners.dataflow.worker.windmill.work.refresh;

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.after;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.api.services.dataflow.model.MapTask;
Expand Down Expand Up @@ -50,7 +52,6 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Test;
Expand Down Expand Up @@ -192,11 +193,11 @@ public void testActiveWorkRefresh() throws InterruptedException {
}

@Test
public void testInvalidateStuckCommits() {
public void testInvalidateStuckCommits() throws InterruptedException {
int stuckCommitDurationMillis = 100;
Table<ComputationState, Work, WindmillStateCache.ForComputation> computations =
HashBasedTable.create();
WindmillStateCache stateCache = new WindmillStateCache(100);
WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(100);
ByteString key = ByteString.EMPTY;
for (int i = 0; i < 5; i++) {
WindmillStateCache.ForComputation perComputationStateCache =
Expand All @@ -209,6 +210,19 @@ public void testInvalidateStuckCommits() {
}

TestClock fakeClock = new TestClock(Instant.now());
CountDownLatch invalidateStuckCommitRan = new CountDownLatch(computations.size());

// Count down the latch every time to avoid waiting/sleeping arbitrarily.
for (ComputationState computation : computations.rowKeySet()) {
doAnswer(
invocation -> {
invocation.callRealMethod();
invalidateStuckCommitRan.countDown();
return null;
})
.when(computation)
.invalidateStuckCommits(any(Instant.class));
}

ActiveWorkRefresher activeWorkRefresher =
createActiveWorkRefresher(
Expand All @@ -220,21 +234,20 @@ public void testInvalidateStuckCommits() {

activeWorkRefresher.start();
fakeClock.advance(Duration.millis(stuckCommitDurationMillis));
Uninterruptibles.sleepUninterruptibly(stuckCommitDurationMillis, TimeUnit.MILLISECONDS);
invalidateStuckCommitRan.await();
activeWorkRefresher.stop();

for (Table.Cell<ComputationState, Work, WindmillStateCache.ForComputation> cell :
computations.cellSet()) {
ComputationState computation = cell.getRowKey();
Work work = cell.getColumnKey();
WindmillStateCache.ForComputation perComputationStateCache = cell.getValue();
verify(perComputationStateCache, after((long) (stuckCommitDurationMillis * 1.5)).times(1))
verify(perComputationStateCache, times(1))
.invalidate(eq(key), eq(work.getWorkItem().getShardingKey()));
verify(computation, after((long) (stuckCommitDurationMillis * 1.5)).times(1))
verify(computation, times(1))
.completeWorkAndScheduleNextWorkForKey(
eq(ShardedKey.create(key, work.getWorkItem().getShardingKey())), eq(work.id()));
}

activeWorkRefresher.stop();
}

static class TestClock implements Clock {
Expand Down

0 comments on commit 14d25c3

Please sign in to comment.