Skip to content

Commit

Permalink
Added support for SparkRunner streaming stateful processing (#33267)
Browse files Browse the repository at this point in the history
  • Loading branch information
twosom authored Dec 18, 2024
1 parent e68a79c commit f476417
Show file tree
Hide file tree
Showing 19 changed files with 886 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test",
"https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test"
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test"
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test"
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test",
"https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test"
}
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@

## New Features / Improvements

* Added support for stateful processing in Spark Runner for streaming pipelines. Timer functionality is not yet supported and will be implemented in a future release ([#33237](https://github.com/apache/beam/issues/33237)).
* Improved batch performance of SparkRunner's GroupByKey ([#20943](https://github.com/apache/beam/pull/20943)).
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)).
* This enables initial Java GroupIntoBatches support.
* Support OrderedListState in Prism ([#32929](https://github.com/apache/beam/issues/32929)).
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## Breaking Changes

Expand Down
2 changes: 1 addition & 1 deletion runners/spark/spark_runner.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test)
excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment'

// State and Timers
excludeCategories 'org.apache.beam.sdk.testing.UsesStatefulParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'
excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap'
excludeCategories 'org.apache.beam.sdk.testing.UsesLoopingTimer'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.util.ArrayList;
import java.util.LinkedHashMap;
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet.StateAndTimers;
import org.apache.beam.runners.spark.stateful.StateAndTimers;
import org.apache.beam.runners.spark.translation.ValueAndCoderKryoSerializer;
import org.apache.beam.runners.spark.translation.ValueAndCoderLazySerializable;
import org.apache.beam.runners.spark.util.ByteArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
*/
package org.apache.beam.runners.spark.stateful;

import static org.apache.beam.runners.spark.translation.TranslationUtils.checkpointIfNeeded;
import static org.apache.beam.runners.spark.translation.TranslationUtils.getBatchDuration;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -35,7 +38,6 @@
import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine;
import org.apache.beam.runners.core.triggers.TriggerStateMachines;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.translation.ReifyTimestampsAndWindowsFunction;
import org.apache.beam.runners.spark.translation.TranslationUtils;
Expand All @@ -60,10 +62,8 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.dstream.DStream;
Expand Down Expand Up @@ -100,27 +100,6 @@ public class SparkGroupAlsoByWindowViaWindowSet implements Serializable {
private static final Logger LOG =
LoggerFactory.getLogger(SparkGroupAlsoByWindowViaWindowSet.class);

/** State and Timers wrapper. */
public static class StateAndTimers implements Serializable {
// Serializable state for internals (namespace to state tag to coded value).
private final Table<String, String, byte[]> state;
private final Collection<byte[]> serTimers;

private StateAndTimers(
final Table<String, String, byte[]> state, final Collection<byte[]> timers) {
this.state = state;
this.serTimers = timers;
}

Table<String, String, byte[]> getState() {
return state;
}

Collection<byte[]> getTimers() {
return serTimers;
}
}

private static class OutputWindowedValueHolder<K, V>
implements OutputWindowedValue<KV<K, Iterable<V>>> {
private final List<WindowedValue<KV<K, Iterable<V>>>> windowedValues = new ArrayList<>();
Expand Down Expand Up @@ -348,7 +327,7 @@ private Collection<TimerInternals.TimerData> filterTimersEligibleForProcessing(

// empty outputs are filtered later using DStream filtering
final StateAndTimers updated =
new StateAndTimers(
StateAndTimers.of(
stateInternals.getState(),
SparkTimerInternals.serializeTimers(
timerInternals.getTimers(), timerDataCoder));
Expand Down Expand Up @@ -466,21 +445,6 @@ private static <W extends BoundedWindow> TimerInternals.TimerDataCoderV2 timerDa
return TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder());
}

private static void checkpointIfNeeded(
final DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> firedStream,
final SerializablePipelineOptions options) {

final Long checkpointDurationMillis = getBatchDuration(options);

if (checkpointDurationMillis > 0) {
firedStream.checkpoint(new Duration(checkpointDurationMillis));
}
}

private static Long getBatchDuration(final SerializablePipelineOptions options) {
return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
}

private static <K, InputT> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> stripStateValues(
final DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> firedStream,
final Coder<K> keyCoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
class SparkStateInternals<K> implements StateInternals {
public class SparkStateInternals<K> implements StateInternals {

private final K key;
// Serializable state for internals (namespace to state tag to coded value).
Expand All @@ -79,11 +79,11 @@ private SparkStateInternals(K key, Table<String, String, byte[]> stateTable) {
this.stateTable = stateTable;
}

static <K> SparkStateInternals<K> forKey(K key) {
public static <K> SparkStateInternals<K> forKey(K key) {
return new SparkStateInternals<>(key);
}

static <K> SparkStateInternals<K> forKeyAndState(
public static <K> SparkStateInternals<K> forKeyAndState(
K key, Table<String, String, byte[]> stateTable) {
return new SparkStateInternals<>(key, stateTable);
}
Expand Down Expand Up @@ -412,17 +412,25 @@ public void put(MapKeyT key, MapValueT value) {
@Override
public ReadableState<MapValueT> computeIfAbsent(
MapKeyT key, Function<? super MapKeyT, ? extends MapValueT> mappingFunction) {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
Map<MapKeyT, MapValueT> sparkMapState = readAsMap();
MapValueT current = sparkMapState.get(key);
if (current == null) {
put(key, mappingFunction.apply(key));
}
return ReadableStates.immediate(current);
}

private Map<MapKeyT, MapValueT> readAsMap() {
Map<MapKeyT, MapValueT> mapState = readValue();
if (mapState == null) {
mapState = new HashMap<>();
}
return mapState;
}

@Override
public void remove(MapKeyT key) {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
Map<MapKeyT, MapValueT> sparkMapState = readAsMap();
sparkMapState.remove(key);
writeValue(sparkMapState);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public Collection<TimerData> getTimers() {
return timers;
}

void addTimers(Iterator<TimerData> timers) {
public void addTimers(Iterator<TimerData> timers) {
while (timers.hasNext()) {
TimerData timer = timers.next();
this.timers.add(timer);
Expand Down Expand Up @@ -163,7 +163,8 @@ public void setTimer(
Instant target,
Instant outputTimestamp,
TimeDomain timeDomain) {
throw new UnsupportedOperationException("Setting a timer by ID not yet supported.");
this.setTimer(
TimerData.of(timerId, timerFamilyId, namespace, target, outputTimestamp, timeDomain));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.stateful;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.Collection;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;

/** State and Timers wrapper. */
@AutoValue
public abstract class StateAndTimers implements Serializable {
public abstract Table<String, String, byte[]> getState();

public abstract Collection<byte[]> getTimers();

public static StateAndTimers of(
final Table<String, String, byte[]> state, final Collection<byte[]> timers) {
return new AutoValue_StateAndTimers.Builder().setState(state).setTimers(timers).build();
}

@AutoValue.Builder
abstract static class Builder {
abstract Builder setState(Table<String, String, byte[]> state);

abstract Builder setTimers(Collection<byte[]> timers);

abstract StateAndTimers build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
import org.joda.time.Instant;

/** DoFnRunner decorator which registers {@link MetricsContainerImpl}. */
class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
public class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
private final DoFnRunner<InputT, OutputT> delegate;
private final String stepName;
private final MetricsContainerStepMapAccumulator metricsAccum;

DoFnRunnerWithMetrics(
public DoFnRunnerWithMetrics(
String stepName,
DoFnRunner<InputT, OutputT> delegate,
MetricsContainerStepMapAccumulator metricsAccum) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
* Processes Spark's input data iterators using Beam's {@link
* org.apache.beam.runners.core.DoFnRunner}.
*/
interface SparkInputDataProcessor<FnInputT, FnOutputT, OutputT> {
public interface SparkInputDataProcessor<FnInputT, FnOutputT, OutputT> {

/**
* @return {@link OutputManager} to be used by {@link org.apache.beam.runners.core.DoFnRunner} for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import org.apache.beam.sdk.transforms.DoFn;

/** Holds current processing context for {@link SparkInputDataProcessor}. */
class SparkProcessContext<K, InputT, OutputT> {
public class SparkProcessContext<K, InputT, OutputT> {
private final String stepName;
private final DoFn<InputT, OutputT> doFn;
private final DoFnRunner<InputT, OutputT> doFnRunner;
private final Iterator<TimerInternals.TimerData> timerDataIterator;
private final K key;

SparkProcessContext(
public SparkProcessContext(
String stepName,
DoFn<InputT, OutputT> doFn,
DoFnRunner<InputT, OutputT> doFnRunner,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.beam.runners.core.InMemoryStateInternals;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateInternalsFactory;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.SparkRunner;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.util.ByteArray;
Expand Down Expand Up @@ -54,8 +56,10 @@
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.dstream.DStream;
import scala.Tuple2;

/** A set of utilities to help translating Beam transformations into Spark transformations. */
Expand Down Expand Up @@ -258,6 +262,52 @@ public Boolean call(Tuple2<TupleTag<V>, WindowedValue<?>> input) {
}
}

/**
* Retrieves the batch duration in milliseconds from Spark pipeline options.
*
* @param options The serializable pipeline options containing Spark-specific settings
* @return The checkpoint duration in milliseconds as specified in SparkPipelineOptions
*/
public static Long getBatchDuration(final SerializablePipelineOptions options) {
return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
}

/**
* Reject timers {@link DoFn}.
*
* @param doFn the {@link DoFn} to possibly reject.
*/
public static void rejectTimers(DoFn<?, ?> doFn) {
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
if (signature.timerDeclarations().size() > 0
|| signature.timerFamilyDeclarations().size() > 0) {
throw new UnsupportedOperationException(
String.format(
"Found %s annotations on %s, but %s cannot yet be used with timers in the %s.",
DoFn.TimerId.class.getSimpleName(),
doFn.getClass().getName(),
DoFn.class.getSimpleName(),
SparkRunner.class.getSimpleName()));
}
}

/**
* Checkpoints the given DStream if checkpointing is enabled in the pipeline options.
*
* @param dStream The DStream to be checkpointed
* @param options The SerializablePipelineOptions containing configuration settings including
* batch duration
*/
public static void checkpointIfNeeded(
final DStream<?> dStream, final SerializablePipelineOptions options) {

final Long checkpointDurationMillis = getBatchDuration(options);

if (checkpointDurationMillis > 0) {
dStream.checkpoint(new Duration(checkpointDurationMillis));
}
}

/**
* Reject state and timers {@link DoFn}.
*
Expand Down
Loading

0 comments on commit f476417

Please sign in to comment.