Skip to content

Commit

Permalink
SAMZA-2796: Introduce config knob for framework thread sub DAG execut…
Browse files Browse the repository at this point in the history
…ion (#1691)

Description
As part of SAMZA-2781, we use framework thread pool to execute hand-offs and sub-DAG execution. We want to add a config knob to enable users opt-in to the feature as opposed to enable it by default.

Changes
Introduce config knob to use the framework executor

Tests
Added unit tests

Usage Instructions
Refer to the configuration documentation. To enable framework thread pool for sub-DAG execution and message hand off, set job.operator.framework.executor.enabled to true
  • Loading branch information
mynameborat authored Nov 21, 2023
1 parent 65f31eb commit 66495b6
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 53 deletions.
10 changes: 10 additions & 0 deletions docs/learn/documentation/versioned/jobs/configuration-table.html
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ <h1>Samza Configuration Reference</h1>
</td>
</tr>

<tr>
<td class="property" id="job.operator.framework.executor.enabled">job.operator.framework.executor.enabled</td>
<td class="default">false</td>
<td class="description">
If enabled, framework thread pool will be used for message hand off and sub DAG execution. Otherwise, the
execution will fall back to using caller thread or java fork join pool depending on the type of work
chained as part of message hand off.
</td>
</tr>

<tr>
<!-- change link to StandAlone design/tutorial doc. SAMZA-1299 -->
<th colspan="3" class="section" id="ZkBasedJobCoordination"><a href="../index.html">Zookeeper-based job configuration</a></th>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ public class JobConfig extends MapConfig {
public static final String JOB_ELASTICITY_FACTOR = "job.elasticity.factor";
public static final int DEFAULT_JOB_ELASTICITY_FACTOR = 1;

public static final String JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = "job.operator.framework.executor.enabled";

public static final boolean DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = false;

public JobConfig(Config config) {
super(config);
}
Expand Down Expand Up @@ -528,4 +532,8 @@ public int getElasticityFactor() {
public String getCoordinatorExecuteCommand() {
return get(COORDINATOR_EXECUTE_COMMAND, DEFAULT_COORDINATOR_EXECUTE_COMMAND);
}

public boolean getOperatorFrameworkExecutorEnabled() {
return getBoolean(JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED, DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.samza.SamzaException;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
Expand Down Expand Up @@ -95,6 +97,7 @@ public abstract class OperatorImpl<M, RM> {
private ControlMessageSender controlMessageSender;
private int elasticityFactor;
private ExecutorService operatorExecutor;
private boolean operatorExecutorEnabled;

/**
* Initialize this {@link OperatorImpl} and its user-defined functions.
Expand Down Expand Up @@ -136,7 +139,9 @@ public final void init(InternalTaskContext internalTaskContext) {
this.taskModel = taskContext.getTaskModel();
this.callbackScheduler = taskContext.getCallbackScheduler();
handleInit(context);
this.elasticityFactor = new JobConfig(config).getElasticityFactor();
JobConfig jobConfig = new JobConfig(config);
this.elasticityFactor = jobConfig.getElasticityFactor();
this.operatorExecutorEnabled = jobConfig.getOperatorFrameworkExecutorEnabled();
this.operatorExecutor = context.getTaskContext().getOperatorExecutor();

initialized = true;
Expand Down Expand Up @@ -192,21 +197,20 @@ public final CompletionStage<Void> onMessageAsync(M message, MessageCollector co
getOpImplId(), getOperatorSpec().getSourceLocation(), expectedType, actualType), e);
}

CompletionStage<Void> result = completableResultsFuture.thenComposeAsync(results -> {
CompletionStage<Void> result = composeFutureWithExecutor(completableResultsFuture, results -> {
long endNs = this.highResClock.nanoTime();
this.handleMessageNs.update(endNs - startNs);

return CompletableFuture.allOf(results.stream()
.flatMap(r -> this.registeredOperators.stream()
.map(op -> op.onMessageAsync(r, collector, coordinator)))
.flatMap(r -> this.registeredOperators.stream().map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));
}, operatorExecutor);
});

WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn();
if (watermarkFn != null) {
// check whether there is new watermark emitted from the user function
Long outputWm = watermarkFn.getOutputWatermark();
return result.thenComposeAsync(ignored -> propagateWatermark(outputWm, collector, coordinator), operatorExecutor);
return composeFutureWithExecutor(result, ignored -> propagateWatermark(outputWm, collector, coordinator));
}

return result;
Expand Down Expand Up @@ -245,11 +249,9 @@ public final CompletionStage<Void> onTimer(MessageCollector collector, TaskCoord
.map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));

return resultFuture.thenComposeAsync(x ->
CompletableFuture.allOf(this.registeredOperators
.stream()
.map(op -> op.onTimer(collector, coordinator))
.toArray(CompletableFuture[]::new)), operatorExecutor);
return composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(this.registeredOperators.stream()
.map(op -> op.onTimer(collector, coordinator))
.toArray(CompletableFuture[]::new)));
}

/**
Expand Down Expand Up @@ -315,15 +317,14 @@ public final CompletionStage<Void> aggregateEndOfStream(EndOfStreamMessage eos,
}

// populate the end-of-stream through the dag
endOfStreamFuture = onEndOfStream(collector, coordinator)
.thenAcceptAsync(result -> {
if (eosStates.allEndOfStream()) {
// all inputs have been end-of-stream, shut down the task
LOG.info("All input streams have reached the end for task {}", taskName.getTaskName());
coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
}
}, operatorExecutor);
endOfStreamFuture = acceptFutureWithExecutor(onEndOfStream(collector, coordinator), result -> {
if (eosStates.allEndOfStream()) {
// all inputs have been end-of-stream, shut down the task
LOG.info("All input streams have reached the end for task {}", taskName.getTaskName());
coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
}
});
}

return endOfStreamFuture;
Expand All @@ -347,10 +348,10 @@ private CompletionStage<Void> onEndOfStream(MessageCollector collector, TaskCoor
.map(op -> op.onMessageAsync(r, collector, coordinator)))
.toArray(CompletableFuture[]::new));

endOfStreamFuture = resultFuture.thenComposeAsync(x ->
CompletableFuture.allOf(this.registeredOperators.stream()
endOfStreamFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(
this.registeredOperators.stream()
.map(op -> op.onEndOfStream(collector, coordinator))
.toArray(CompletableFuture[]::new)), operatorExecutor);
.toArray(CompletableFuture[]::new)));
}

return endOfStreamFuture;
Expand Down Expand Up @@ -406,15 +407,14 @@ public final CompletionStage<Void> aggregateDrainMessages(DrainMessage drainMess
controlMessageSender.broadcastToOtherPartitions(new DrainMessage(drainMessage.getRunId()), ssp, collector);
}

drainFuture = onDrainOfStream(collector, coordinator)
.thenAcceptAsync(result -> {
if (drainStates.areAllStreamsDrained()) {
// All input streams have been drained, shut down the task
LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName());
coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
}
}, operatorExecutor);
drainFuture = acceptFutureWithExecutor(onDrainOfStream(collector, coordinator), result -> {
if (drainStates.areAllStreamsDrained()) {
// All input streams have been drained, shut down the task
LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName());
coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
}
});
}

return drainFuture;
Expand All @@ -439,10 +439,10 @@ private CompletionStage<Void> onDrainOfStream(MessageCollector collector, TaskCo
.toArray(CompletableFuture[]::new));

// propagate DrainMessage to downstream operators
drainFuture = resultFuture.thenComposeAsync(x ->
CompletableFuture.allOf(this.registeredOperators.stream()
drainFuture = composeFutureWithExecutor(resultFuture, x -> CompletableFuture.allOf(
this.registeredOperators.stream()
.map(op -> op.onDrainOfStream(collector, coordinator))
.toArray(CompletableFuture[]::new)), operatorExecutor);
.toArray(CompletableFuture[]::new)));
}

return drainFuture;
Expand Down Expand Up @@ -474,8 +474,8 @@ public final CompletionStage<Void> aggregateWatermark(WatermarkMessage watermark
controlMessageSender.broadcastToOtherPartitions(new WatermarkMessage(watermark), ssp, collector);
}
// populate the watermark through the dag
watermarkFuture = onWatermark(watermark, collector, coordinator)
.thenAcceptAsync(ignored -> watermarkStates.updateAggregateMetric(ssp, watermark), operatorExecutor);
watermarkFuture = acceptFutureWithExecutor(onWatermark(watermark, collector, coordinator),
ignored -> watermarkStates.updateAggregateMetric(ssp, watermark));
}

return watermarkFuture;
Expand Down Expand Up @@ -530,8 +530,8 @@ private CompletionStage<Void> onWatermark(long watermark, MessageCollector colle
.toArray(CompletableFuture[]::new));
}

watermarkFuture = watermarkFuture.thenComposeAsync(res -> propagateWatermark(outputWm, collector, coordinator),
operatorExecutor);
watermarkFuture =
composeFutureWithExecutor(watermarkFuture, res -> propagateWatermark(outputWm, collector, coordinator));
}

return watermarkFuture;
Expand Down Expand Up @@ -679,6 +679,20 @@ final Collection<RM> handleMessage(M message, MessageCollector collector, TaskCo
.toCompletableFuture().join();
}

@VisibleForTesting
final <T, U> CompletionStage<U> composeFutureWithExecutor(CompletionStage<T> futureToChain,
Function<? super T, ? extends CompletionStage<U>> fn) {
return operatorExecutorEnabled ? futureToChain.thenComposeAsync(fn, operatorExecutor)
: futureToChain.thenCompose(fn);
}

@VisibleForTesting
final <T> CompletionStage<Void> acceptFutureWithExecutor(CompletionStage<T> futureToChain,
Consumer<? super T> consumer) {
return operatorExecutorEnabled ? futureToChain.thenAcceptAsync(consumer, operatorExecutor)
: futureToChain.thenAccept(consumer);
}

private HighResolutionClock createHighResClock(Config config) {
MetricsConfig metricsConfig = new MetricsConfig(config);
// The timer metrics calculation here is only enabled for debugging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,13 @@ class TaskInstance(
val jobConfig = new JobConfig(jobContext.getConfig)
val taskExecutorFactory = ReflectionUtil.getObj(jobConfig.getTaskExecutorFactory, classOf[TaskExecutorFactory])

var operatorExecutor = Option.empty[java.util.concurrent.ExecutorService].orNull
if (jobConfig.getOperatorFrameworkExecutorEnabled) {
operatorExecutor = taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig)
}
new TaskContextImpl(taskModel, metrics.registry, kvStoreSupplier, tableManager,
new CallbackSchedulerImpl(epochTimeScheduler), offsetManager, jobModel, streamMetadataCache,
systemStreamPartitions, taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig))
systemStreamPartitions, operatorExecutor)
}
// need separate field for this instead of using it through Context, since Context throws an exception if it is null
private val applicationTaskContextOption = applicationTaskContextFactoryOption
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@
*/
package org.apache.samza.operators.impl;

import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.samza.config.Config;
import org.apache.samza.config.MapConfig;
import org.apache.samza.context.ContainerContext;
import org.apache.samza.context.Context;
import org.apache.samza.context.InternalTaskContext;
import org.apache.samza.context.MockContext;
import org.apache.samza.context.JobContext;
import org.apache.samza.context.TaskContext;
import org.apache.samza.job.model.TaskModel;
import org.apache.samza.metrics.Counter;
import org.apache.samza.metrics.MetricsRegistryMap;
Expand All @@ -44,33 +52,111 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;


public class TestOperatorImpl {
private Context context;
private InternalTaskContext internalTaskContext;

private JobContext jobContext;

private TaskContext taskContext;

private ContainerContext containerContext;

@Before
public void setup() {
this.context = new MockContext();
this.context = mock(Context.class);
this.internalTaskContext = mock(InternalTaskContext.class);
this.jobContext = mock(JobContext.class);
this.taskContext = mock(TaskContext.class);
this.containerContext = mock(ContainerContext.class);
when(this.internalTaskContext.getContext()).thenReturn(this.context);
// might be necessary in the future
when(this.internalTaskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(mock(EndOfStreamStates.class));
when(this.internalTaskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
when(this.context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class));
when(this.context.getTaskContext().getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor());
when(this.context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap());
when(this.context.getJobContext()).thenReturn(jobContext);
when(this.context.getTaskContext()).thenReturn(taskContext);
when(this.taskContext.getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
when(this.taskContext.getTaskModel()).thenReturn(mock(TaskModel.class));
when(this.taskContext.getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor());
when(this.context.getContainerContext()).thenReturn(containerContext);
when(containerContext.getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap());
}

@Test
public void testComposeFutureWithExecutorWithFrameworkExecutorEnabled() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
ExecutorService mockExecutor = mock(ExecutorService.class);
CompletionStage<Object> mockFuture = mock(CompletionStage.class);
Function<Object, CompletionStage<Object>> mockFunction = mock(Function.class);

Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true"));

when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
when(this.jobContext.getConfig()).thenReturn(config);

opImpl.init(this.internalTaskContext);
opImpl.composeFutureWithExecutor(mockFuture, mockFunction);

verify(mockFuture).thenComposeAsync(eq(mockFunction), eq(mockExecutor));
}

@Test
public void testComposeFutureWithExecutorWithFrameworkExecutorDisabled() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
ExecutorService mockExecutor = mock(ExecutorService.class);
CompletionStage<Object> mockFuture = mock(CompletionStage.class);
Function<Object, CompletionStage<Object>> mockFunction = mock(Function.class);

Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false"));

when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
when(this.jobContext.getConfig()).thenReturn(config);

opImpl.init(this.internalTaskContext);
opImpl.composeFutureWithExecutor(mockFuture, mockFunction);

verify(mockFuture).thenCompose(eq(mockFunction));
}

@Test
public void testAcceptFutureWithExecutorWithFrameworkExecutorDisabled() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
ExecutorService mockExecutor = mock(ExecutorService.class);
CompletionStage<Object> mockFuture = mock(CompletionStage.class);
Consumer<Object> mockConsumer = mock(Consumer.class);

Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false"));

when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
when(this.jobContext.getConfig()).thenReturn(config);

opImpl.init(this.internalTaskContext);
opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer);

verify(mockFuture).thenAccept(eq(mockConsumer));
}

@Test
public void testAcceptFutureWithExecutorWithFrameworkExecutorEnabled() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
ExecutorService mockExecutor = mock(ExecutorService.class);
CompletionStage<Object> mockFuture = mock(CompletionStage.class);
Consumer<Object> mockConsumer = mock(Consumer.class);

Config config = new MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true"));

when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
when(this.jobContext.getConfig()).thenReturn(config);

opImpl.init(this.internalTaskContext);
opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer);

verify(mockFuture).thenAcceptAsync(eq(mockConsumer), eq(mockExecutor));
}
@Test(expected = IllegalStateException.class)
public void testMultipleInitShouldThrow() {
OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
Expand Down

0 comments on commit 66495b6

Please sign in to comment.