Skip to content

Commit

Permalink
Updates ExpansionService to support dynamically discovering and expan…
Browse files Browse the repository at this point in the history
…ding SchemaTransforms
  • Loading branch information
chamikaramj committed Sep 29, 2022
1 parent b59df6c commit 6ec0d64
Show file tree
Hide file tree
Showing 8 changed files with 963 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ option java_package = "org.apache.beam.model.expansion.v1";
option java_outer_classname = "ExpansionApi";

import "org/apache/beam/model/pipeline/v1/beam_runner_api.proto";
import "org/apache/beam/model/pipeline/v1/schema.proto";

message ExpansionRequest {
// Set of components needed to interpret the transform, or which
Expand Down Expand Up @@ -72,7 +73,34 @@ message ExpansionResponse {
string error = 10;
}

message DiscoverSchemaTransformRequest {
}

message SchemaTransformConfig {
// Config schema of the SchemaTransform
org.apache.beam.model.pipeline.v1.Schema config_schema = 1;

// Names of input PCollections
repeated string input_pcollection_names = 2;

// Names of output PCollections
repeated string output_pcollection_names = 3;
}

message DiscoverSchemaTransformResponse {
// A mapping from SchemaTransform ID to schema transform config of discovered
// SchemaTransforms
map <string, SchemaTransformConfig> schema_transform_configs = 1;

// If list of identifies are empty, this may contain an error.
string error = 2;
}

// Job Service for constructing pipelines
service ExpansionService {
rpc Expand (ExpansionRequest) returns (ExpansionResponse);

//A RPC to discover already registered SchemaTransformProviders.
// See https://s.apache.org/easy-multi-language for more details.
rpc DiscoverSchemaTransform (DiscoverSchemaTransformRequest) returns (DiscoverSchemaTransformResponse);
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ message ExpansionMethods {
// Transform payload will be of type JavaClassLookupPayload.
JAVA_CLASS_LOOKUP = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) =
"beam:expansion:payload:java_class_lookup:v1"];

// Expanding a SchemaTransform identified by the expansion service.
// Transform payload will be of type SchemaTransformPayload.
SCHEMATRANSFORM = 1 [(org.apache.beam.model.pipeline.v1.beam_urn) =
"beam:expansion:payload:schematransform:v1"];
}
}

Expand Down Expand Up @@ -106,4 +111,16 @@ message BuilderMethod {
bytes payload = 3;
}

