From 23e59afce976d40f3e881094deee7e42c42a0e11 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Wed, 23 Nov 2022 16:52:50 -0800 Subject: [PATCH] Updates ExpansionService to support dynamically discovering and expanding SchemaTransforms (#23413) * Updates ExpansionService to support dynamically discovering and expanding SchemaTransforms * Fixing checker framework errors. * Address reviewer comments * Addressing reviewer comments * Addressing reviewer comments --- .../v1/beam_expansion_api.proto | 28 + .../pipeline/v1/external_transforms.proto | 17 + .../expansion/service/ExpansionService.java | 75 ++- ...pansionServiceSchemaTransformProvider.java | 144 ++++++ ...ionServiceSchemaTransformProviderTest.java | 486 ++++++++++++++++++ .../apache_beam/portability/common_urns.py | 1 + .../python/apache_beam/transforms/external.py | 158 ++++-- .../apache_beam/transforms/external_test.py | 30 ++ 8 files changed, 898 insertions(+), 41 deletions(-) create mode 100644 sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java create mode 100644 sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java diff --git a/model/job-management/src/main/proto/org/apache/beam/model/job_management/v1/beam_expansion_api.proto b/model/job-management/src/main/proto/org/apache/beam/model/job_management/v1/beam_expansion_api.proto index f3ab890005de..568f9c877410 100644 --- a/model/job-management/src/main/proto/org/apache/beam/model/job_management/v1/beam_expansion_api.proto +++ b/model/job-management/src/main/proto/org/apache/beam/model/job_management/v1/beam_expansion_api.proto @@ -30,6 +30,7 @@ option java_package = "org.apache.beam.model.expansion.v1"; option java_outer_classname = "ExpansionApi"; import "org/apache/beam/model/pipeline/v1/beam_runner_api.proto"; +import "org/apache/beam/model/pipeline/v1/schema.proto"; message ExpansionRequest { // Set of components needed to interpret the transform, or which @@ -72,7 +73,34 @@ message ExpansionResponse { string error = 10; } +message DiscoverSchemaTransformRequest { +} + +message SchemaTransformConfig { + // Config schema of the SchemaTransform + org.apache.beam.model.pipeline.v1.Schema config_schema = 1; + + // Names of input PCollections + repeated string input_pcollection_names = 2; + + // Names of output PCollections + repeated string output_pcollection_names = 3; +} + +message DiscoverSchemaTransformResponse { + // A mapping from SchemaTransform ID to schema transform config of discovered + // SchemaTransforms + map schema_transform_configs = 1; + + // If list of identifies are empty, this may contain an error. + string error = 2; +} + // Job Service for constructing pipelines service ExpansionService { rpc Expand (ExpansionRequest) returns (ExpansionResponse); + + //A RPC to discover already registered SchemaTransformProviders. + // See https://s.apache.org/easy-multi-language for more details. + rpc DiscoverSchemaTransform (DiscoverSchemaTransformRequest) returns (DiscoverSchemaTransformResponse); } diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index baff2c0436f5..18cd02e3942c 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -51,6 +51,11 @@ message ExpansionMethods { // Transform payload will be of type JavaClassLookupPayload. JAVA_CLASS_LOOKUP = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:expansion:payload:java_class_lookup:v1"]; + + // Expanding a SchemaTransform identified by the expansion service. + // Transform payload will be of type SchemaTransformPayload. + SCHEMA_TRANSFORM = 1 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:expansion:payload:schematransform:v1"]; } } @@ -106,4 +111,16 @@ message BuilderMethod { bytes payload = 3; } +message SchemaTransformPayload { + // The identifier of the SchemaTransform (typically a URN). + string identifier = 1; + + // The configuration schema of the SchemaTransform. + Schema configuration_schema = 2; + // The configuration of the SchemaTransform. + // Should be decodable via beam:coder:row:v1. + // The schema of the Row should be compatible with the schema of the + // SchemaTransform. + bytes configuration_row = 3; +} diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index fed01d2576e6..221c40f79202 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -35,6 +35,9 @@ import java.util.Set; import java.util.stream.Collectors; import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.expansion.v1.ExpansionApi.DiscoverSchemaTransformRequest; +import org.apache.beam.model.expansion.v1.ExpansionApi.DiscoverSchemaTransformResponse; +import org.apache.beam.model.expansion.v1.ExpansionApi.SchemaTransformConfig; import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc; import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExternalConfigurationPayload; @@ -63,6 +66,7 @@ import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.transforms.ExternalTransformBuilder; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SerializableFunction; @@ -436,6 +440,10 @@ private Map getRegisteredTransforms() { return registeredTransforms; } + private Iterable getRegisteredSchemaTransforms() { + return ExpansionServiceSchemaTransformProvider.of().getAllProviders(); + } + private Map loadRegisteredTransforms() { ImmutableMap.Builder registeredTransformsBuilder = ImmutableMap.builder(); @@ -500,6 +508,8 @@ private Map loadRegisteredTransforms() { pipelineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist(); assert allowList != null; transformProvider = new JavaClassLookupTransformProvider(allowList); + } else if (getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM).equals(urn)) { + transformProvider = ExpansionServiceSchemaTransformProvider.of(); } else { transformProvider = getRegisteredTransforms().get(urn); if (transformProvider == null) { @@ -604,6 +614,42 @@ public void expand( } } + DiscoverSchemaTransformResponse discover(DiscoverSchemaTransformRequest request) { + ExpansionServiceSchemaTransformProvider transformProvider = + ExpansionServiceSchemaTransformProvider.of(); + DiscoverSchemaTransformResponse.Builder responseBuilder = + DiscoverSchemaTransformResponse.newBuilder(); + for (org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider provider : + transformProvider.getAllProviders()) { + SchemaTransformConfig.Builder schemaTransformConfigBuider = + SchemaTransformConfig.newBuilder(); + schemaTransformConfigBuider.setConfigSchema( + SchemaTranslation.schemaToProto(provider.configurationSchema(), true)); + schemaTransformConfigBuider.addAllInputPcollectionNames(provider.inputCollectionNames()); + schemaTransformConfigBuider.addAllOutputPcollectionNames(provider.outputCollectionNames()); + responseBuilder.putSchemaTransformConfigs( + provider.identifier(), schemaTransformConfigBuider.build()); + } + + return responseBuilder.build(); + } + + @Override + public void discoverSchemaTransform( + DiscoverSchemaTransformRequest request, + StreamObserver responseObserver) { + try { + responseObserver.onNext(discover(request)); + responseObserver.onCompleted(); + } catch (RuntimeException exn) { + responseObserver.onNext( + ExpansionApi.DiscoverSchemaTransformResponse.newBuilder() + .setError(Throwables.getStackTraceAsString(exn)) + .build()); + responseObserver.onCompleted(); + } + } + @Override public void close() throws Exception { // Nothing to do because the expansion service is stateless. @@ -618,9 +664,36 @@ public static void main(String[] args) throws Exception { @SuppressWarnings("nullness") ExpansionService service = new ExpansionService(Arrays.copyOfRange(args, 1, args.length)); + + StringBuilder registeredTransformsLog = new StringBuilder(); + boolean registeredTransformsFound = false; + registeredTransformsLog.append("\n"); + registeredTransformsLog.append("Registered transforms:"); + for (Map.Entry entry : service.getRegisteredTransforms().entrySet()) { - System.out.println("\t" + entry.getKey() + ": " + entry.getValue()); + registeredTransformsFound = true; + registeredTransformsLog.append("\n\t" + entry.getKey() + ": " + entry.getValue()); + } + + StringBuilder registeredSchemaTransformProvidersLog = new StringBuilder(); + boolean registeredSchemaTransformProvidersFound = false; + registeredSchemaTransformProvidersLog.append("\n"); + registeredSchemaTransformProvidersLog.append("Registered SchemaTransformProviders:"); + + for (SchemaTransformProvider provider : service.getRegisteredSchemaTransforms()) { + registeredSchemaTransformProvidersFound = true; + registeredSchemaTransformProvidersLog.append("\n\t" + provider.identifier()); + } + + if (registeredTransformsFound) { + System.out.println(registeredTransformsLog.toString()); + } + if (registeredSchemaTransformProvidersFound) { + System.out.println(registeredSchemaTransformProvidersLog.toString()); + } + if (!registeredTransformsFound && !registeredSchemaTransformProvidersFound) { + System.out.println("\nDid not find any registered transforms or SchemaTransforms.\n"); } Server server = diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java new file mode 100644 index 000000000000..4657e0524025 --- /dev/null +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.ServiceLoader; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; + +@SuppressWarnings({"rawtypes"}) +public class ExpansionServiceSchemaTransformProvider + implements TransformProvider { + + private Map + schemaTransformProviders = new HashMap<>(); + private static @Nullable ExpansionServiceSchemaTransformProvider transformProvider = null; + + private ExpansionServiceSchemaTransformProvider() { + try { + for (org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider schemaTransformProvider : + ServiceLoader.load( + org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider.class)) { + if (schemaTransformProviders.containsKey(schemaTransformProvider.identifier())) { + throw new IllegalArgumentException( + "Found multiple SchemaTransformProvider implementations with the same identifier " + + schemaTransformProvider.identifier()); + } + schemaTransformProviders.put(schemaTransformProvider.identifier(), schemaTransformProvider); + } + } catch (Exception e) { + throw new RuntimeException(e.getMessage()); + } + } + + public static ExpansionServiceSchemaTransformProvider of() { + if (transformProvider == null) { + transformProvider = new ExpansionServiceSchemaTransformProvider(); + } + + return transformProvider; + } + + @Override + public PCollectionRowTuple createInput(Pipeline p, Map> inputs) { + PCollectionRowTuple inputRowTuple = PCollectionRowTuple.empty(p); + for (Map.Entry> entry : inputs.entrySet()) { + inputRowTuple = inputRowTuple.and(entry.getKey(), (PCollection) entry.getValue()); + } + return inputRowTuple; + } + + @Override + public Map> extractOutputs(PCollectionRowTuple output) { + ImmutableMap.Builder> pCollectionMap = ImmutableMap.builder(); + for (String key : output.getAll().keySet()) { + pCollectionMap.put(key, output.get(key)); + } + return pCollectionMap.build(); + } + + @Override + public PTransform getTransform(FunctionSpec spec) { + SchemaTransformPayload payload; + try { + payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + String identifier = payload.getIdentifier(); + if (!schemaTransformProviders.containsKey(identifier)) { + throw new RuntimeException( + "Did not find a SchemaTransformProvider with the identifier " + identifier); + } + + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + "Invalid payload type for URN " + getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM), e); + } + + String identifier = payload.getIdentifier(); + org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider provider = + schemaTransformProviders.get(identifier); + if (provider == null) { + throw new IllegalArgumentException( + "Could not find a SchemaTransform with identifier " + identifier); + } + + Schema configSchemaFromRequest = + SchemaTranslation.schemaFromProto((payload.getConfigurationSchema())); + Schema configSchemaFromProvider = provider.configurationSchema(); + + if (!configSchemaFromRequest.assignableTo(configSchemaFromProvider)) { + throw new IllegalArgumentException( + String.format( + "Config schema provided with the expansion request %s is not compatible with the " + + "config of the Schema transform %s.", + configSchemaFromRequest, configSchemaFromProvider)); + } + + Row configRow; + try { + configRow = + RowCoder.of(provider.configurationSchema()) + .decode(payload.getConfigurationRow().newInput()); + } catch (IOException e) { + throw new RuntimeException("Error decoding payload", e); + } + + return provider.from(configRow).buildTransform(); + } + + Iterable getAllProviders() { + return schemaTransformProviders.values(); + } +} diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java new file mode 100644 index 000000000000..5b9b50b248a7 --- /dev/null +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java @@ -0,0 +1,486 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.junit.Assert.assertEquals; + +import com.google.auto.service.AutoService; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PTransformTranslation; +import org.apache.beam.runners.core.construction.ParDoTranslation; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaCreate; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.InferableFunction; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.junit.Test; + +/** Tests for {@link ExpansionServiceSchemaTransformProvider}. */ +@SuppressWarnings({ + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) +}) +public class ExpansionServiceSchemaTransformProviderTest { + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static final Schema TEST_SCHEMATRANSFORM_CONFIG_SCHEMA = + Schema.of( + Field.of("str1", FieldType.STRING), + Field.of("str2", FieldType.STRING), + Field.of("int1", FieldType.INT32), + Field.of("int2", FieldType.INT32)); + + private ExpansionService expansionService = new ExpansionService(); + + @DefaultSchema(JavaFieldSchema.class) + public static class TestSchemaTransformConfiguration { + + public final String str1; + public final String str2; + public final Integer int1; + public final Integer int2; + + @SchemaCreate + public TestSchemaTransformConfiguration(String str1, String str2, Integer int1, Integer int2) { + this.str1 = str1; + this.str2 = str2; + this.int1 = int1; + this.int2 = int2; + } + } + + /** Registers a SchemaTransform. */ + @AutoService(SchemaTransformProvider.class) + public static class TestSchemaTransformProvider + extends TypedSchemaTransformProvider { + + @Override + protected Class configurationClass() { + return TestSchemaTransformConfiguration.class; + } + + @Override + protected SchemaTransform from(TestSchemaTransformConfiguration configuration) { + return new TestSchemaTransform( + configuration.str1, configuration.str2, configuration.int1, configuration.int2); + } + + @Override + public String identifier() { + return "dummy_id"; + } + + @Override + public List inputCollectionNames() { + return ImmutableList.of("input1"); + } + + @Override + public List outputCollectionNames() { + return ImmutableList.of("output1"); + } + } + + public static class TestSchemaTransform implements SchemaTransform { + + private String str1; + private String str2; + private Integer int1; + private Integer int2; + + public TestSchemaTransform(String str1, String str2, Integer int1, Integer int2) { + this.str1 = str1; + this.str2 = str2; + this.int1 = int1; + this.int2 = int2; + } + + @Override + public PTransform buildTransform() { + return new TestTransform(str1, str2, int1, int2); + } + } + + public static class TestDoFn extends DoFn { + + public String str1; + public String str2; + public int int1; + public int int2; + + public TestDoFn(String str1, String str2, Integer int1, Integer int2) { + this.str1 = str1; + this.str2 = str2; + this.int1 = int1; + this.int2 = int2; + } + + @ProcessElement + public void processElement(@Element String element, OutputReceiver receiver) { + receiver.output(element); + } + } + + public static class TestTransform extends PTransform { + + private String str1; + private String str2; + private Integer int1; + private Integer int2; + + public TestTransform(String str1, String str2, Integer int1, Integer int2) { + this.str1 = str1; + this.str2 = str2; + this.int1 = int1; + this.int2 = int2; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PCollection outputPC = + input + .getAll() + .values() + .iterator() + .next() + .apply( + MapElements.via( + new InferableFunction() { + @Override + public String apply(Row input) throws Exception { + return input.getString("in_str"); + } + })) + .apply(ParDo.of(new TestDoFn(this.str1, this.str2, this.int1, this.int2))) + .apply( + MapElements.via( + new InferableFunction() { + @Override + public Row apply(String input) throws Exception { + return Row.withSchema(Schema.of(Field.of("out_str", FieldType.STRING))) + .withFieldValue("out_str", input) + .build(); + } + })) + .setRowSchema(Schema.of(Field.of("out_str", FieldType.STRING))); + return PCollectionRowTuple.of("output1", outputPC); + } + } + + /** Registers a SchemaTransform. */ + @AutoService(SchemaTransformProvider.class) + public static class TestSchemaTransformProviderMultiInputMultiOutput + extends TypedSchemaTransformProvider { + + @Override + protected Class configurationClass() { + return TestSchemaTransformConfiguration.class; + } + + @Override + protected SchemaTransform from(TestSchemaTransformConfiguration configuration) { + return new TestSchemaTransformMultiInputOutput( + configuration.str1, configuration.str2, configuration.int1, configuration.int2); + } + + @Override + public String identifier() { + return "dummy_id_multi_input_multi_output"; + } + + @Override + public List inputCollectionNames() { + return ImmutableList.of("input1", "input2"); + } + + @Override + public List outputCollectionNames() { + return ImmutableList.of("output1", "output2"); + } + } + + public static class TestSchemaTransformMultiInputOutput implements SchemaTransform { + + private String str1; + private String str2; + private Integer int1; + private Integer int2; + + public TestSchemaTransformMultiInputOutput( + String str1, String str2, Integer int1, Integer int2) { + this.str1 = str1; + this.str2 = str2; + this.int1 = int1; + this.int2 = int2; + } + + @Override + public PTransform buildTransform() { + return new TestTransformMultiInputMultiOutput(str1, str2, int1, int2); + } + } + + public static class TestTransformMultiInputMultiOutput + extends PTransform { + + private String str1; + private String str2; + private Integer int1; + private Integer int2; + + public TestTransformMultiInputMultiOutput( + String str1, String str2, Integer int1, Integer int2) { + this.str1 = str1; + this.str2 = str2; + this.int1 = int1; + this.int2 = int2; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PCollection outputPC1 = + input + .get("input1") + .apply( + MapElements.via( + new InferableFunction() { + @Override + public String apply(Row input) throws Exception { + return input.getString("in_str"); + } + })) + .apply(ParDo.of(new TestDoFn(this.str1, this.str2, this.int1, this.int2))) + .apply( + MapElements.via( + new InferableFunction() { + @Override + public Row apply(String input) throws Exception { + return Row.withSchema(Schema.of(Field.of("out_str", FieldType.STRING))) + .withFieldValue("out_str", input) + .build(); + } + })) + .setRowSchema(Schema.of(Field.of("out_str", FieldType.STRING))); + PCollection outputPC2 = + input + .get("input2") + .apply( + MapElements.via( + new InferableFunction() { + @Override + public String apply(Row input) throws Exception { + return input.getString("in_str"); + } + })) + .apply(ParDo.of(new TestDoFn(this.str1, this.str2, this.int1, this.int2))) + .apply( + MapElements.via( + new InferableFunction() { + @Override + public Row apply(String input) throws Exception { + return Row.withSchema(Schema.of(Field.of("out_str", FieldType.STRING))) + .withFieldValue("out_str", input) + .build(); + } + })) + .setRowSchema(Schema.of(Field.of("out_str", FieldType.STRING))); + return PCollectionRowTuple.of("output1", outputPC1, "output2", outputPC2); + } + } + + @Test + public void testSchemaTransformDiscovery() { + ExpansionApi.DiscoverSchemaTransformRequest discoverRequest = + ExpansionApi.DiscoverSchemaTransformRequest.newBuilder().build(); + ExpansionApi.DiscoverSchemaTransformResponse response = + expansionService.discover(discoverRequest); + assertEquals(2, response.getSchemaTransformConfigsCount()); + } + + private void verifyLeafTransforms(ExpansionApi.ExpansionResponse response, int count) { + + int leafTransformCount = 0; + for (RunnerApi.PTransform transform : response.getComponents().getTransformsMap().values()) { + if (transform.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN)) { + RunnerApi.ParDoPayload parDoPayload; + try { + parDoPayload = RunnerApi.ParDoPayload.parseFrom(transform.getSpec().getPayload()); + DoFn doFn = ParDoTranslation.getDoFn(parDoPayload); + if (!(doFn instanceof TestDoFn)) { + continue; + } + TestDoFn testDoFn = (TestDoFn) doFn; + assertEquals("aaa", testDoFn.str1); + assertEquals("bbb", testDoFn.str2); + assertEquals(111, testDoFn.int1); + assertEquals(222, testDoFn.int2); + leafTransformCount++; + } catch (InvalidProtocolBufferException exc) { + throw new RuntimeException(exc); + } + } + } + assertEquals(count, leafTransformCount); + } + + @Test + public void testSchemaTransformExpansion() { + Pipeline p = Pipeline.create(); + p.apply(Impulse.create()); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + String inputPcollId = + Iterables.getOnlyElement( + Iterables.getOnlyElement(pipelineProto.getComponents().getTransformsMap().values()) + .getOutputsMap() + .values()); + Row configRow = + Row.withSchema(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA) + .withFieldValue("str1", "aaa") + .withFieldValue("str2", "bbb") + .withFieldValue("int1", 111) + .withFieldValue("int2", 222) + .build(); + + ByteStringOutputStream outputStream = new ByteStringOutputStream(); + try { + SchemaCoder.of(configRow.getSchema()).encode(configRow, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + + ExternalTransforms.SchemaTransformPayload payload = + ExternalTransforms.SchemaTransformPayload.newBuilder() + .setIdentifier("dummy_id") + .setConfigurationRow(outputStream.toByteString()) + .setConfigurationSchema( + SchemaTranslation.schemaToProto(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA, true)) + .build(); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM)) + .setPayload(payload.toByteString())) + .putInputs("input1", inputPcollId)) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + + assertEquals(3, expandedTransform.getSubtransformsCount()); + assertEquals(1, expandedTransform.getInputsCount()); + assertEquals(1, expandedTransform.getOutputsCount()); + verifyLeafTransforms(response, 1); + } + + @Test + public void testSchemaTransformExpansionMultiInputMultiOutput() { + Pipeline p = Pipeline.create(); + p.apply(Impulse.create()); + p.apply(Impulse.create()); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + List inputPcollIds = new ArrayList<>(); + for (RunnerApi.PTransform transform : + pipelineProto.getComponents().getTransformsMap().values()) { + inputPcollIds.add(Iterables.getOnlyElement(transform.getOutputsMap().values())); + } + assertEquals(2, inputPcollIds.size()); + + Row configRow = + Row.withSchema(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA) + .withFieldValue("str1", "aaa") + .withFieldValue("str2", "bbb") + .withFieldValue("int1", 111) + .withFieldValue("int2", 222) + .build(); + + ByteStringOutputStream outputStream = new ByteStringOutputStream(); + try { + SchemaCoder.of(configRow.getSchema()).encode(configRow, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + + ExternalTransforms.SchemaTransformPayload payload = + ExternalTransforms.SchemaTransformPayload.newBuilder() + .setIdentifier("dummy_id_multi_input_multi_output") + .setConfigurationRow(outputStream.toByteString()) + .setConfigurationSchema( + SchemaTranslation.schemaToProto(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA, true)) + .build(); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM)) + .setPayload(payload.toByteString())) + .putInputs("input1", inputPcollIds.get(0)) + .putInputs("input2", inputPcollIds.get(1))) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + + assertEquals(6, expandedTransform.getSubtransformsCount()); + assertEquals(2, expandedTransform.getInputsCount()); + assertEquals(2, expandedTransform.getOutputsCount()); + verifyLeafTransforms(response, 2); + } +} diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py index 3b47f1ab1e40..3799af5d2e1b 100644 --- a/sdks/python/apache_beam/portability/common_urns.py +++ b/sdks/python/apache_beam/portability/common_urns.py @@ -78,6 +78,7 @@ displayData = StandardDisplayData.DisplayData java_class_lookup = ExpansionMethods.Enum.JAVA_CLASS_LOOKUP +schematransform_based_expand = ExpansionMethods.Enum.SCHEMA_TRANSFORM decimal = LogicalTypes.Enum.DECIMAL micros_instant = LogicalTypes.Enum.MICROS_INSTANT diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 58b5182593ea..7a51379a0e39 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -28,6 +28,7 @@ import logging import threading from collections import OrderedDict +from collections import namedtuple from typing import Dict import grpc @@ -104,6 +105,28 @@ def payload(self): """ return self.build().SerializeToString() + def _get_schema_proto_and_payload(self, **kwargs): + named_fields = [] + fields_to_values = OrderedDict() + + for key, value in kwargs.items(): + if not key: + raise ValueError('Parameter name cannot be empty') + if value is None: + raise ValueError( + 'Received value None for key %s. None values are currently not ' + 'supported' % key) + named_fields.append( + (key, convert_to_typing_type(instance_to_type(value)))) + fields_to_values[key] = value + + schema_proto = named_fields_to_schema(named_fields) + row = named_tuple_from_schema(schema_proto)(**fields_to_values) + schema = named_tuple_to_schema(type(row)) + + payload = RowCoder(schema).encode(row) + return (schema_proto, payload) + class SchemaBasedPayloadBuilder(PayloadBuilder): """ @@ -156,6 +179,20 @@ def _get_named_tuple_instance(self): return self._tuple_instance +class SchemaTransformPayloadBuilder(PayloadBuilder): + def __init__(self, identifier, **kwargs): + self._identifier = identifier + self._kwargs = kwargs + + def build(self): + schema_proto, payload = self._get_schema_proto_and_payload(**self._kwargs) + payload = external_transforms_pb2.SchemaTransformPayload( + identifier=self._identifier, + configuration_schema=schema_proto, + configuration_row=payload) + return payload + + class JavaClassLookupPayloadBuilder(PayloadBuilder): """ Builds a payload for directly instantiating a Java transform using a @@ -177,45 +214,26 @@ def __init__(self, class_name): self._constructor_param_kwargs = None self._builder_methods_and_params = OrderedDict() - def _get_schema_proto_and_payload(self, *args, **kwargs): - named_fields = [] - fields_to_values = OrderedDict() + def _args_to_named_fields(self, args): next_field_id = 0 + named_fields = OrderedDict() for value in args: if value is None: raise ValueError( 'Received value None. None values are currently not supported') - named_fields.append( - ((JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT % next_field_id), - convert_to_typing_type(instance_to_type(value)))) - fields_to_values[( + named_fields[( JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT % next_field_id)] = value next_field_id += 1 - for key, value in kwargs.items(): - if not key: - raise ValueError('Parameter name cannot be empty') - if value is None: - raise ValueError( - 'Received value None for key %s. None values are currently not ' - 'supported' % key) - named_fields.append( - (key, convert_to_typing_type(instance_to_type(value)))) - fields_to_values[key] = value - - schema_proto = named_fields_to_schema(named_fields) - row = named_tuple_from_schema(schema_proto)(**fields_to_values) - schema = named_tuple_to_schema(type(row)) - - payload = RowCoder(schema).encode(row) - return (schema_proto, payload) + return named_fields def build(self): - constructor_param_args = self._constructor_param_args or [] - constructor_param_kwargs = self._constructor_param_kwargs or {} + all_constructor_param_kwargs = self._args_to_named_fields( + self._constructor_param_args) + if self._constructor_param_kwargs: + all_constructor_param_kwargs.update(self._constructor_param_kwargs) constructor_schema, constructor_payload = ( - self._get_schema_proto_and_payload( - *constructor_param_args, **constructor_param_kwargs)) + self._get_schema_proto_and_payload(**all_constructor_param_kwargs)) payload = external_transforms_pb2.JavaClassLookupPayload( class_name=self._class_name, constructor_schema=constructor_schema, @@ -225,9 +243,12 @@ def build(self): for builder_method_name, params in self._builder_methods_and_params.items(): builder_method_args, builder_method_kwargs = params + all_builder_method_kwargs = self._args_to_named_fields( + builder_method_args) + if builder_method_kwargs: + all_builder_method_kwargs.update(builder_method_kwargs) builder_method_schema, builder_method_payload = ( - self._get_schema_proto_and_payload( - *builder_method_args, **builder_method_kwargs)) + self._get_schema_proto_and_payload(**all_builder_method_kwargs)) builder_method = external_transforms_pb2.BuilderMethod( name=builder_method_name, schema=builder_method_schema, @@ -289,6 +310,64 @@ def _has_constructor(self): self._constructor_param_kwargs) +# Information regarding a SchemaTransform available in an external SDK. +SchemaTransformsConfig = namedtuple( + 'SchemaTransformsConfig', + ['identifier', 'configuration_schema', 'inputs', 'outputs']) + + +class SchemaAwareExternalTransform(ptransform.PTransform): + """A proxy transform for SchemaTransforms implemented in external SDKs. + + This allows Python pipelines to directly use existing SchemaTransforms + available to the expansion service without adding additional code in external + SDKs. + + :param identifier: unique identifier of the SchemaTransform. + :param expansion_service: an expansion service to use. This should already be + available and the Schema-aware transforms to be used must already be + deployed. + :param classpath: (Optional) A list paths to additional jars to place on the + expansion service classpath. + :kwargs: field name to value mapping for configuring the schema transform. + keys map to the field names of the schema of the SchemaTransform + (in-order). + """ + def __init__(self, identifier, expansion_service, classpath=None, **kwargs): + self._expansion_service = expansion_service + self._payload_builder = SchemaTransformPayloadBuilder(identifier, **kwargs) + self._classpath = classpath + + def expand(self, pcolls): + # Expand the transform using the expansion service. + return pcolls | ExternalTransform( + common_urns.schematransform_based_expand.urn, + self._payload_builder, + self._expansion_service) + + @staticmethod + def discover(expansion_service): + """Discover all SchemaTransforms available to the given expansion service. + + :return: a list of SchemaTransformsConfigs that represent the discovered + SchemaTransforms. + """ + + with ExternalTransform.service(expansion_service) as service: + discover_response = service.DiscoverSchemaTransform( + beam_expansion_api_pb2.DiscoverSchemaTransformRequest()) + + for identifier in discover_response.schema_transform_configs: + proto_config = discover_response.schema_transform_configs[identifier] + schema = named_tuple_from_schema(proto_config.config_schema) + + yield SchemaTransformsConfig( + identifier=identifier, + configuration_schema=schema, + inputs=proto_config.input_pcollection_names, + outputs=proto_config.output_pcollection_names) + + class JavaExternalTransform(ptransform.PTransform): """A proxy for Java-implemented external transforms. @@ -520,7 +599,7 @@ def expand(self, pvalueish): transform=transform_proto, output_coder_requests=output_coders) - with self._service() as service: + with ExternalTransform.service(self._expansion_service) as service: response = service.Expand(request) if response.error: raise RuntimeError(response.error) @@ -549,9 +628,10 @@ def fix_output(pcoll, tag): return self._output_to_pvalueish(self._outputs) + @staticmethod @contextlib.contextmanager - def _service(self): - if isinstance(self._expansion_service, str): + def service(expansion_service): + if isinstance(expansion_service, str): channel_options = [("grpc.max_receive_message_length", -1), ("grpc.max_send_message_length", -1)] if hasattr(grpc, 'local_channel_credentials'): @@ -560,7 +640,7 @@ def _service(self): # TODO: update this to support secure non-local channels. channel_factory_fn = functools.partial( grpc.secure_channel, - self._expansion_service, + expansion_service, grpc.local_channel_credentials(), options=channel_options) else: @@ -568,15 +648,13 @@ def _service(self): # by older versions of grpc which may be pulled in due to other project # dependencies. channel_factory_fn = functools.partial( - grpc.insecure_channel, - self._expansion_service, - options=channel_options) + grpc.insecure_channel, expansion_service, options=channel_options) with channel_factory_fn() as channel: yield ExpansionAndArtifactRetrievalStub(channel) - elif hasattr(self._expansion_service, 'Expand'): - yield self._expansion_service + elif hasattr(expansion_service, 'Expand'): + yield expansion_service else: - with self._expansion_service as stub: + with expansion_service as stub: yield stub def _resolve_artifacts(self, components, service, dest): diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index c567f34330d8..f38876367c39 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -44,6 +44,7 @@ from apache_beam.transforms.external import JavaExternalTransform from apache_beam.transforms.external import JavaJarExpansionService from apache_beam.transforms.external import NamedTupleBasedPayloadBuilder +from apache_beam.transforms.external import SchemaTransformPayloadBuilder from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.utils import proto_utils @@ -445,6 +446,35 @@ class DataclassTransform(beam.ExternalTransform): return get_payload(DataclassTransform(**values)) +class SchemaTransformPayloadBuilderTest(unittest.TestCase): + def test_build_payload(self): + ComplexType = typing.NamedTuple( + "ComplexType", [ + ("str_sub_field", str), + ("int_sub_field", int), + ]) + + payload_builder = SchemaTransformPayloadBuilder( + identifier='dummy_id', + str_field='aaa', + int_field=123, + object_field=ComplexType(str_sub_field="bbb", int_sub_field=456)) + payload_bytes = payload_builder.payload() + payload_from_bytes = proto_utils.parse_Bytes( + payload_bytes, external_transforms_pb2.SchemaTransformPayload) + + self.assertEqual('dummy_id', payload_from_bytes.identifier) + + expected_coder = RowCoder(payload_from_bytes.configuration_schema) + schema_transform_config = expected_coder.decode( + payload_from_bytes.configuration_row) + + self.assertEqual('aaa', schema_transform_config.str_field) + self.assertEqual(123, schema_transform_config.int_field) + self.assertEqual('bbb', schema_transform_config.object_field.str_sub_field) + self.assertEqual(456, schema_transform_config.object_field.int_sub_field) + + class JavaClassLookupPayloadBuilderTest(unittest.TestCase): def _verify_row(self, schema, row_payload, expected_values): row = RowCoder(schema).decode(row_payload)