Skip to content

Commit

Permalink
GCS autosharding flag (#29886)
Browse files Browse the repository at this point in the history
GCS autosharding flag

Co-authored-by: Naireen <[email protected]>
  • Loading branch information
Naireen and Naireen authored Jan 18, 2024
1 parent 52b4a9c commit 79b9de2
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,10 @@ message WriteFilesPayload {
bool runner_determined_sharding = 4;

map<string, SideInput> side_inputs = 5;

// This is different from runner based sharding. This is done by the runner backend, where as runner_determined_sharding
// is by the runner translator
bool auto_sharded = 6;
}

// Payload used by Google Cloud Pub/Sub read transform.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ public boolean isWindowedWrites() {
return transform.getWindowedWrites();
}

@Override
public boolean isAutoSharded() {
return transform.getWithAutoSharding();
}

@Override
public boolean isRunnerDeterminedSharding() {
return transform.getNumShardsProvider() == null
Expand Down Expand Up @@ -175,6 +180,16 @@ public static <T, DestinationT> boolean isWindowedWrites(
return getWriteFilesPayload(transform).getWindowedWrites();
}

public static <T, DestinationT> boolean isAutoSharded(
AppliedPTransform<
PCollection<T>,
WriteFilesResult<DestinationT>,
? extends PTransform<PCollection<T>, WriteFilesResult<DestinationT>>>
transform)
throws IOException {
return getWriteFilesPayload(transform).getAutoSharded();
}

public static <T, DestinationT> boolean isRunnerDeterminedSharding(
AppliedPTransform<
PCollection<T>,
Expand Down Expand Up @@ -268,6 +283,11 @@ public boolean isWindowedWrites() {
return payload.getWindowedWrites();
}

@Override
public boolean isAutoSharded() {
return payload.getAutoSharded();
}

@Override
public boolean isRunnerDeterminedSharding() {
return payload.getRunnerDeterminedSharding();
Expand Down Expand Up @@ -309,6 +329,8 @@ private interface WriteFilesLike {

boolean isWindowedWrites();

boolean isAutoSharded();

boolean isRunnerDeterminedSharding();
}

Expand All @@ -319,6 +341,7 @@ public static WriteFilesPayload payloadForWriteFilesLike(
.setSink(writeFiles.translateSink(components))
.putAllSideInputs(writeFiles.translateSideInputs(components))
.setWindowedWrites(writeFiles.isWindowedWrites())
.setAutoSharded(writeFiles.isAutoSharded())
.setRunnerDeterminedSharding(writeFiles.isRunnerDeterminedSharding())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ public static Iterable<WriteFiles<Object, Void, Object>> data() {
WriteFiles.to(new DummySink()),
WriteFiles.to(new DummySink()).withWindowedWrites(),
WriteFiles.to(new DummySink()).withNumShards(17),
WriteFiles.to(new DummySink()).withWindowedWrites().withNumShards(42));
WriteFiles.to(new DummySink()).withWindowedWrites().withNumShards(42),
WriteFiles.to(new DummySink()).withAutoSharding());
}

@Parameter(0)
Expand Down Expand Up @@ -105,6 +106,9 @@ public void testExtractionDirectFromTransform() throws Exception {
equalTo(
writeFiles.getNumShardsProvider() == null && writeFiles.getComputeNumShards() == null));

assertThat(
WriteFilesTranslation.isAutoSharded(appliedPTransform),
equalTo(writeFiles.getWithAutoSharding()));
assertThat(
WriteFilesTranslation.isWindowedWrites(appliedPTransform),
equalTo(writeFiles.getWindowedWrites()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2446,6 +2446,13 @@ static class StreamingShardedWriteFactory<UserT, DestinationT, OutputT>
if (WriteFilesTranslation.isWindowedWrites(transform)) {
replacement = replacement.withWindowedWrites();
}

if (WriteFilesTranslation.isAutoSharded(transform)) {
replacement = replacement.withAutoSharding();
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform), replacement);
}

return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
replacement.withNumShards(numShards));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,39 @@ public void testStreamingWriteWithNoShardingReturnsNewTransformMaxWorkersUnset()
testStreamingWriteOverride(options, StreamingShardedWriteFactory.DEFAULT_NUM_SHARDS);
}

@Test
public void testStreamingWriteWithShardingReturnsSameTransform() {
PipelineOptions options = TestPipeline.testingPipelineOptions();

TestPipeline p = TestPipeline.fromOptions(options);

StreamingShardedWriteFactory<Object, Void, Object> factory =
new StreamingShardedWriteFactory<>(p.getOptions());
WriteFiles<Object, Void, Object> original =
WriteFiles.to(new TestSink(tmpFolder.toString())).withAutoSharding();
PCollection<Object> objs = (PCollection) p.apply(Create.empty(VoidCoder.of()));
AppliedPTransform<PCollection<Object>, WriteFilesResult<Void>, WriteFiles<Object, Void, Object>>
originalApplication =
AppliedPTransform.of(
"writefiles",
PValues.expandInput(objs),
Collections.emptyMap(),
original,
ResourceHints.create(),
p);

WriteFiles<Object, Void, Object> replacement =
(WriteFiles<Object, Void, Object>)
factory.getReplacementTransform(originalApplication).getTransform();

WriteFilesResult<Void> originalResult = objs.apply(original);
WriteFilesResult<Void> replacementResult = objs.apply(replacement);

assertTrue(replacement.getNumShardsProvider() == null);
assertTrue(replacement.getComputeNumShards() == null);
assertTrue(replacement.getWithAutoSharding());
}