message SchemaTransformPayload {
// The identifier of the SchemaTransform (typically a URN).
string identifier = 1;

// The configuration schema of the SchemaTransform.
Schema configuration_schema = 2;

// The configuration of the SchemaTransform.
// Should be decodable via beam:coder:row:v1.
// The schema of the Row should be compatible with the schema of the
// SchemaTransform.
bytes configuration_row = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.model.expansion.v1.ExpansionApi;
import org.apache.beam.model.expansion.v1.ExpansionApi.DiscoverSchemaTransformRequest;
import org.apache.beam.model.expansion.v1.ExpansionApi.DiscoverSchemaTransformResponse;
import org.apache.beam.model.expansion.v1.ExpansionApi.SchemaTransformConfig;
import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc;
import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExternalConfigurationPayload;
Expand Down Expand Up @@ -63,6 +66,7 @@
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.transforms.ExternalTransformBuilder;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
Expand Down Expand Up @@ -436,6 +440,10 @@ private Map<String, TransformProvider> getRegisteredTransforms() {
return registeredTransforms;
}

private Iterable<SchemaTransformProvider> getRegisteredSchemaTransforms() {
return ExpansionServiceSchemaTransformProvider.of().getAllProviders();
}

private Map<String, TransformProvider> loadRegisteredTransforms() {
ImmutableMap.Builder<String, TransformProvider> registeredTransformsBuilder =
ImmutableMap.builder();
Expand Down Expand Up @@ -500,6 +508,8 @@ private Map<String, TransformProvider> loadRegisteredTransforms() {
pipelineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist();
assert allowList != null;
transformProvider = new JavaClassLookupTransformProvider(allowList);
} else if (getUrn(ExpansionMethods.Enum.SCHEMATRANSFORM).equals(urn)) {
transformProvider = ExpansionServiceSchemaTransformProvider.of();
} else {
transformProvider = getRegisteredTransforms().get(urn);
if (transformProvider == null) {
Expand Down Expand Up @@ -604,6 +614,42 @@ public void expand(
}
}

DiscoverSchemaTransformResponse discover(DiscoverSchemaTransformRequest request) {
ExpansionServiceSchemaTransformProvider transformProvider =
ExpansionServiceSchemaTransformProvider.of();
DiscoverSchemaTransformResponse.Builder responseBuilder =
DiscoverSchemaTransformResponse.newBuilder();
for (org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider provider :
transformProvider.getAllProviders()) {
SchemaTransformConfig.Builder schemaTransformConfigBuider =
SchemaTransformConfig.newBuilder();
schemaTransformConfigBuider.setConfigSchema(
SchemaTranslation.schemaToProto(provider.configurationSchema(), true));
schemaTransformConfigBuider.addAllInputPcollectionNames(provider.inputCollectionNames());
schemaTransformConfigBuider.addAllOutputPcollectionNames(provider.outputCollectionNames());
responseBuilder.putSchemaTransformConfigs(
provider.identifier(), schemaTransformConfigBuider.build());
}

return responseBuilder.build();
}

@Override
public void discoverSchemaTransform(
DiscoverSchemaTransformRequest request,
StreamObserver<DiscoverSchemaTransformResponse> responseObserver) {
try {
responseObserver.onNext(discover(request));
responseObserver.onCompleted();
} catch (RuntimeException exn) {
responseObserver.onNext(
ExpansionApi.DiscoverSchemaTransformResponse.newBuilder()
.setError(Throwables.getStackTraceAsString(exn))
.build());
responseObserver.onCompleted();
}
}

@Override
public void close() throws Exception {
// Nothing to do because the expansion service is stateless.
Expand All @@ -618,9 +664,36 @@ public static void main(String[] args) throws Exception {

@SuppressWarnings("nullness")
ExpansionService service = new ExpansionService(Arrays.copyOfRange(args, 1, args.length));

StringBuilder registeredTransformsLog = new StringBuilder();
boolean registeredTransformsFound = false;
registeredTransformsLog.append("\n");
registeredTransformsLog.append("Registered transforms:");

for (Map.Entry<String, TransformProvider> entry :
service.getRegisteredTransforms().entrySet()) {
System.out.println("\t" + entry.getKey() + ": " + entry.getValue());
registeredTransformsFound = true;
registeredTransformsLog.append("\n\t" + entry.getKey() + ": " + entry.getValue());
}

StringBuilder registeredSchemaTransformProvidersLog = new StringBuilder();
boolean registeredSchemaTransformProvidersFound = false;
registeredSchemaTransformProvidersLog.append("\n");
registeredSchemaTransformProvidersLog.append("Registered SchemaTransformProviders:");

for (SchemaTransformProvider provider : service.getRegisteredSchemaTransforms()) {
registeredSchemaTransformProvidersFound = true;
registeredSchemaTransformProvidersLog.append("\n\t" + provider.identifier());
}

if (registeredTransformsFound) {
System.out.println(registeredTransformsLog.toString());
}
if (registeredSchemaTransformProvidersFound) {
System.out.println(registeredSchemaTransformProvidersLog.toString());
}
if (!registeredTransformsFound && !registeredSchemaTransformProvidersFound) {
System.out.println("\nDid not find any registered transforms or SchemaTransforms.\n");
}

Server server =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.expansion.service;

import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.ServiceLoader;
import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.InvalidProtocolBufferException;

@SuppressWarnings({"rawtypes"})
public class ExpansionServiceSchemaTransformProvider implements TransformProvider {

static final String DEFAULT_INPUT_TAG = "INPUT";

private Map<String, org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider>
schemaTransformProviders = new HashMap<>();
private static ExpansionServiceSchemaTransformProvider transformProvider = null;

private void loadSchemaTransforms() {
try {
for (org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider schemaTransformProvider :
ServiceLoader.load(
org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider.class)) {
if (schemaTransformProviders.containsKey(schemaTransformProvider.identifier())) {
throw new IllegalArgumentException(
"Found multiple SchemaTransformProvider implementations with the same identifier "
+ schemaTransformProvider.identifier());
}
schemaTransformProviders.put(schemaTransformProvider.identifier(), schemaTransformProvider);
}
} catch (Exception e) {
throw new RuntimeException(e.getMessage());
}
}

private ExpansionServiceSchemaTransformProvider() {
loadSchemaTransforms();
}

public static ExpansionServiceSchemaTransformProvider of() {
if (transformProvider == null) {
transformProvider = new ExpansionServiceSchemaTransformProvider();
}

return transformProvider;
}

static class RowTransform extends PTransform {

private PTransform<PCollectionRowTuple, PCollectionRowTuple> rowTuplePTransform;

public RowTransform(PTransform<PCollectionRowTuple, PCollectionRowTuple> rowTuplePTransform) {
this.rowTuplePTransform = rowTuplePTransform;
}

@Override
public POutput expand(PInput input) {
PCollectionRowTuple inputRowTuple;

if (input instanceof PCollectionRowTuple) {
inputRowTuple = (PCollectionRowTuple) input;
} else if (input instanceof PCollection) {
inputRowTuple = PCollectionRowTuple.of(DEFAULT_INPUT_TAG, (PCollection) input);
} else if (input instanceof PBegin) {
inputRowTuple = PCollectionRowTuple.empty(input.getPipeline());
} else if (input instanceof PCollectionTuple) {
inputRowTuple = PCollectionRowTuple.empty(input.getPipeline());
PCollectionTuple inputTuple = (PCollectionTuple) input;
for (TupleTag<?> tag : inputTuple.getAll().keySet()) {
inputRowTuple = inputRowTuple.and(tag.getId(), (PCollection<Row>) inputTuple.get(tag));
}
} else {
throw new RuntimeException(String.format("Unsupported input type: %s", input));
}
PCollectionRowTuple output = inputRowTuple.apply(this.rowTuplePTransform);

if (output.getAll().size() > 1) {
PCollectionTuple pcTuple = PCollectionTuple.empty(input.getPipeline());
for (String key : output.getAll().keySet()) {
pcTuple = pcTuple.and(key, output.get(key));
}
return pcTuple;
} else if (output.getAll().size() == 1) {
return output.getAll().values().iterator().next();
} else {
return PDone.in(input.getPipeline());
}
}
}

@Override
public PTransform getTransform(FunctionSpec spec) {
SchemaTransformPayload payload;
try {
payload = SchemaTransformPayload.parseFrom(spec.getPayload());
String identifier = payload.getIdentifier();
if (!schemaTransformProviders.containsKey(identifier)) {
throw new RuntimeException(
"Did not find a SchemaTransformProvider with the identifier " + identifier);
}

} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException(
"Invalid payload type for URN " + getUrn(ExpansionMethods.Enum.SCHEMATRANSFORM), e);
}

String identifier = payload.getIdentifier();
org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider provider =
schemaTransformProviders.get(identifier);

Schema configSchemaFromRequest =
SchemaTranslation.schemaFromProto((payload.getConfigurationSchema()));
Schema configSchemaFromProvider = provider.configurationSchema();

if (!configSchemaFromRequest.assignableTo(configSchemaFromProvider)) {
throw new IllegalArgumentException(
String.format(
"Config schema provided with the expansion request %s is not compatible with the "
+ "config of the Schema transform %s.",
configSchemaFromRequest, configSchemaFromProvider));
}

Row configRow;
try {
configRow =
RowCoder.of(provider.configurationSchema())
.decode(payload.getConfigurationRow().newInput());
} catch (IOException e) {
throw new RuntimeException("Error decoding payload", e);
}

return new RowTransform(provider.from(configRow).buildTransform());
}

Iterable<org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider> getAllProviders() {
return schemaTransformProviders.values();
}
}
Loading

0 comments on commit 6ec0d64

Please sign in to comment.