Skip to content

Commit

Permalink
Enable MapState and SetState for dataflow streaming engine pipelines …
Browse files Browse the repository at this point in the history
…with legacy runner by building on top of MultimapState. (#31453)
  • Loading branch information
scwhittle authored Jul 4, 2024
1 parent a5eee58 commit c08afea
Show file tree
Hide file tree
Showing 21 changed files with 573 additions and 197 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

* Multiple RunInference instances can now share the same model instance by setting the model_identifier parameter (Python) ([#31665](https://github.com/apache/beam/issues/31665)).
* Removed a 3rd party LGPL dependency from the Go SDK ([#31765](https://github.com/apache/beam/issues/31765)).
* Support for MapState and SetState when using Dataflow Runner v1 with Streaming Engine (Java) ([[#18200](https://github.com/apache/beam/issues/18200)])

## Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,14 @@ public static <KeyT> StateTag<MapState<KeyT, Boolean>> convertToMapTagInternal(
new StructuredId(setTag.getId()), StateSpecs.convertToMapSpecInternal(setTag.getSpec()));
}

public static <KeyT, ValueT> StateTag<MultimapState<KeyT, ValueT>> convertToMultiMapTagInternal(
StateTag<MapState<KeyT, ValueT>> mapTag) {
StateSpec<MapState<KeyT, ValueT>> spec = mapTag.getSpec();
StateSpec<MultimapState<KeyT, ValueT>> multimapSpec =
StateSpecs.convertToMultimapSpecInternal(spec);
return new SimpleStateTag<>(new StructuredId(mapTag.getId()), multimapSpec);
}

private static class StructuredId implements Serializable {
private final StateKind kind;
private final String rawId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2564,11 +2564,6 @@ static boolean useUnifiedWorker(DataflowPipelineOptions options) {
|| hasExperiment(options, "use_portable_job_submission");
}

static boolean useStreamingEngine(DataflowPipelineOptions options) {
return hasExperiment(options, GcpOptions.STREAMING_ENGINE_EXPERIMENT)
|| hasExperiment(options, GcpOptions.WINDMILL_SERVICE_EXPERIMENT);
}

static void verifyDoFnSupported(
DoFn<?, ?> fn, boolean streaming, DataflowPipelineOptions options) {
if (!streaming && DoFnSignatures.usesMultimapState(fn)) {
Expand All @@ -2583,8 +2578,6 @@ static void verifyDoFnSupported(
"%s does not currently support @RequiresTimeSortedInput in streaming mode.",
DataflowRunner.class.getSimpleName()));
}

boolean streamingEngine = useStreamingEngine(options);
boolean isUnifiedWorker = useUnifiedWorker(options);

if (DoFnSignatures.usesMultimapState(fn) && isUnifiedWorker) {
Expand All @@ -2593,25 +2586,17 @@ static void verifyDoFnSupported(
"%s does not currently support %s running using streaming on unified worker",
DataflowRunner.class.getSimpleName(), MultimapState.class.getSimpleName()));
}
if (DoFnSignatures.usesSetState(fn)) {
if (streaming && (isUnifiedWorker || streamingEngine)) {
throw new UnsupportedOperationException(
String.format(
"%s does not currently support %s when using %s",
DataflowRunner.class.getSimpleName(),
SetState.class.getSimpleName(),
isUnifiedWorker ? "streaming on unified worker" : "streaming engine"));
}
if (DoFnSignatures.usesSetState(fn) && streaming && isUnifiedWorker) {
throw new UnsupportedOperationException(
String.format(
"%s does not currently support %s when using streaming on unified worker",
DataflowRunner.class.getSimpleName(), SetState.class.getSimpleName()));
}
if (DoFnSignatures.usesMapState(fn)) {
if (streaming && (isUnifiedWorker || streamingEngine)) {
throw new UnsupportedOperationException(
String.format(
"%s does not currently support %s when using %s",
DataflowRunner.class.getSimpleName(),
MapState.class.getSimpleName(),
isUnifiedWorker ? "streaming on unified worker" : "streaming engine"));
}
if (DoFnSignatures.usesMapState(fn) && streaming && isUnifiedWorker) {
throw new UnsupportedOperationException(
String.format(
"%s does not currently support %s when using streaming on unified worker",
DataflowRunner.class.getSimpleName(), MapState.class.getSimpleName()));
}
if (DoFnSignatures.usesBundleFinalizer(fn) && !isUnifiedWorker) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@
import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.runners.TransformHierarchy.Node;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.ValueState;
Expand Down Expand Up @@ -1880,63 +1878,6 @@ public void testSettingConflictingEnableAndDisableExperimentsThrowsException() t
}
}

private void verifyMapStateUnsupported(PipelineOptions options) throws Exception {
Pipeline p = Pipeline.create(options);
p.apply(Create.of(KV.of(13, 42)))
.apply(
ParDo.of(
new DoFn<KV<Integer, Integer>, Void>() {

@StateId("fizzle")
private final StateSpec<MapState<Void, Void>> voidState = StateSpecs.map();

@ProcessElement
public void process() {}
}));

thrown.expectMessage("MapState");
thrown.expect(UnsupportedOperationException.class);
p.run();
}

@Test
public void testMapStateUnsupportedStreamingEngine() throws Exception {
PipelineOptions options = buildPipelineOptions();
ExperimentalOptions.addExperiment(
options.as(ExperimentalOptions.class), GcpOptions.STREAMING_ENGINE_EXPERIMENT);
options.as(DataflowPipelineOptions.class).setStreaming(true);

verifyMapStateUnsupported(options);
}

private void verifySetStateUnsupported(PipelineOptions options) throws Exception {
Pipeline p = Pipeline.create(options);
p.apply(Create.of(KV.of(13, 42)))
.apply(
ParDo.of(
new DoFn<KV<Integer, Integer>, Void>() {

@StateId("fizzle")
private final StateSpec<SetState<Void>> voidState = StateSpecs.set();

@ProcessElement
public void process() {}
}));

thrown.expectMessage("SetState");
thrown.expect(UnsupportedOperationException.class);
p.run();
}

@Test
public void testSetStateUnsupportedStreamingEngine() throws Exception {
PipelineOptions options = buildPipelineOptions();
ExperimentalOptions.addExperiment(
options.as(ExperimentalOptions.class), GcpOptions.STREAMING_ENGINE_EXPERIMENT);
options.as(DataflowPipelineOptions.class).setStreaming(true);
verifySetStateUnsupported(options);
}

/** Records all the composite transforms visited within the Pipeline. */
private static class CompositeTransformRecorder extends PipelineVisitor.Defaults {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,10 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o
BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options);
AtomicInteger maxWorkItemCommitBytes = new AtomicInteger(Integer.MAX_VALUE);
WindmillStateCache windmillStateCache =
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
WindmillStateCache.builder()
.setSizeMb(options.getWorkerCacheMb())
.setSupportMapViaMultimap(options.isEnableStreamingEngine())
.build();
Function<String, ScheduledExecutorService> executorSupplier =
threadName ->
Executors.newSingleThreadScheduledExecutor(
Expand Down Expand Up @@ -478,7 +481,11 @@ static StreamingDataflowWorker forTesting(
ConcurrentMap<String, StageInfo> stageInfo = new ConcurrentHashMap<>();
AtomicInteger maxWorkItemCommitBytes = new AtomicInteger(maxWorkItemCommitBytesOverrides);
BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options);
WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
WindmillStateCache stateCache =
WindmillStateCache.builder()
.setSizeMb(options.getWorkerCacheMb())
.setSupportMapViaMultimap(options.isEnableStreamingEngine())
.build();
ComputationConfig.Fetcher configFetcher =
options.isEnableStreamingEngine()
? StreamingEngineComputationConfigFetcher.forTesting(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.windmill.state;

import org.apache.beam.sdk.state.MapState;

public abstract class AbstractWindmillMap<K, V> extends SimpleWindmillState
implements MapState<K, V> {}
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,9 @@
import org.apache.beam.runners.core.StateTable;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.runners.core.StateTags;
import org.apache.beam.sdk.coders.BooleanCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.MultimapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateContext;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
import org.apache.beam.sdk.state.*;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
Expand All @@ -50,6 +42,7 @@ final class CachingStateTable extends StateTable {
private final Supplier<Closeable> scopedReadStateSupplier;
private final @Nullable StateTable derivedStateTable;
private final boolean isNewKey;
private final boolean mapStateViaMultimapState;

private CachingStateTable(Builder builder) {
this.stateFamily = builder.stateFamily;
Expand All @@ -59,6 +52,7 @@ private CachingStateTable(Builder builder) {
this.isNewKey = builder.isNewKey;
this.scopedReadStateSupplier = builder.scopedReadStateSupplier;
this.derivedStateTable = builder.derivedStateTable;
this.mapStateViaMultimapState = builder.mapStateViaMultimapState;

if (this.isSystemTable) {
Preconditions.checkState(derivedStateTable == null);
Expand Down Expand Up @@ -103,30 +97,39 @@ public <T> BagState<T> bindBag(StateTag<BagState<T>> address, Coder<T> elemCoder

@Override
public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T> elemCoder) {
StateTag<MapState<T, Boolean>> internalMapAddress = StateTags.convertToMapTagInternal(spec);
WindmillSet<T> result =
new WindmillSet<>(namespace, spec, stateFamily, elemCoder, cache, isNewKey);
new WindmillSet<>(bindMap(internalMapAddress, elemCoder, BooleanCoder.of()));
result.initializeForWorkItem(reader, scopedReadStateSupplier);
return result;
}

@Override
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
public <KeyT, ValueT> AbstractWindmillMap<KeyT, ValueT> bindMap(
StateTag<MapState<KeyT, ValueT>> spec, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) {
WindmillMap<KeyT, ValueT> result =
cache
.get(namespace, spec)
.map(mapState -> (WindmillMap<KeyT, ValueT>) mapState)
.orElseGet(
() ->
new WindmillMap<>(
namespace, spec, stateFamily, keyCoder, valueCoder, isNewKey));

AbstractWindmillMap<KeyT, ValueT> result;
if (mapStateViaMultimapState) {
StateTag<MultimapState<KeyT, ValueT>> internalMultimapAddress =
StateTags.convertToMultiMapTagInternal(spec);
result =
new WindmillMapViaMultimap<>(
bindMultimap(internalMultimapAddress, keyCoder, valueCoder));
} else {
result =
cache
.get(namespace, spec)
.map(mapState -> (AbstractWindmillMap<KeyT, ValueT>) mapState)
.orElseGet(
() ->
new WindmillMap<>(
namespace, spec, stateFamily, keyCoder, valueCoder, isNewKey));
}
result.initializeForWorkItem(reader, scopedReadStateSupplier);
return result;
}

@Override
public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
public <KeyT, ValueT> WindmillMultimap<KeyT, ValueT> bindMultimap(
StateTag<MultimapState<KeyT, ValueT>> spec,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
Expand Down Expand Up @@ -246,6 +249,7 @@ static class Builder {
private final boolean isNewKey;
private boolean isSystemTable;
private @Nullable StateTable derivedStateTable;
private boolean mapStateViaMultimapState = false;

private Builder(
String stateFamily,
Expand All @@ -268,6 +272,11 @@ Builder withDerivedState(StateTable derivedStateTable) {
return this;
}

Builder withMapStateViaMultimapState() {
this.mapStateViaMultimapState = true;
return this;
}

CachingStateTable build() {
return new CachingStateTable(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.Function;
Expand All @@ -40,6 +37,8 @@
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
Expand All @@ -51,7 +50,7 @@
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class WindmillMap<K, V> extends SimpleWindmillState implements MapState<K, V> {
public class WindmillMap<K, V> extends AbstractWindmillMap<K, V> {
private final StateNamespace namespace;
private final StateTag<MapState<K, V>> address;
private final ByteString stateKeyPrefix;
Expand Down Expand Up @@ -327,7 +326,7 @@ private class WindmillMapEntriesReadableState
@Override
public Iterable<Map.Entry<K, V>> read() {
if (complete) {
return Iterables.unmodifiableIterable(cachedValues.entrySet());
return ImmutableMap.copyOf(cachedValues).entrySet();
}
Future<Iterable<Map.Entry<ByteString, V>>> persistedData = getFuture();
try (Closeable scope = scopedReadState()) {
Expand All @@ -352,20 +351,22 @@ public Iterable<Map.Entry<K, V>> read() {
cachedValues.putIfAbsent(e.getKey(), e.getValue());
});
complete = true;
return Iterables.unmodifiableIterable(cachedValues.entrySet());
return ImmutableMap.copyOf(cachedValues).entrySet();
} else {
ImmutableMap<K, V> cachedCopy = ImmutableMap.copyOf(cachedValues);
ImmutableSet<K> removalCopy = ImmutableSet.copyOf(localRemovals);
// This means that the result might be too large to cache, so don't add it to the
// local cache. Instead merge the iterables, giving priority to any local additions
// (represented in cachedValued and localRemovals) that may not have been committed
// (represented in cachedCopy and removalCopy) that may not have been committed
// yet.
return Iterables.unmodifiableIterable(
Iterables.concat(
cachedValues.entrySet(),
cachedCopy.entrySet(),
Iterables.filter(
transformedData,
e ->
!cachedValues.containsKey(e.getKey())
&& !localRemovals.contains(e.getKey()))));
!cachedCopy.containsKey(e.getKey())
&& !removalCopy.contains(e.getKey()))));
}

} catch (InterruptedException | ExecutionException | IOException e) {
Expand Down Expand Up @@ -428,7 +429,6 @@ public WindmillMapReadResultReadableState(K key, @Nullable V defaultValue) {
negativeCache.add(key);
return defaultValue;
}
// TODO: Don't do this if it was already in cache.
cachedValues.put(key, persistedValue);
return persistedValue;
} catch (InterruptedException | ExecutionException | IOException e) {
Expand Down
Loading

0 comments on commit c08afea

Please sign in to comment.