private void verifyMergingStatefulParDoRejected(PipelineOptions options) throws Exception {
Pipeline p = Pipeline.create(options);

Expand Down
10 changes: 4 additions & 6 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,6 @@ public PipelineOptions getOptions() {
*
* <p>Replaces all nodes that match a {@link PTransformOverride} in this pipeline. Overrides are
* applied in the order they are present within the list.
*
* <p>After all nodes are replaced, ensures that no nodes in the updated graph match any of the
* overrides.
*/
@Internal
public void replaceAll(List<PTransformOverride> overrides) {
Expand Down Expand Up @@ -241,9 +238,10 @@ public CompositeBehavior enterCompositeTransform(Node node) {

@Override
public void leaveCompositeTransform(Node node) {
if (node.isRootNode()) {
checkState(
matched.isEmpty(), "Found nodes that matched overrides. Matches: %s", matched);
if (node.isRootNode() && !matched.isEmpty()) {
LOG.info(
"Found nodes that matched overrides. Matches: {}. The match usually should be empty unless there are runner specific replacement transforms.",
matched);
}
}

Expand Down
14 changes: 14 additions & 0 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ public static <InputT> Write<Void, InputT> write() {
.setDynamic(false)
.setCompression(Compression.UNCOMPRESSED)
.setIgnoreWindowing(false)
.setAutoSharding(false)
.setNoSpilling(false)
.build();
}
Expand All @@ -406,6 +407,7 @@ public static <DestT, InputT> Write<DestT, InputT> writeDynamic() {
.setDynamic(true)
.setCompression(Compression.UNCOMPRESSED)
.setIgnoreWindowing(false)
.setAutoSharding(false)
.setNoSpilling(false)
.build();
}
Expand Down Expand Up @@ -1037,6 +1039,8 @@ public static FileNaming relativeFileNaming(

abstract boolean getIgnoreWindowing();

abstract boolean getAutoSharding();

abstract boolean getNoSpilling();

abstract @Nullable ErrorHandler<BadRecord, ?> getBadRecordErrorHandler();
Expand Down Expand Up @@ -1085,6 +1089,8 @@ abstract Builder<DestinationT, UserT> setSharding(

abstract Builder<DestinationT, UserT> setIgnoreWindowing(boolean ignoreWindowing);

abstract Builder<DestinationT, UserT> setAutoSharding(boolean autosharding);

abstract Builder<DestinationT, UserT> setNoSpilling(boolean noSpilling);

abstract Builder<DestinationT, UserT> setBadRecordErrorHandler(
Expand Down Expand Up @@ -1311,6 +1317,10 @@ public Write<DestinationT, UserT> withIgnoreWindowing() {
return toBuilder().setIgnoreWindowing(true).build();
}

public Write<DestinationT, UserT> withAutoSharding() {
return toBuilder().setAutoSharding(true).build();
}

/** See {@link WriteFiles#withNoSpilling()}. */
public Write<DestinationT, UserT> withNoSpilling() {
return toBuilder().setNoSpilling(true).build();
Expand Down Expand Up @@ -1412,6 +1422,7 @@ public WriteFilesResult<DestinationT> expand(PCollection<UserT> input) {
resolvedSpec.setNumShards(getNumShards());
resolvedSpec.setSharding(getSharding());
resolvedSpec.setIgnoreWindowing(getIgnoreWindowing());
resolvedSpec.setAutoSharding(getAutoSharding());
resolvedSpec.setNoSpilling(getNoSpilling());

Write<DestinationT, UserT> resolved = resolvedSpec.build();
Expand All @@ -1428,6 +1439,9 @@ public WriteFilesResult<DestinationT> expand(PCollection<UserT> input) {
if (!getIgnoreWindowing()) {
writeFiles = writeFiles.withWindowedWrites();
}
if (getAutoSharding()) {
writeFiles = writeFiles.withAutoSharding();
}
if (getNoSpilling()) {
writeFiles = writeFiles.withNoSpilling();
}
Expand Down
18 changes: 18 additions & 0 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ public static <UserT> TypedWrite<UserT, Void> writeCustomType() {
.setWindowedWrites(false)
.setNoSpilling(false)
.setSkipIfEmpty(false)
.setAutoSharding(false)
.build();
}

Expand Down Expand Up @@ -702,6 +703,9 @@ public abstract static class TypedWrite<UserT, DestinationT>
/** Whether to write windowed output files. */
abstract boolean getWindowedWrites();

/** Whether to enable autosharding. */
abstract boolean getAutoSharding();

/** Whether to skip the spilling of data caused by having maxNumWritersPerBundle. */
abstract boolean getNoSpilling();

Expand Down Expand Up @@ -755,6 +759,8 @@ abstract Builder<UserT, DestinationT> setNumShards(

abstract Builder<UserT, DestinationT> setWindowedWrites(boolean windowedWrites);

abstract Builder<UserT, DestinationT> setAutoSharding(boolean windowedWrites);

abstract Builder<UserT, DestinationT> setNoSpilling(boolean noSpilling);

abstract Builder<UserT, DestinationT> setSkipIfEmpty(boolean noSpilling);
Expand Down Expand Up @@ -999,6 +1005,10 @@ public TypedWrite<UserT, DestinationT> withWindowedWrites() {
return toBuilder().setWindowedWrites(true).build();
}

public TypedWrite<UserT, DestinationT> withAutoSharding() {
return toBuilder().setAutoSharding(true).build();
}

/** See {@link WriteFiles#withNoSpilling()}. */
public TypedWrite<UserT, DestinationT> withNoSpilling() {
return toBuilder().setNoSpilling(true).build();
Expand Down Expand Up @@ -1097,6 +1107,9 @@ public WriteFilesResult<DestinationT> expand(PCollection<UserT> input) {
if (getWindowedWrites()) {
write = write.withWindowedWrites();
}
if (getAutoSharding()) {
write = write.withAutoSharding();
}
if (getNoSpilling()) {
write = write.withNoSpilling();
}
Expand Down Expand Up @@ -1268,6 +1281,11 @@ public Write withWindowedWrites() {
return new Write(inner.withWindowedWrites());
}

/** See {@link TypedWrite#withAutoSharding}. */
public Write withAutoSharding() {
return new Write(inner.withAutoSharding());
}

/** See {@link TypedWrite#withNoSpilling}. */
public Write withNoSpilling() {
return new Write(inner.withNoSpilling());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ public static <UserT, DestinationT, OutputT> WriteFiles<UserT, DestinationT, Out
.setComputeNumShards(null)
.setNumShardsProvider(null)
.setWindowedWrites(false)
.setWithAutoSharding(false)
.setMaxNumWritersPerBundle(DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE)
.setSideInputs(sink.getDynamicDestinations().getSideInputs())
.setSkipIfEmpty(false)
Expand All @@ -189,6 +190,8 @@ public static <UserT, DestinationT, OutputT> WriteFiles<UserT, DestinationT, Out

public abstract boolean getWindowedWrites();

public abstract boolean getWithAutoSharding();

abstract int getMaxNumWritersPerBundle();

abstract boolean getSkipIfEmpty();
Expand Down Expand Up @@ -216,6 +219,8 @@ abstract Builder<UserT, DestinationT, OutputT> setNumShardsProvider(

abstract Builder<UserT, DestinationT, OutputT> setWindowedWrites(boolean windowedWrites);

abstract Builder<UserT, DestinationT, OutputT> setWithAutoSharding(boolean withAutoSharding);

abstract Builder<UserT, DestinationT, OutputT> setMaxNumWritersPerBundle(
int maxNumWritersPerBundle);

Expand Down Expand Up @@ -308,6 +313,13 @@ public WriteFiles<UserT, DestinationT, OutputT> withRunnerDeterminedSharding() {
return toBuilder().setComputeNumShards(null).setNumShardsProvider(null).build();
}

public WriteFiles<UserT, DestinationT, OutputT> withAutoSharding() {
checkArgument(
getComputeNumShards() == null && getNumShardsProvider() == null,
" sharding should be null if autosharding is specified.");
return toBuilder().setWithAutoSharding(true).build();
}

/**
* Returns a new {@link WriteFiles} that will write to the current {@link FileBasedSink} using the
* specified sharding function to assign shard for inputs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,16 +418,15 @@ public CompositeBehavior enterCompositeTransform(Node node) {
}

/**
* Tests that {@link Pipeline#replaceAll(List)} throws when one of the PTransformOverride still
* Tests that {@link Pipeline#replaceAll(List)} succeeds when one of the PTransformOverride still
* matches.
*/
@Test
public void testReplaceAllIncomplete() {
pipeline.enableAbandonedNodeEnforcement(false);
pipeline.apply(GenerateSequence.from(0));

// The order is such that the output of the second will match the first, which is not permitted
thrown.expect(IllegalStateException.class);
// The order is such that the output of the second will match the first, which is permitted.
pipeline.replaceAll(
ImmutableList.of(
PTransformOverride.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,18 @@ public void testWithRunnerDeterminedShardingUnbounded() throws IOException {
true);
}

@Test
@Category({NeedsRunner.class, UsesUnboundedPCollections.class})
public void testWithShardingUnbounded() throws IOException {
runShardedWrite(
Arrays.asList("one", "two", "three", "four", "five", "six"),
Window.into(FixedWindows.of(Duration.standardSeconds(10))),
getBaseOutputFilename(),
WriteFiles.to(makeSimpleSink()).withWindowedWrites().withAutoSharding(),
null,
true);
}

@Test
@Category({
NeedsRunner.class,
Expand Down

0 comments on commit 79b9de2

Please sign in to comment.