From 68f1543b6bfe4ceaa752c7f16fc2cae7393211fd Mon Sep 17 00:00:00 2001 From: martin trieu Date: Mon, 21 Oct 2024 03:05:43 -0600 Subject: [PATCH] Simplify budget distribution logic and new worker metadata consumption (#32775) --- .../FanOutStreamingEngineWorkerHarness.java | 379 ++++++++-------- .../harness/GlobalDataStreamSender.java | 63 +++ ...tate.java => StreamingEngineBackends.java} | 30 +- .../harness/WindmillStreamSender.java | 25 +- .../worker/windmill/WindmillEndpoints.java | 28 +- .../windmill/WindmillServiceAddress.java | 22 +- .../windmill/client/WindmillStream.java | 7 +- .../client/grpc/GrpcDirectGetWorkStream.java | 286 ++++++++----- .../client/grpc/GrpcGetDataStream.java | 2 +- .../client/grpc/GrpcGetWorkStream.java | 10 +- .../grpc/GrpcWindmillStreamFactory.java | 6 +- .../grpc/stubs/WindmillChannelFactory.java | 17 +- .../budget/EvenGetWorkBudgetDistributor.java | 59 +-- .../budget/GetWorkBudgetDistributors.java | 6 +- .../work/budget/GetWorkBudgetSpender.java | 8 +- .../dataflow/worker/FakeWindmillServer.java | 10 +- ...anOutStreamingEngineWorkerHarnessTest.java | 111 ++--- .../harness/WindmillStreamSenderTest.java | 4 +- .../grpc/GrpcDirectGetWorkStreamTest.java | 405 ++++++++++++++++++ .../EvenGetWorkBudgetDistributorTest.java | 186 ++------ 20 files changed, 998 insertions(+), 666 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java rename runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/{StreamingEngineConnectionState.java => StreamingEngineBackends.java} (55%) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 3556b7ce2919..458cf57ca8e7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -20,20 +20,25 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; -import java.util.Collection; -import java.util.List; +import java.io.Closeable; +import java.util.HashSet; import java.util.Map.Entry; +import java.util.NoSuchElementException; import java.util.Optional; -import java.util.Queue; -import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import java.util.stream.Collectors; import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -54,18 +59,14 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; -import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,32 +81,39 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorkerHarness { private static final Logger LOG = LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class); - private static final String PUBLISH_NEW_WORKER_METADATA_THREAD = "PublishNewWorkerMetadataThread"; - private static final String CONSUME_NEW_WORKER_METADATA_THREAD = "ConsumeNewWorkerMetadataThread"; + private static final String WORKER_METADATA_CONSUMER_THREAD_NAME = + "WindmillWorkerMetadataConsumerThread"; + private static final String STREAM_MANAGER_THREAD_NAME = "WindmillStreamManager-%d"; private final JobHeader jobHeader; private final GrpcWindmillStreamFactory streamFactory; private final WorkItemScheduler workItemScheduler; private final ChannelCachingStubFactory channelCachingStubFactory; private final GrpcDispatcherClient dispatcherClient; - private final AtomicBoolean isBudgetRefreshPaused; - private final GetWorkBudgetRefresher getWorkBudgetRefresher; - private final AtomicReference lastBudgetRefresh; + private final GetWorkBudgetDistributor getWorkBudgetDistributor; + private final GetWorkBudget totalGetWorkBudget; private final ThrottleTimer getWorkerMetadataThrottleTimer; - private final ExecutorService newWorkerMetadataPublisher; - private final ExecutorService newWorkerMetadataConsumer; - private final long clientId; - private final Supplier getWorkerMetadataStream; - private final Queue newWindmillEndpoints; private final Function workCommitterFactory; private final ThrottlingGetDataMetricTracker getDataMetricTracker; + private final ExecutorService windmillStreamManager; + private final ExecutorService workerMetadataConsumer; + private final Object metadataLock = new Object(); /** Writes are guarded by synchronization, reads are lock free. */ - private final AtomicReference connections; + private final AtomicReference backends; - private volatile boolean started; + @GuardedBy("this") + private long activeMetadataVersion; + + @GuardedBy("metadataLock") + private long pendingMetadataVersion; + + @GuardedBy("this") + private boolean started; + + @GuardedBy("this") + private @Nullable GetWorkerMetadataStream getWorkerMetadataStream; - @SuppressWarnings("FutureReturnValueIgnored") private FanOutStreamingEngineWorkerHarness( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, @@ -114,7 +122,6 @@ private FanOutStreamingEngineWorkerHarness( ChannelCachingStubFactory channelCachingStubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { this.jobHeader = jobHeader; @@ -122,42 +129,21 @@ private FanOutStreamingEngineWorkerHarness( this.started = false; this.streamFactory = streamFactory; this.workItemScheduler = workItemScheduler; - this.connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY); + this.backends = new AtomicReference<>(StreamingEngineBackends.EMPTY); this.channelCachingStubFactory = channelCachingStubFactory; this.dispatcherClient = dispatcherClient; - this.isBudgetRefreshPaused = new AtomicBoolean(false); this.getWorkerMetadataThrottleTimer = new ThrottleTimer(); - this.newWorkerMetadataPublisher = - singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD); - this.newWorkerMetadataConsumer = - singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD); - this.clientId = clientId; - this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH); - this.newWindmillEndpoints = Queues.synchronizedQueue(EvictingQueue.create(1)); - this.getWorkBudgetRefresher = - new GetWorkBudgetRefresher( - isBudgetRefreshPaused::get, - () -> { - getWorkBudgetDistributor.distributeBudget( - connections.get().windmillStreams().values(), totalGetWorkBudget); - lastBudgetRefresh.set(Instant.now()); - }); - this.getWorkerMetadataStream = - Suppliers.memoize( - () -> - streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getWindmillMetadataServiceStubBlocking(), - getWorkerMetadataThrottleTimer, - endpoints -> - // Run this on a separate thread than the grpc stream thread. - newWorkerMetadataPublisher.submit( - () -> newWindmillEndpoints.add(endpoints)))); + this.windmillStreamManager = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build()); + this.workerMetadataConsumer = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build()); + this.getWorkBudgetDistributor = getWorkBudgetDistributor; + this.totalGetWorkBudget = totalGetWorkBudget; + this.activeMetadataVersion = Long.MIN_VALUE; this.workCommitterFactory = workCommitterFactory; - } - - private static ExecutorService singleThreadedExecutorServiceOf(String threadName) { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()); + this.getWorkerMetadataStream = null; } /** @@ -183,7 +169,6 @@ public static FanOutStreamingEngineWorkerHarness create( channelCachingStubFactory, getWorkBudgetDistributor, dispatcherClient, - /* clientId= */ new Random().nextLong(), workCommitterFactory, getDataMetricTracker); } @@ -197,7 +182,6 @@ static FanOutStreamingEngineWorkerHarness forTesting( ChannelCachingStubFactory stubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkProvider = @@ -209,201 +193,218 @@ static FanOutStreamingEngineWorkerHarness forTesting( stubFactory, getWorkBudgetDistributor, dispatcherClient, - clientId, workCommitterFactory, getDataMetricTracker); fanOutStreamingEngineWorkProvider.start(); return fanOutStreamingEngineWorkProvider; } - @SuppressWarnings("ReturnValueIgnored") @Override public synchronized void start() { - Preconditions.checkState(!started, "StreamingEngineClient cannot start twice."); - // Starts the stream, this value is memoized. - getWorkerMetadataStream.get(); - startWorkerMetadataConsumer(); - getWorkBudgetRefresher.start(); + Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); + getWorkerMetadataStream = + streamFactory.createGetWorkerMetadataStream( + dispatcherClient.getWindmillMetadataServiceStubBlocking(), + getWorkerMetadataThrottleTimer, + this::consumeWorkerMetadata); started = true; } public ImmutableSet currentWindmillEndpoints() { - return connections.get().windmillConnections().keySet().stream() + return backends.get().windmillStreams().keySet().stream() .map(Endpoint::directEndpoint) .filter(Optional::isPresent) .map(Optional::get) - .filter( - windmillServiceAddress -> - windmillServiceAddress.getKind() != WindmillServiceAddress.Kind.IPV6) - .map( - windmillServiceAddress -> - windmillServiceAddress.getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS - ? windmillServiceAddress.gcpServiceAddress() - : windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress()) + .map(WindmillServiceAddress::getServiceAddress) .collect(toImmutableSet()); } /** - * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link - * GetDataStream} pointing to dispatcher. + * Fetches {@link GetDataStream} mapped to globalDataKey if or throws {@link + * NoSuchElementException} if one is not found. */ private GetDataStream getGlobalDataStream(String globalDataKey) { - return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey)) - .map(Supplier::get) - .orElseGet( - () -> - streamFactory.createGetDataStream( - dispatcherClient.getWindmillServiceStub(), new ThrottleTimer())); - } - - @SuppressWarnings("FutureReturnValueIgnored") - private void startWorkerMetadataConsumer() { - newWorkerMetadataConsumer.submit( - () -> { - while (true) { - Optional.ofNullable(newWindmillEndpoints.poll()) - .ifPresent(this::consumeWindmillWorkerEndpoints); - } - }); + return Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey)) + .map(GlobalDataStreamSender::get) + .orElseThrow( + () -> new NoSuchElementException("No endpoint for global data tag: " + globalDataKey)); } @VisibleForTesting @Override public synchronized void shutdown() { - Preconditions.checkState(started, "StreamingEngineClient never started."); - getWorkerMetadataStream.get().halfClose(); - getWorkBudgetRefresher.stop(); - newWorkerMetadataPublisher.shutdownNow(); - newWorkerMetadataConsumer.shutdownNow(); + Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started."); + Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); + workerMetadataConsumer.shutdownNow(); + closeStreamsNotIn(WindmillEndpoints.none()); channelCachingStubFactory.shutdown(); + + try { + Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.", e); + } + + windmillStreamManager.shutdown(); + boolean isStreamManagerShutdown = false; + try { + isStreamManagerShutdown = windmillStreamManager.awaitTermination(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.", e); + } + if (!isStreamManagerShutdown) { + windmillStreamManager.shutdownNow(); + } + } + + private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) { + synchronized (metadataLock) { + // Only process versions greater than what we currently have to prevent double processing of + // metadata. workerMetadataConsumer is single-threaded so we maintain ordering. + if (windmillEndpoints.version() > pendingMetadataVersion) { + pendingMetadataVersion = windmillEndpoints.version(); + workerMetadataConsumer.execute(() -> consumeWindmillWorkerEndpoints(windmillEndpoints)); + } + } } - /** - * {@link java.util.function.Consumer} used to update {@link #connections} on - * new backend worker metadata. - */ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWindmillEndpoints) { - isBudgetRefreshPaused.set(true); - LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints); - ImmutableMap newWindmillConnections = - createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints()); - - StreamingEngineConnectionState newConnectionsState = - StreamingEngineConnectionState.builder() - .setWindmillConnections(newWindmillConnections) - .setWindmillStreams( - closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values())) + // Since this is run on a single threaded executor, multiple versions of the metadata maybe + // queued up while a previous version of the windmillEndpoints were being consumed. Only consume + // the endpoints if they are the most current version. + synchronized (metadataLock) { + if (newWindmillEndpoints.version() < pendingMetadataVersion) { + return; + } + } + + LOG.debug( + "Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}", + newWindmillEndpoints, + activeMetadataVersion, + newWindmillEndpoints.version()); + closeStreamsNotIn(newWindmillEndpoints); + ImmutableMap newStreams = + createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join(); + StreamingEngineBackends newBackends = + StreamingEngineBackends.builder() + .setWindmillStreams(newStreams) .setGlobalDataStreams( createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints())) .build(); + backends.set(newBackends); + getWorkBudgetDistributor.distributeBudget(newStreams.values(), totalGetWorkBudget); + activeMetadataVersion = newWindmillEndpoints.version(); + } + + /** Close the streams that are no longer valid asynchronously. */ + private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { + StreamingEngineBackends currentBackends = backends.get(); + currentBackends.windmillStreams().entrySet().stream() + .filter( + connectionAndStream -> + !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) + .forEach( + entry -> + windmillStreamManager.execute( + () -> closeStreamSender(entry.getKey(), entry.getValue()))); - LOG.info( - "Setting new connections: {}. Previous connections: {}.", - newConnectionsState, - connections.get()); - connections.set(newConnectionsState); - isBudgetRefreshPaused.set(false); - getWorkBudgetRefresher.requestBudgetRefresh(); + Set newGlobalDataEndpoints = + new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); + currentBackends.globalDataStreams().values().stream() + .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) + .forEach( + sender -> + windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender))); + } + + private void closeStreamSender(Endpoint endpoint, Closeable sender) { + LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender); + try { + sender.close(); + endpoint.directEndpoint().ifPresent(channelCachingStubFactory::remove); + LOG.debug("Successfully closed streams to {}", endpoint); + } catch (Exception e) { + LOG.error("Error closing streams to endpoint={}, sender={}", endpoint, sender); + } + } + + private synchronized CompletableFuture> + createAndStartNewStreams(ImmutableSet newWindmillEndpoints) { + ImmutableMap currentStreams = backends.get().windmillStreams(); + return MoreFutures.allAsList( + newWindmillEndpoints.stream() + .map(endpoint -> getOrCreateWindmillStreamSenderFuture(endpoint, currentStreams)) + .collect(Collectors.toList())) + .thenApply( + backends -> backends.stream().collect(toImmutableMap(Pair::getLeft, Pair::getRight))) + .toCompletableFuture(); + } + + private CompletionStage> + getOrCreateWindmillStreamSenderFuture( + Endpoint endpoint, ImmutableMap currentStreams) { + return MoreFutures.supplyAsync( + () -> + Pair.of( + endpoint, + Optional.ofNullable(currentStreams.get(endpoint)) + .orElseGet(() -> createAndStartWindmillStreamSender(endpoint))), + windmillStreamManager); } /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ - public long getAndResetThrottleTimes() { - return connections.get().windmillStreams().values().stream() + public long getAndResetThrottleTime() { + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getAndResetThrottleTime) .reduce(0L, Long::sum) + getWorkerMetadataThrottleTimer.getAndResetThrottleTime(); } public long currentActiveCommitBytes() { - return connections.get().windmillStreams().values().stream() + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getCurrentActiveCommitBytes) .reduce(0L, Long::sum); } @VisibleForTesting - StreamingEngineConnectionState getCurrentConnections() { - return connections.get(); - } - - private synchronized ImmutableMap createNewWindmillConnections( - List newWindmillEndpoints) { - ImmutableMap currentConnections = - connections.get().windmillConnections(); - return newWindmillEndpoints.stream() - .collect( - toImmutableMap( - Function.identity(), - endpoint -> - // Reuse existing stubs if they exist. Optional.orElseGet only calls the - // supplier if the value is not present, preventing constructing expensive - // objects. - Optional.ofNullable(currentConnections.get(endpoint)) - .orElseGet( - () -> WindmillConnection.from(endpoint, this::createWindmillStub)))); + StreamingEngineBackends currentBackends() { + return backends.get(); } - private synchronized ImmutableMap - closeStaleStreamsAndCreateNewStreams(Collection newWindmillConnections) { - ImmutableMap currentStreams = - connections.get().windmillStreams(); - - // Close the streams that are no longer valid. - currentStreams.entrySet().stream() - .filter( - connectionAndStream -> !newWindmillConnections.contains(connectionAndStream.getKey())) - .forEach( - entry -> { - entry.getValue().closeAllStreams(); - entry.getKey().directEndpoint().ifPresent(channelCachingStubFactory::remove); - }); - - return newWindmillConnections.stream() - .collect( - toImmutableMap( - Function.identity(), - newConnection -> - Optional.ofNullable(currentStreams.get(newConnection)) - .orElseGet(() -> createAndStartWindmillStreamSenderFor(newConnection)))); - } - - private ImmutableMap> createNewGlobalDataStreams( + private ImmutableMap createNewGlobalDataStreams( ImmutableMap newGlobalDataEndpoints) { - ImmutableMap> currentGlobalDataStreams = - connections.get().globalDataStreams(); + ImmutableMap currentGlobalDataStreams = + backends.get().globalDataStreams(); return newGlobalDataEndpoints.entrySet().stream() .collect( toImmutableMap( Entry::getKey, keyedEndpoint -> - existingOrNewGetDataStreamFor(keyedEndpoint, currentGlobalDataStreams))); + getOrCreateGlobalDataSteam(keyedEndpoint, currentGlobalDataStreams))); } - private Supplier existingOrNewGetDataStreamFor( + private GlobalDataStreamSender getOrCreateGlobalDataSteam( Entry keyedEndpoint, - ImmutableMap> currentGlobalDataStreams) { - return Preconditions.checkNotNull( - currentGlobalDataStreams.getOrDefault( - keyedEndpoint.getKey(), + ImmutableMap currentGlobalDataStreams) { + return Optional.ofNullable(currentGlobalDataStreams.get(keyedEndpoint.getKey())) + .orElseGet( () -> - streamFactory.createGetDataStream( - newOrExistingStubFor(keyedEndpoint.getValue()), new ThrottleTimer()))); - } - - private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint endpoint) { - return Optional.ofNullable(connections.get().windmillConnections().get(endpoint)) - .map(WindmillConnection::stub) - .orElseGet(() -> createWindmillStub(endpoint)); + new GlobalDataStreamSender( + () -> + streamFactory.createGetDataStream( + createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), + keyedEndpoint.getValue())); } - private WindmillStreamSender createAndStartWindmillStreamSenderFor( - WindmillConnection connection) { - // Initially create each stream with no budget. The budget will be eventually assigned by the - // GetWorkBudgetDistributor. + private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) { WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( - connection, + WindmillConnection.from(endpoint, this::createWindmillStub), GetWorkRequest.newBuilder() - .setClientId(clientId) + .setClientId(jobHeader.getClientId()) .setJobId(jobHeader.getJobId()) .setProjectId(jobHeader.getProjectId()) .setWorkerId(jobHeader.getWorkerId()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java new file mode 100644 index 000000000000..ce5f3a7b6bfc --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import java.io.Closeable; +import java.util.function.Supplier; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; + +@Internal +@ThreadSafe +// TODO (m-trieu): replace Supplier with Stream after github.com/apache/beam/pull/32774/ is +// merged +final class GlobalDataStreamSender implements Closeable, Supplier { + private final Endpoint endpoint; + private final Supplier delegate; + private volatile boolean started; + + GlobalDataStreamSender(Supplier delegate, Endpoint endpoint) { + // Ensures that the Supplier is thread-safe + this.delegate = Suppliers.memoize(delegate::get); + this.started = false; + this.endpoint = endpoint; + } + + @Override + public GetDataStream get() { + if (!started) { + started = true; + } + + return delegate.get(); + } + + @Override + public void close() { + if (started) { + delegate.get().shutdown(); + } + } + + Endpoint endpoint() { + return endpoint; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java similarity index 55% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java index 3c85ee6abe1f..14290b486830 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java @@ -18,47 +18,37 @@ package org.apache.beam.runners.dataflow.worker.streaming.harness; import com.google.auto.value.AutoValue; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; /** - * Represents the current state of connections to Streaming Engine. Connections are updated when - * backend workers assigned to the key ranges being processed by this user worker change during + * Represents the current state of connections to the Streaming Engine backend. Backends are updated + * when backend workers assigned to the key ranges being processed by this user worker change during * pipeline execution. For example, changes can happen via autoscaling, load-balancing, or other * backend updates. */ @AutoValue -abstract class StreamingEngineConnectionState { - static final StreamingEngineConnectionState EMPTY = builder().build(); +abstract class StreamingEngineBackends { + static final StreamingEngineBackends EMPTY = builder().build(); static Builder builder() { - return new AutoValue_StreamingEngineConnectionState.Builder() - .setWindmillConnections(ImmutableMap.of()) + return new AutoValue_StreamingEngineBackends.Builder() .setWindmillStreams(ImmutableMap.of()) .setGlobalDataStreams(ImmutableMap.of()); } - abstract ImmutableMap windmillConnections(); - - abstract ImmutableMap windmillStreams(); + abstract ImmutableMap windmillStreams(); /** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. */ - abstract ImmutableMap> globalDataStreams(); + abstract ImmutableMap globalDataStreams(); @AutoValue.Builder abstract static class Builder { - public abstract Builder setWindmillConnections( - ImmutableMap value); - - public abstract Builder setWindmillStreams( - ImmutableMap value); + public abstract Builder setWindmillStreams(ImmutableMap value); public abstract Builder setGlobalDataStreams( - ImmutableMap> value); + ImmutableMap value); - public abstract StreamingEngineConnectionState build(); + public abstract StreamingEngineBackends build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 45aa403ee71b..744c3d74445f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import java.io.Closeable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -49,7 +50,7 @@ * {@link GetWorkBudget} is set. * *

Once started, the underlying streams are "alive" until they are manually closed via {@link - * #closeAllStreams()}. + * #close()} ()}. * *

If closed, it means that the backend endpoint is no longer in the worker set. Once closed, * these instances are not reused. @@ -59,7 +60,7 @@ */ @Internal @ThreadSafe -final class WindmillStreamSender implements GetWorkBudgetSpender { +final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable { private final AtomicBoolean started; private final AtomicReference getWorkBudget; private final Supplier getWorkStream; @@ -103,9 +104,9 @@ private WindmillStreamSender( connection, withRequestBudget(getWorkRequest, getWorkBudget.get()), streamingEngineThrottleTimers.getWorkThrottleTimer(), - () -> FixedStreamHeartbeatSender.create(getDataStream.get()), - () -> getDataClientFactory.apply(getDataStream.get()), - workCommitter, + FixedStreamHeartbeatSender.create(getDataStream.get()), + getDataClientFactory.apply(getDataStream.get()), + workCommitter.get(), workItemScheduler)); } @@ -141,7 +142,8 @@ void startStreams() { started.set(true); } - void closeAllStreams() { + @Override + public void close() { // Supplier.get() starts the stream which is an expensive operation as it initiates the // streaming RPCs by possibly making calls over the network. Do not close the streams unless // they have already been started. @@ -154,18 +156,13 @@ void closeAllStreams() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta)); + public void setBudget(long items, long bytes) { + getWorkBudget.set(getWorkBudget.get().apply(items, bytes)); if (started.get()) { - getWorkStream.get().adjustBudget(itemsDelta, bytesDelta); + getWorkStream.get().setBudget(items, bytes); } } - @Override - public GetWorkBudget remainingBudget() { - return started.get() ? getWorkStream.get().remainingBudget() : getWorkBudget.get(); - } - long getAndResetThrottleTime() { return streamingEngineThrottleTimers.getAndResetThrottleTime(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index d7ed83def43e..eb269eef848f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; import com.google.auto.value.AutoValue; import java.net.Inet6Address; @@ -27,8 +27,8 @@ import java.util.Map; import java.util.Optional; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,6 +41,14 @@ public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + public static WindmillEndpoints none() { + return WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) + .setWindmillEndpoints(ImmutableSet.of()) + .setGlobalDataEndpoints(ImmutableMap.of()) + .build(); + } + public static WindmillEndpoints from( Windmill.WorkerMetadataResponse workerMetadataResponseProto) { ImmutableMap globalDataServers = @@ -53,14 +61,15 @@ public static WindmillEndpoints from( endpoint.getValue(), workerMetadataResponseProto.getExternalEndpoint()))); - ImmutableList windmillServers = + ImmutableSet windmillServers = workerMetadataResponseProto.getWorkEndpointsList().stream() .map( endpointProto -> Endpoint.from(endpointProto, workerMetadataResponseProto.getExternalEndpoint())) - .collect(toImmutableList()); + .collect(toImmutableSet()); return WindmillEndpoints.builder() + .setVersion(workerMetadataResponseProto.getMetadataVersion()) .setGlobalDataEndpoints(globalDataServers) .setWindmillEndpoints(windmillServers) .build(); @@ -123,6 +132,9 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( directEndpointAddress.getHostAddress(), (int) endpointProto.getPort())); } + /** Version of the endpoints which increases with every modification. */ + public abstract long version(); + /** * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key * is a global data tag and the value is the endpoint where the data associated with the global @@ -138,7 +150,7 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( * Windmill servers. Returns a list of endpoints used to communicate with the corresponding * Windmill servers. */ - public abstract ImmutableList windmillEndpoints(); + public abstract ImmutableSet windmillEndpoints(); /** * Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with @@ -204,13 +216,15 @@ public abstract static class Builder { @AutoValue.Builder public abstract static class Builder { + public abstract Builder setVersion(long version); + public abstract Builder setGlobalDataEndpoints( ImmutableMap globalDataServers); public abstract Builder setWindmillEndpoints( - ImmutableList windmillServers); + ImmutableSet windmillServers); - abstract ImmutableList.Builder windmillEndpointsBuilder(); + abstract ImmutableSet.Builder windmillEndpointsBuilder(); public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint endpoint) { windmillEndpointsBuilder().add(endpoint); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java index 90f93b072673..0b895652efe2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java @@ -19,38 +19,36 @@ import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; -import java.net.Inet6Address; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; /** Used to create channels to communicate with Streaming Engine via gRpc. */ @AutoOneOf(WindmillServiceAddress.Kind.class) public abstract class WindmillServiceAddress { - public static WindmillServiceAddress create(Inet6Address ipv6Address) { - return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address); - } public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) { return AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress); } - public abstract Kind getKind(); + public static WindmillServiceAddress create( + AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { + return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( + authenticatedGcpServiceAddress); + } - public abstract Inet6Address ipv6(); + public abstract Kind getKind(); public abstract HostAndPort gcpServiceAddress(); public abstract AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress(); - public static WindmillServiceAddress create( - AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { - return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( - authenticatedGcpServiceAddress); + public final HostAndPort getServiceAddress() { + return getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS + ? gcpServiceAddress() + : authenticatedGcpServiceAddress().gcpServiceAddress(); } public enum Kind { - IPV6, GCP_SERVICE_ADDRESS, - // TODO(m-trieu): Use for direct connections when ALTS is enabled. AUTHENTICATED_GCP_SERVICE_ADDRESS } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 31bd4e146a78..f26c56b14ec2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -56,10 +56,11 @@ public interface WindmillStream { @ThreadSafe interface GetWorkStream extends WindmillStream { /** Adjusts the {@link GetWorkBudget} for the stream. */ - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(GetWorkBudget newBudget); - /** Returns the remaining in-flight {@link GetWorkBudget}. */ - GetWorkBudget remainingBudget(); + default void setBudget(long newItems, long newBytes) { + setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); + } } /** Interface for streaming GetDataRequests to Windmill. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 19de998b1da8..b27ebc8e9eee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -21,9 +21,11 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.concurrent.GuardedBy; +import net.jcip.annotations.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -44,8 +46,8 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Implementation of {@link GetWorkStream} that passes along a specific {@link @@ -55,9 +57,10 @@ * these direct streams are used to facilitate these RPC calls to specific backend workers. */ @Internal -public final class GrpcDirectGetWorkStream +final class GrpcDirectGetWorkStream extends AbstractWindmillStream implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class); private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST = StreamingGetWorkRequest.newBuilder() .setRequestExtension( @@ -67,15 +70,14 @@ public final class GrpcDirectGetWorkStream .build()) .build(); - private final AtomicReference inFlightBudget; - private final AtomicReference nextBudgetAdjustment; - private final AtomicReference pendingResponseBudget; - private final GetWorkRequest request; + private final GetWorkBudgetTracker budgetTracker; + private final GetWorkRequest requestHeader; private final WorkItemScheduler workItemScheduler; private final ThrottleTimer getWorkThrottleTimer; - private final Supplier heartbeatSender; - private final Supplier workCommitter; - private final Supplier getDataClient; + private final HeartbeatSender heartbeatSender; + private final WorkCommitter workCommitter; + private final GetDataClient getDataClient; + private final AtomicReference lastRequest; /** * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they @@ -92,15 +94,15 @@ private GrpcDirectGetWorkStream( StreamObserver, StreamObserver> startGetWorkRpcFn, - GetWorkRequest request, + GetWorkRequest requestHeader, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { super( "GetWorkStream", @@ -110,19 +112,23 @@ private GrpcDirectGetWorkStream( streamRegistry, logEveryNStreamFailures, backendWorkerToken); - this.request = request; + this.requestHeader = requestHeader; this.getWorkThrottleTimer = getWorkThrottleTimer; this.workItemScheduler = workItemScheduler; this.workItemAssemblers = new ConcurrentHashMap<>(); - this.heartbeatSender = Suppliers.memoize(heartbeatSender::get); - this.workCommitter = Suppliers.memoize(workCommitter::get); - this.getDataClient = Suppliers.memoize(getDataClient::get); - this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); - this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); - this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); + this.heartbeatSender = heartbeatSender; + this.workCommitter = workCommitter; + this.getDataClient = getDataClient; + this.lastRequest = new AtomicReference<>(); + this.budgetTracker = + new GetWorkBudgetTracker( + GetWorkBudget.builder() + .setItems(requestHeader.getMaxItems()) + .setBytes(requestHeader.getMaxBytes()) + .build()); } - public static GrpcDirectGetWorkStream create( + static GrpcDirectGetWorkStream create( String backendWorkerToken, Function< StreamObserver, @@ -134,9 +140,9 @@ public static GrpcDirectGetWorkStream create( Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { GrpcDirectGetWorkStream getWorkStream = new GrpcDirectGetWorkStream( @@ -165,46 +171,52 @@ private static Watermarks createWatermarks( .build(); } - private void sendRequestExtension(GetWorkBudget adjustment) { - inFlightBudget.getAndUpdate(budget -> budget.apply(adjustment)); - StreamingGetWorkRequest extension = - StreamingGetWorkRequest.newBuilder() - .setRequestExtension( - Windmill.StreamingGetWorkRequestExtension.newBuilder() - .setMaxItems(adjustment.items()) - .setMaxBytes(adjustment.bytes())) - .build(); - - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); + /** + * @implNote Do not lock/synchronize here due to this running on grpc serial executor for message + * which can deadlock since we send on the stream beneath the synchronization. {@link + * AbstractWindmillStream#send(Object)} is synchronized so the sends are already guarded. + */ + private void maybeSendRequestExtension(GetWorkBudget extension) { + if (extension.items() > 0 || extension.bytes() > 0) { + executeSafely( + () -> { + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(extension.items()) + .setMaxBytes(extension.bytes())) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(extension); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was closed. + } + }); + } } @Override protected synchronized void onNewStream() { workItemAssemblers.clear(); - // Add the current in-flight budget to the next adjustment. Only positive values are allowed - // here - // with negatives defaulting to 0, since GetWorkBudgets cannot be created with negative values. - GetWorkBudget budgetAdjustment = nextBudgetAdjustment.get().apply(inFlightBudget.get()); - inFlightBudget.set(budgetAdjustment); - send( - StreamingGetWorkRequest.newBuilder() - .setRequest( - request - .toBuilder() - .setMaxBytes(budgetAdjustment.bytes()) - .setMaxItems(budgetAdjustment.items())) - .build()); - - // We just sent the budget, reset it. - nextBudgetAdjustment.set(GetWorkBudget.noBudget()); + if (!isShutdown()) { + budgetTracker.reset(); + GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequest( + requestHeader + .toBuilder() + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build()) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(initialGetWorkBudget); + send(request); + } } @Override @@ -216,8 +228,9 @@ protected boolean hasPendingRequests() { public void appendSpecificHtml(PrintWriter writer) { // Number of buffers is same as distinct workers that sent work on this stream. writer.format( - "GetWorkStream: %d buffers, %s inflight budget allowed.", - workItemAssemblers.size(), inFlightBudget.get()); + "GetWorkStream: %d buffers, " + "last sent request: %s; ", + workItemAssemblers.size(), lastRequest.get()); + writer.print(budgetTracker.debugString()); } @Override @@ -235,30 +248,22 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) { } private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) { - // Record the fact that there are now fewer outstanding messages and bytes on the stream. - inFlightBudget.updateAndGet(budget -> budget.subtract(1, assembledWorkItem.bufferedSize())); WorkItem workItem = assembledWorkItem.workItem(); GetWorkResponseChunkAssembler.ComputationMetadata metadata = assembledWorkItem.computationMetadata(); - pendingResponseBudget.getAndUpdate(budget -> budget.apply(1, workItem.getSerializedSize())); - try { - workItemScheduler.scheduleWork( - workItem, - createWatermarks(workItem, Preconditions.checkNotNull(metadata)), - createProcessingContext(Preconditions.checkNotNull(metadata.computationId())), - assembledWorkItem.latencyAttributions()); - } finally { - pendingResponseBudget.getAndUpdate(budget -> budget.apply(-1, -workItem.getSerializedSize())); - } + workItemScheduler.scheduleWork( + workItem, + createWatermarks(workItem, metadata), + createProcessingContext(metadata.computationId()), + assembledWorkItem.latencyAttributions()); + budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize()); + GetWorkBudget extension = budgetTracker.computeBudgetExtension(); + maybeSendRequestExtension(extension); } private Work.ProcessingContext createProcessingContext(String computationId) { return Work.createProcessingContext( - computationId, - getDataClient.get(), - workCommitter.get()::commit, - heartbeatSender.get(), - backendWorkerToken()); + computationId, getDataClient, workCommitter::commit, heartbeatSender, backendWorkerToken()); } @Override @@ -267,25 +272,110 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - GetWorkBudget adjustment = - nextBudgetAdjustment - // Get the current value, and reset the nextBudgetAdjustment. This will be set again - // when adjustBudget is called. - .getAndUpdate(unused -> GetWorkBudget.noBudget()) - .apply(itemsDelta, bytesDelta); - sendRequestExtension(adjustment); + public void setBudget(GetWorkBudget newBudget) { + GetWorkBudget extension = budgetTracker.consumeAndComputeBudgetUpdate(newBudget); + maybeSendRequestExtension(extension); } - @Override - public GetWorkBudget remainingBudget() { - // Snapshot the current budgets. - GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get(); - GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get(); - GetWorkBudget currentInflightBudget = inFlightBudget.get(); - - return currentPendingResponseBudget - .apply(currentNextBudgetAdjustment) - .apply(currentInflightBudget); + private void executeSafely(Runnable runnable) { + try { + executor().execute(runnable); + } catch (RejectedExecutionException e) { + LOG.debug("{} has been shutdown.", getClass()); + } + } + + /** + * Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request + * extensions. + */ + @ThreadSafe + private static final class GetWorkBudgetTracker { + + @GuardedBy("GetWorkBudgetTracker.this") + private GetWorkBudget maxGetWorkBudget; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsReceived = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesReceived = 0; + + private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) { + this.maxGetWorkBudget = maxGetWorkBudget; + } + + private synchronized void reset() { + itemsRequested = 0; + bytesRequested = 0; + itemsReceived = 0; + bytesReceived = 0; + } + + private synchronized String debugString() { + return String.format( + "max budget: %s; " + + "in-flight budget: %s; " + + "total budget requested: %s; " + + "total budget received: %s.", + maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(), totalReceivedBudget()); + } + + /** Consumes the new budget and computes an extension based on the new budget. */ + private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) { + maxGetWorkBudget = newBudget; + return computeBudgetExtension(); + } + + private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) { + itemsRequested += budgetRequested.items(); + bytesRequested += budgetRequested.bytes(); + } + + private synchronized void recordBudgetReceived(long returnedBudget) { + itemsReceived++; + bytesReceived += returnedBudget; + } + + /** + * If the outstanding items or bytes limit has gotten too low, top both off with a + * GetWorkExtension. The goal is to keep the limits relatively close to their maximum values + * without sending too many extension requests. + */ + private synchronized GetWorkBudget computeBudgetExtension() { + // Expected items and bytes can go negative here, since WorkItems returned might be larger + // than the initially requested budget. + long inFlightItems = itemsRequested - itemsReceived; + long inFlightBytes = bytesRequested - bytesReceived; + + // Don't send negative budget extensions. + long requestBytes = Math.max(0, maxGetWorkBudget.bytes() - inFlightBytes); + long requestItems = Math.max(0, maxGetWorkBudget.items() - inFlightItems); + + return (inFlightItems > requestItems / 2 && inFlightBytes > requestBytes / 2) + ? GetWorkBudget.noBudget() + : GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build(); + } + + private synchronized GetWorkBudget inFlightBudget() { + return GetWorkBudget.builder() + .setItems(itemsRequested - itemsReceived) + .setBytes(bytesRequested - bytesReceived) + .build(); + } + + private synchronized GetWorkBudget totalRequestedBudget() { + return GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build(); + } + + private synchronized GetWorkBudget totalReceivedBudget() { + return GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build(); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 0e9a0c6316ee..c99e05a77074 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -59,7 +59,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcGetDataStream +final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 09ecbf3f3051..a368f3fec235 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -194,15 +194,7 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op } - - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setBytes(request.getMaxBytes() - inflightBytes.get()) - .setItems(request.getMaxItems() - inflightMessages.get()) - .build(); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 92f031db9972..9e6a02d135e2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -198,9 +198,9 @@ public GetWorkStream createDirectGetWorkStream( WindmillConnection connection, GetWorkRequest request, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( connection.backendWorkerToken(), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java index 9aec29a3ba4d..f0ea2f550a74 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java @@ -36,7 +36,6 @@ /** Utility class used to create different RPC Channels. */ public final class WindmillChannelFactory { public static final String LOCALHOST = "localhost"; - private static final int DEFAULT_GRPC_PORT = 443; private static final int MAX_REMOTE_TRACE_EVENTS = 100; private WindmillChannelFactory() {} @@ -55,8 +54,6 @@ public static Channel localhostChannel(int port) { public static ManagedChannel remoteChannel( WindmillServiceAddress windmillServiceAddress, int windmillServiceRpcChannelTimeoutSec) { switch (windmillServiceAddress.getKind()) { - case IPV6: - return remoteChannel(windmillServiceAddress.ipv6(), windmillServiceRpcChannelTimeoutSec); case GCP_SERVICE_ADDRESS: return remoteChannel( windmillServiceAddress.gcpServiceAddress(), windmillServiceRpcChannelTimeoutSec); @@ -67,7 +64,8 @@ public static ManagedChannel remoteChannel( windmillServiceRpcChannelTimeoutSec); default: throw new UnsupportedOperationException( - "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS are supported WindmillServiceAddresses."); + "Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS are supported" + + " WindmillServiceAddresses."); } } @@ -105,17 +103,6 @@ public static Channel remoteChannel( } } - public static ManagedChannel remoteChannel( - Inet6Address directEndpoint, int windmillServiceRpcChannelTimeoutSec) { - try { - return createRemoteChannel( - NettyChannelBuilder.forAddress(new InetSocketAddress(directEndpoint, DEFAULT_GRPC_PORT)), - windmillServiceRpcChannelTimeoutSec); - } catch (SSLException sslException) { - throw new WindmillChannelCreationException(directEndpoint.toString(), sslException); - } - } - @SuppressWarnings("nullness") private static ManagedChannel createRemoteChannel( NettyChannelBuilder channelBuilder, int windmillServiceRpcChannelTimeoutSec) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java index 403bb99efb4c..8a1ba2556cf2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java @@ -17,18 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide; import java.math.RoundingMode; -import java.util.Map; -import java.util.Map.Entry; -import java.util.function.Function; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,22 +29,11 @@ @Internal final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor { private static final Logger LOG = LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class); - private final Supplier activeWorkBudgetSupplier; - - EvenGetWorkBudgetDistributor(Supplier activeWorkBudgetSupplier) { - this.activeWorkBudgetSupplier = activeWorkBudgetSupplier; - } - - private static boolean isBelowFiftyPercentOfTarget( - GetWorkBudget remaining, GetWorkBudget target) { - return remaining.items() < roundToLong(target.items() * 0.5, RoundingMode.CEILING) - || remaining.bytes() < roundToLong(target.bytes() * 0.5, RoundingMode.CEILING); - } @Override public void distributeBudget( - ImmutableCollection budgetOwners, GetWorkBudget getWorkBudget) { - if (budgetOwners.isEmpty()) { + ImmutableCollection budgetSpenders, GetWorkBudget getWorkBudget) { + if (budgetSpenders.isEmpty()) { LOG.debug("Cannot distribute budget to no owners."); return; } @@ -61,38 +43,15 @@ public void distributeBudget( return; } - Map desiredBudgets = computeDesiredBudgets(budgetOwners, getWorkBudget); - - for (Entry streamAndDesiredBudget : desiredBudgets.entrySet()) { - GetWorkBudgetSpender getWorkBudgetSpender = streamAndDesiredBudget.getKey(); - GetWorkBudget desired = streamAndDesiredBudget.getValue(); - GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget(); - if (isBelowFiftyPercentOfTarget(remaining, desired)) { - GetWorkBudget adjustment = desired.subtract(remaining); - getWorkBudgetSpender.adjustBudget(adjustment); - } - } + GetWorkBudget budgetPerStream = computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget); + budgetSpenders.forEach(getWorkBudgetSpender -> getWorkBudgetSpender.setBudget(budgetPerStream)); } - private ImmutableMap computeDesiredBudgets( + private GetWorkBudget computeDesiredPerStreamBudget( ImmutableCollection streams, GetWorkBudget totalGetWorkBudget) { - GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get(); - LOG.info("Current active work budget: {}", activeWorkBudget); - // TODO: Fix possibly non-deterministic handing out of budgets. - // Rounding up here will drift upwards over the lifetime of the streams. - GetWorkBudget budgetPerStream = - GetWorkBudget.builder() - .setItems( - divide( - totalGetWorkBudget.items() - activeWorkBudget.items(), - streams.size(), - RoundingMode.CEILING)) - .setBytes( - divide( - totalGetWorkBudget.bytes() - activeWorkBudget.bytes(), - streams.size(), - RoundingMode.CEILING)) - .build(); - return streams.stream().collect(toImmutableMap(Function.identity(), unused -> budgetPerStream)); + return GetWorkBudget.builder() + .setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING)) + .setBytes(divide(totalGetWorkBudget.bytes(), streams.size(), RoundingMode.CEILING)) + .build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java index 43c0d46139da..2013c9ff1cb7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java @@ -17,13 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; @Internal public final class GetWorkBudgetDistributors { - public static GetWorkBudgetDistributor distributeEvenly( - Supplier activeWorkBudgetSupplier) { - return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier); + public static GetWorkBudgetDistributor distributeEvenly() { + return new EvenGetWorkBudgetDistributor(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java index 254b2589062e..decf101a641b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java @@ -22,11 +22,9 @@ * org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget} */ public interface GetWorkBudgetSpender { - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(long items, long bytes); - default void adjustBudget(GetWorkBudget adjustment) { - adjustBudget(adjustment.items(), adjustment.bytes()); + default void setBudget(GetWorkBudget budget) { + setBudget(budget.items(), budget.bytes()); } - - GetWorkBudget remainingBudget(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index b3f7467cdbd3..90ffb3d3fbcf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -245,18 +245,10 @@ public void halfClose() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op. } - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setItems(request.getMaxItems()) - .setBytes(request.getMaxBytes()) - .build(); - } - @Override public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { while (done.getCount() > 0) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index ed8815c48e76..0092fcc7bcd1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -30,9 +30,7 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashSet; -import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; @@ -46,7 +44,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; @@ -71,7 +68,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.junit.After; import org.junit.Before; @@ -92,7 +88,6 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString()) .build()); - private static final long CLIENT_ID = 1L; private static final String JOB_ID = "jobId"; private static final String PROJECT_ID = "projectId"; private static final String WORKER_ID = "workerId"; @@ -101,6 +96,7 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) + .setClientId(1L) .build(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @@ -134,7 +130,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) - .setClientId(CLIENT_ID) + .setClientId(JOB_HEADER.getClientId()) .setMaxItems(items) .setMaxBytes(bytes) .build(); @@ -174,7 +170,7 @@ public void cleanUp() { stubFactory.shutdown(); } - private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( + private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness( GetWorkBudget getWorkBudget, GetWorkBudgetDistributor getWorkBudgetDistributor, WorkItemScheduler workItemScheduler) { @@ -186,7 +182,6 @@ private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( stubFactory, getWorkBudgetDistributor, dispatcherClient, - CLIENT_ID, ignored -> mock(WorkCommitter.class), new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); } @@ -201,7 +196,7 @@ public void testStreamsStartCorrectly() throws InterruptedException { spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(items).setBytes(bytes).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -219,16 +214,14 @@ public void testStreamsStartCorrectly() throws InterruptedException { getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); - assertEquals(2, currentConnections.windmillConnections().size()); - assertEquals(2, currentConnections.windmillStreams().size()); + assertEquals(2, currentBackends.windmillStreams().size()); Set workerTokens = - currentConnections.windmillConnections().values().stream() - .map(WindmillConnection::backendWorkerToken) + currentBackends.windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertTrue(workerTokens.contains(workerToken)); @@ -252,27 +245,6 @@ public void testStreamsStartCorrectly() throws InterruptedException { verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); } - @Test - public void testScheduledBudgetRefresh() throws InterruptedException { - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(2)); - fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( - GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), - getWorkBudgetDistributor, - noOpProcessWorkItemFn()); - - getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata( - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(1) - .addWorkEndpoints(metadataResponseEndpoint("workerToken")) - .putAllGlobalDataEndpoints(DEFAULT) - .build()); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), any()); - } - @Test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { @@ -280,7 +252,7 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor(metadataCount)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -309,32 +281,28 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() WorkerMetadataResponse.Endpoint.newBuilder() .setBackendWorkerToken(workerToken3) .build()) - .putAllGlobalDataEndpoints(DEFAULT) .build(); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); - assertEquals(1, currentConnections.windmillConnections().size()); - assertEquals(1, currentConnections.windmillStreams().size()); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); + assertEquals(1, currentBackends.windmillStreams().size()); Set workerTokens = - fanOutStreamingEngineWorkProvider.getCurrentConnections().windmillConnections().values() - .stream() - .map(WindmillConnection::backendWorkerToken) + fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertFalse(workerTokens.contains(workerToken)); assertFalse(workerTokens.contains(workerToken2)); + assertTrue(currentBackends.globalDataStreams().isEmpty()); } @Test public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedException { String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; - String workerToken3 = "workerToken3"; WorkerMetadataResponse firstWorkerMetadata = WorkerMetadataResponse.newBuilder() @@ -354,42 +322,24 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); - WorkerMetadataResponse thirdWorkerMetadata = - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(3) - .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder() - .setBackendWorkerToken(workerToken3) - .build()) - .putAllGlobalDataEndpoints(DEFAULT) - .build(); - - List workerMetadataResponses = - Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata, thirdWorkerMetadata); TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size())); + spy(new TestGetWorkBudgetDistributor(1)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); getWorkerMetadataReady.await(); - // Make sure we are injecting the metadata from smallest to largest. - workerMetadataResponses.stream() - .sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion)) - .forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata); - - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size())) - .distributeBudget(any(), any()); - } + fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + getWorkBudgetDistributor.expectNumDistributions(1); + fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - private void waitForWorkerMetadataToBeConsumed( - TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException { - getWorkBudgetDistributor.waitForBudgetDistribution(); + verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } private static class GetWorkerMetadataTestStub @@ -434,21 +384,24 @@ private void injectWorkerMetadata(WorkerMetadataResponse response) { } private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor { - private final CountDownLatch getWorkBudgetDistributorTriggered; + private CountDownLatch getWorkBudgetDistributorTriggered; private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) { this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } - @SuppressWarnings("ReturnValueIgnored") - private void waitForBudgetDistribution() throws InterruptedException { - getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + private boolean waitForBudgetDistribution() throws InterruptedException { + return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + } + + private void expectNumDistributions(int numBudgetDistributionsExpected) { + this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } @Override public void distributeBudget( ImmutableCollection streams, GetWorkBudget getWorkBudget) { - streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), getWorkBudget.bytes())); + streams.forEach(stream -> stream.setBudget(getWorkBudget.items(), getWorkBudget.bytes())); getWorkBudgetDistributorTriggered.countDown(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index dc6cc5641055..32d1f5738086 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -193,7 +193,7 @@ public void testCloseAllStreams_doesNotCloseUnstartedStreams() { WindmillStreamSender windmillStreamSender = newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build()); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); verifyNoInteractions(streamFactory); } @@ -230,7 +230,7 @@ public void testCloseAllStreams_closesAllStreams() { mockStreamFactory); windmillStreamSender.startStreams(); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); verify(mockGetWorkStream).shutdown(); verify(mockGetDataStream).shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java new file mode 100644 index 000000000000..fd2b30238836 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final WorkItemScheduler NO_OP_WORK_ITEM_SCHEDULER = + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setClientId(1L) + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + private static void assertHeader( + Windmill.StreamingGetWorkRequest getWorkRequest, GetWorkBudget expectedInitialBudget) { + assertTrue(getWorkRequest.hasRequest()); + assertFalse(getWorkRequest.hasRequestExtension()); + assertThat(getWorkRequest.getRequest()) + .isEqualTo( + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(expectedInitialBudget.items()) + .setMaxBytes(expectedInitialBudget.bytes()) + .build()); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertHeader(requestObserver.sent().get(0), GetWorkBudget.noBudget()); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertHeader(Iterables.getOnlyElement(requests), GetWorkBudget.noBudget()); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set scheduledWorkItems = new HashSet<>(); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> + scheduledWorkItems.add(work)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), throttleTimer, NO_OP_WORK_ITEM_SCHEDULER); + stream.startThrottleTimer(); + assertTrue(throttleTimer.throttled()); + testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); + assertFalse(throttleTimer.throttled()); + } + + private static class GetWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetWorkRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetWorkStreamTestStub(TestGetWorkRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectResponse(Windmill.StreamingGetWorkResponseChunk responseChunk) { + checkNotNull(responseObserver).onNext(responseChunk); + } + } + + private static class TestGetWorkRequestObserver + implements StreamObserver { + private final List requests = + Collections.synchronizedList(new ArrayList<>()); + private final CountDownLatch waitForRequests; + private @Nullable volatile StreamObserver + responseObserver; + + public TestGetWorkRequestObserver(CountDownLatch waitForRequests) { + this.waitForRequests = waitForRequests; + } + + @Override + public void onNext(Windmill.StreamingGetWorkRequest request) { + requests.add(request); + waitForRequests.countDown(); + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + + List sent() { + return requests; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java index 3cda4559c100..c76d5a584184 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -17,9 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -40,169 +38,79 @@ public class EvenGetWorkBudgetDistributorTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private static GetWorkBudgetDistributor createBudgetDistributor(GetWorkBudget activeWorkBudget) { - return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget); - } + private static GetWorkBudgetSpender createGetWorkBudgetOwner() { + // Lambdas are final and cannot be spied. + return spy( + new GetWorkBudgetSpender() { - private static GetWorkBudgetDistributor createBudgetDistributor(long activeWorkItemsAndBytes) { - return createBudgetDistributor( - GetWorkBudget.builder() - .setItems(activeWorkItemsAndBytes) - .setBytes(activeWorkItemsAndBytes) - .build()); + @Override + public void setBudget(long items, long bytes) {} + }); } @Test public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { - createBudgetDistributor(1L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.of(), GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); } @Test public void testDistributeBudget_doesNothingWithNoBudget() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())); - createBudgetDistributor(1L) + GetWorkBudgetSpender getWorkBudgetSpender = createGetWorkBudgetOwner(); + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget(ImmutableList.of(getWorkBudgetSpender), GetWorkBudget.noBudget()); verifyNoInteractions(getWorkBudgetSpender); } @Test - public void testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(10L).setBytes(10L).build())); - createBudgetDistributor(0L) - .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(5L).setBytes(5L).build())); - createBudgetDistributor(10L) + public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { + int totalStreams = 10; + long totalItems = 10L; + long totalBytes = 100L; + List streams = new ArrayList<>(); + for (int i = 0; i < totalStreams; i++) { + streams.add(createGetWorkBudgetOwner()); + } + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(20L).setBytes(20L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq( - totalGetWorkBudget.items() - - streamRemainingBudget.items() - - activeWorkItemsAndBytes), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); + ImmutableList.copyOf(streams), + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq( - totalGetWorkBudget.bytes() - - streamRemainingBudget.bytes() - - activeWorkItemsAndBytes)); + streams.forEach( + stream -> + verify(stream, times(1)) + .setBudget(eq(GetWorkBudget.builder().setItems(1L).setBytes(10L).build()))); } @Test - public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { - long totalItemsAndBytes = 10L; + public void testDistributeBudget_distributesFairlyWhenNotEven() { + long totalItems = 10L; + long totalBytes = 19L; List streams = new ArrayList<>(); - for (int i = 0; i < totalItemsAndBytes; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < 3; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), - GetWorkBudget.builder() - .setItems(totalItemsAndBytes) - .setBytes(totalItemsAndBytes) - .build()); + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + .setBudget(eq(GetWorkBudget.builder().setItems(4L).setBytes(7L).build()))); } @Test - public void testDistributeBudget_distributesFairlyWhenNotEven() { + public void testDistributeBudget_distributesBudgetEvenly() { long totalItemsAndBytes = 10L; List streams = new ArrayList<>(); - for (int i = 0; i < 3; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < totalItemsAndBytes; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), GetWorkBudget.builder() @@ -210,24 +118,10 @@ public void testDistributeBudget_distributesFairlyWhenNotEven() { .setBytes(totalItemsAndBytes) .build()); - long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes / (streams.size() * 1.0)); + long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); - } - - private GetWorkBudgetSpender createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget getWorkBudget) { - return spy( - new GetWorkBudgetSpender() { - @Override - public void adjustBudget(long itemsDelta, long bytesDelta) {} - - @Override - public GetWorkBudget remainingBudget() { - return getWorkBudget; - } - }); + .setBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); } }