diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index d1b4f65ce826..2797933d8474 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -21,6 +21,7 @@
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
+import com.google.common.base.Objects;
import java.io.Serializable;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
@@ -620,6 +621,19 @@ public ProcessContinuation withResumeDelay(Duration resumeDelay) {
public ProcessContinuation withFutureOutputWatermark(Instant futureOutputWatermark) {
return new ProcessContinuation(resumeDelay, futureOutputWatermark);
}
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ } else if (!(obj instanceof ProcessContinuation)) {
+ return false;
+ } else {
+ ProcessContinuation that = (ProcessContinuation) obj;
+ return Objects.equal(this.resumeDelay, that.resumeDelay)
+ && Objects.equal(this.futureOutputWatermark, that.futureOutputWatermark);
+ }
+ }
}
/**
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java
index b4d6d0e31be9..12dc7fc6135a 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java
@@ -20,6 +20,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import com.google.common.reflect.TypeToken;
+import java.io.FileOutputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
@@ -29,6 +30,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import javax.annotation.Nullable;
import net.bytebuddy.ByteBuddy;
import net.bytebuddy.NamingStrategy;
import net.bytebuddy.description.field.FieldDescription;
@@ -45,9 +47,7 @@
import net.bytebuddy.implementation.FixedValue;
import net.bytebuddy.implementation.Implementation;
import net.bytebuddy.implementation.Implementation.Context;
-import net.bytebuddy.implementation.MethodCall;
import net.bytebuddy.implementation.MethodDelegation;
-import net.bytebuddy.implementation.bind.MethodDelegationBinder;
import net.bytebuddy.implementation.bind.annotation.TargetMethodAnnotationDrivenBinder;
import net.bytebuddy.implementation.bytecode.ByteCodeAppender;
import net.bytebuddy.implementation.bytecode.StackManipulation;
@@ -61,6 +61,7 @@
import net.bytebuddy.jar.asm.Label;
import net.bytebuddy.jar.asm.MethodVisitor;
import net.bytebuddy.jar.asm.Opcodes;
+import net.bytebuddy.jar.asm.Type;
import net.bytebuddy.matcher.ElementMatchers;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
@@ -223,6 +224,13 @@ public String subclass(TypeDescription.Generic superClass) {
.intercept(delegateWithDowncastOrThrow(signature.newTracker()));
DynamicType.Unloaded> unloaded = builder.make();
+ try {
+ try (FileOutputStream w = new FileOutputStream("/tmp/foo.class")) {
+ w.write(unloaded.getBytes());
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
@SuppressWarnings("unchecked")
Class extends DoFnInvoker, ?>> res =
@@ -249,7 +257,16 @@ private static Implementation delegateWithDowncastOrThrow(DoFnSignature.DoFnMeth
/** Implements an invoker method by delegating to a method of the target {@link DoFn}. */
private abstract static class DoFnMethodDelegation implements Implementation {
- FieldDescription delegateField;
+
+ protected final MethodDescription targetMethod;
+ private FieldDescription delegateField;
+
+ private final boolean targetHasReturn;
+
+ public DoFnMethodDelegation(MethodDescription targetMethod) {
+ this.targetMethod = targetMethod;
+ targetHasReturn = !TypeDescription.VOID.equals(targetMethod.getReturnType().asErasure());
+ }
@Override
public InstrumentedType prepare(InstrumentedType instrumentedType) {
@@ -259,7 +276,6 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) {
.getDeclaredFields()
.filter(ElementMatchers.named(FN_DELEGATE_FIELD_NAME))
.getOnly();
-
// Delegating the method call doesn't require any changes to the instrumented type.
return instrumentedType;
}
@@ -272,25 +288,64 @@ public Size apply(
MethodVisitor methodVisitor,
Context implementationContext,
MethodDescription instrumentedMethod) {
- StackManipulation manipulation =
- new StackManipulation.Compound(
- // Push "this" reference to the stack
- MethodVariableAccess.REFERENCE.loadOffset(0),
- // Access the delegate field of the the invoker
- FieldAccess.forField(delegateField).getter(),
- invokeTargetMethod(instrumentedMethod));
+ // Figure out how many locals we'll need. This corresponds to "this", the parameters
+ // of the instrumented method, and an argument to hold the return value if the target
+ // method has a return value.
+ int numLocals = 1 + instrumentedMethod.getParameters().size() + (targetHasReturn ? 1 : 0);
+
+ Integer returnVarIndex = null;
+ if (targetHasReturn) {
+ // Local comes after formal parameters, so figure out where that is.
+ returnVarIndex = 1; // "this"
+ for (Type param : Type.getArgumentTypes(instrumentedMethod.getDescriptor())) {
+ returnVarIndex += param.getSize();
+ }
+ }
+
+ StackManipulation manipulation = new StackManipulation.Compound(
+ // Push "this" (DoFnInvoker on top of the stack)
+ MethodVariableAccess.REFERENCE.loadOffset(0),
+ // Access this.delegate (DoFn on top of the stack)
+ FieldAccess.forField(delegateField).getter(),
+ // Run the beforeDelegation manipulations.
+ // The arguments necessary to invoke the target are on top of the stack.
+ beforeDelegation(instrumentedMethod),
+ // Perform the method delegation.
+ // This will consume the arguments on top of the stack
+ // Either the stack is now empty (because the targetMethod returns void) or the
+ // stack contains the return value.
+ new UserCodeMethodInvocation(returnVarIndex, targetMethod, instrumentedMethod),
+ // Run the afterDelegation manipulations.
+ // Either the stack is now empty (because the instrumentedMethod returns void)
+ // or the stack contains the return value.
+ afterDelegation(instrumentedMethod));
+
StackManipulation.Size size = manipulation.apply(methodVisitor, implementationContext);
- return new Size(size.getMaximalSize(), instrumentedMethod.getStackSize());
+ return new Size(size.getMaximalSize(), numLocals);
}
};
}
/**
- * Generates code to invoke the target method. When this is called the delegate field will be on
- * top of the stack. This should add any necessary arguments to the stack and then perform the
- * method invocation.
+ * Return the code to the prepare the operand stack for the method delegation.
+ *
+ *
Before this method is called, the stack delegate will be the only thing on the stack.
+ *
+ *
After this method is called, the stack contents should contain exactly the arguments
+ * necessary to invoke the target method.
*/
- protected abstract StackManipulation invokeTargetMethod(MethodDescription instrumentedMethod);
+ protected abstract StackManipulation beforeDelegation(MethodDescription instrumentedMethod);
+
+ /**
+ * Return the code to execute after the method delegation.
+ *
+ *
Before this method is called, the stack will either be empty (if the target method
+ * returns void) or contain the method return value.
+ *
+ *
After this method is called, the stack should either be empty (if the instrumented method
+ * returns void) or contain the avlue for the instrumented method to return).
+ */
+ protected abstract StackManipulation afterDelegation(MethodDescription instrumentedMethod);
}
/**
@@ -339,16 +394,13 @@ private static final class ProcessElementDelegation extends DoFnMethodDelegation
/** Implementation of {@link MethodDelegation} for the {@link ProcessElement} method. */
private ProcessElementDelegation(DoFnSignature.ProcessElementMethod signature) {
+ super(new MethodDescription.ForLoadedMethod(signature.targetMethod()));
this.signature = signature;
}
@Override
- protected StackManipulation invokeTargetMethod(MethodDescription instrumentedMethod) {
- MethodDescription targetMethod =
- new MethodCall.MethodLocator.ForExplicitMethod(
- new MethodDescription.ForLoadedMethod(signature.targetMethod()))
- .resolve(instrumentedMethod);
-
+ protected StackManipulation beforeDelegation(
+ MethodDescription instrumentedMethod) {
// Parameters of the wrapper invoker method:
// DoFn.ProcessContext, ExtraContextFactory.
// Parameters of the wrapped DoFn method:
@@ -368,26 +420,21 @@ protected StackManipulation invokeTargetMethod(MethodDescription instrumentedMet
// Insert a downcast.
(param == DoFnSignature.Parameter.RESTRICTION_TRACKER)
? TypeCasting.to(
- new TypeDescription.ForLoadedType(signature.trackerT().getRawType()))
+ new TypeDescription.ForLoadedType(signature.trackerT().getRawType()))
: StackManipulation.Trivial.INSTANCE));
}
+ return new StackManipulation.Compound(parameters);
+ }
- return new StackManipulation.Compound(
- // Push the parameters
- new StackManipulation.Compound(parameters),
- // Invoke the target method
- wrapWithUserCodeException(
- MethodDelegationBinder.MethodInvoker.Simple.INSTANCE.invoke(targetMethod),
- targetMethod.getReturnType().asErasure(),
- instrumentedMethod),
- // Return from the instrumented @ProcessElement method:
- // if it returns void, then return null (meaning don't resume),
- // otherwise return the ProcessContinuation it returned.
- signature.hasReturnValue()
- ? MethodReturn.returning(targetMethod.getReturnType().asErasure())
- : new StackManipulation.Compound(
- MethodInvocation.invoke(PROCESS_CONTINUATION_STOP_METHOD),
- MethodReturn.REFERENCE));
+ @Override
+ protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) {
+ if (TypeDescription.VOID.equals(targetMethod.getReturnType())) {
+ return new StackManipulation.Compound(
+ MethodInvocation.invoke(PROCESS_CONTINUATION_STOP_METHOD),
+ MethodReturn.REFERENCE);
+ } else {
+ return MethodReturn.returning(targetMethod.getReturnType().asErasure());
+ }
}
}
@@ -396,33 +443,20 @@ protected StackManipulation invokeTargetMethod(MethodDescription instrumentedMet
* downcasting parameters to the proper type.
*/
private static class SimpleMethodDelegation extends DoFnMethodDelegation {
- private final Method method;
- SimpleMethodDelegation(Method method) {
- this.method = method;
+ protected SimpleMethodDelegation(Method method) {
+ super(new MethodDescription.ForLoadedMethod(method));
}
@Override
- protected StackManipulation invokeTargetMethod(MethodDescription instrumentedMethod) {
- MethodDescription targetMethod =
- new MethodCall.MethodLocator.ForExplicitMethod(
- new MethodDescription.ForLoadedMethod(method))
- .resolve(instrumentedMethod);
- return new StackManipulation.Compound(
- pushArgumentsOfInstrumentedMethod(targetMethod),
- // Invoke the target method
- wrapWithUserCodeException(
- MethodDelegationBinder.MethodInvoker.Simple.INSTANCE.invoke(targetMethod),
- targetMethod.getReturnType().asErasure(),
- instrumentedMethod),
- new StackManipulation.Compound(
- // Return from the instrumented method
- TargetMethodAnnotationDrivenBinder.TerminationHandler.Returning.INSTANCE.resolve(
- Assigner.DEFAULT, instrumentedMethod, targetMethod)));
+ protected StackManipulation beforeDelegation(MethodDescription instrumentedMethod) {
+ return MethodVariableAccess.allArgumentsOf(targetMethod);
}
- protected StackManipulation pushArgumentsOfInstrumentedMethod(MethodDescription targetMethod) {
- return MethodVariableAccess.allArgumentsOf(targetMethod);
+ @Override
+ protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) {
+ return TargetMethodAnnotationDrivenBinder.TerminationHandler.Returning.INSTANCE
+ .resolve(Assigner.DEFAULT, instrumentedMethod, targetMethod);
}
}
@@ -436,7 +470,7 @@ private static class DowncastingParametersMethodDelegation extends SimpleMethodD
}
@Override
- protected StackManipulation pushArgumentsOfInstrumentedMethod(MethodDescription targetMethod) {
+ protected StackManipulation beforeDelegation(MethodDescription instrumentedMethod) {
List pushParameters = new ArrayList<>();
TypeList.Generic paramTypes = targetMethod.getParameters().asTypeList();
for (int i = 0; i < paramTypes.size(); i++) {
@@ -450,15 +484,36 @@ protected StackManipulation pushArgumentsOfInstrumentedMethod(MethodDescription
}
}
- private abstract static class BaseWrapUserCodeException implements StackManipulation {
+ private static class UserCodeMethodInvocation implements StackManipulation {
+
+ @Nullable
+ private final Integer returnVarIndex;
+ private final MethodDescription targetMethod;
+ private final MethodDescription instrumentedMethod;
+ private final TypeDescription returnType;
+
+ private final Label wrapStart = new Label();
+ private final Label wrapEnd = new Label();
+ private final Label tryBlockStart = new Label();
+ private final Label tryBlockEnd = new Label();
+ private final Label catchBlockStart = new Label();
+ private final Label catchBlockEnd = new Label();
- /** {@link MethodDescription} for {@link UserCodeException#wrap}. */
private final MethodDescription createUserCodeException;
- private final StackManipulation tryBody;
+ public UserCodeMethodInvocation(
+ Integer returnVarIndex,
+ MethodDescription targetMethod,
+ MethodDescription instrumentedMethod) {
+ this.returnVarIndex = returnVarIndex;
+ this.targetMethod = targetMethod;
+ this.instrumentedMethod = instrumentedMethod;
+ this.returnType = targetMethod.getReturnType().asErasure();
+
+ boolean targetMethodReturnsVoid = TypeDescription.VOID.equals(returnType);
+ checkArgument((returnVarIndex == null) == targetMethodReturnsVoid,
+ "returnLocalIndex should be defined if and only if the target method has a return value");
- BaseWrapUserCodeException(StackManipulation tryBody) {
- this.tryBody = tryBody;
try {
createUserCodeException =
new MethodDescription.ForLoadedMethod(
@@ -470,23 +525,52 @@ private abstract static class BaseWrapUserCodeException implements StackManipula
@Override
public boolean isValid() {
- return tryBody.isValid();
+ return true;
}
- void visitFrameWithThrowableOnStack(MethodVisitor mv) {
- String throwableName = new TypeDescription.ForLoadedType(Throwable.class).getInternalName();
- mv.visitFrame(Opcodes.F_SAME1, 0, new Object[] {}, 1, new Object[] {throwableName});
+ private Object describeType(Type type) {
+ switch (type.getSort()) {
+ case Type.OBJECT:
+ return type.getDescriptor();
+ case Type.INT:
+ case Type.BYTE:
+ case Type.BOOLEAN:
+ case Type.SHORT:
+ return Opcodes.INTEGER;
+ case Type.LONG:
+ return Opcodes.LONG;
+ case Type.DOUBLE:
+ return Opcodes.DOUBLE;
+ case Type.FLOAT:
+ return Opcodes.FLOAT;
+ default:
+ throw new IllegalArgumentException("Unhandled type as method argument: " + type);
+ }
}
- protected abstract void visitFrameWithReturnOnStack(MethodVisitor mv);
+ private void visitFrame(MethodVisitor mv,
+ boolean localsIncludeReturn,
+ @Nullable String stackTop) {
+ boolean hasReturnLocal = (returnVarIndex != null) && localsIncludeReturn;
+
+ Type[] localTypes = Type.getArgumentTypes(instrumentedMethod.getDescriptor());
+ Object[] locals = new Object[1 + localTypes.length + (hasReturnLocal ? 1 : 0)];
+ locals[0] = instrumentedMethod.getReceiverType().asErasure().getInternalName();
+ for (int i = 0; i < localTypes.length; i++) {
+ locals[i + 1] = describeType(localTypes[i]);
+ }
+ if (hasReturnLocal) {
+ locals[locals.length - 1] = returnType.getInternalName();
+ }
+
+ Object[] stack = stackTop == null ? new Object[] {} : new Object[] { stackTop };
+
+ mv.visitFrame(Opcodes.F_NEW, locals.length, locals, stack.length, stack);
+ }
@Override
- public Size apply(MethodVisitor mv, Context implementationContext) {
- Label wrapStart = new Label();
- Label tryBlockStart = new Label();
- Label tryBlockEnd = new Label();
- Label catchBlockStart = new Label();
- Label catchBlockEnd = new Label();
+ public Size apply(MethodVisitor mv, Context context) {
+ Size size = new Size(0, 0);
mv.visitLabel(wrapStart);
@@ -495,55 +579,48 @@ public Size apply(MethodVisitor mv, Context implementationContext) {
// The try block attempts to perform the expected operations, then jumps to success
mv.visitLabel(tryBlockStart);
- Size trySize = tryBody.apply(mv, implementationContext);
+ size = size.aggregate(MethodInvocation.invoke(targetMethod).apply(mv, context));
- // After try body, should have same locals and the return type on the stack.
- visitFrameWithReturnOnStack(mv);
+ if (returnVarIndex != null) {
+ mv.visitVarInsn(Opcodes.ASTORE, returnVarIndex);
+ size = size.aggregate(new Size(-1, 0)); // Reduces the size of the stack
+ }
mv.visitJumpInsn(Opcodes.GOTO, catchBlockEnd);
mv.visitLabel(tryBlockEnd);
// The handler wraps the exception, and then throws.
mv.visitLabel(catchBlockStart);
// In catch block, should have same locals and {Throwable} on the stack.
- visitFrameWithThrowableOnStack(mv);
+ visitFrame(mv, false, throwableName);
- Size catchSize =
+ // Create the user code exception and throw
+ size = size.aggregate(
new Compound(MethodInvocation.invoke(createUserCodeException), Throw.INSTANCE)
- .apply(mv, implementationContext);
+ .apply(mv, context));
mv.visitLabel(catchBlockEnd);
- // After catch block, should have same locals and the return type on the stack
- visitFrameWithReturnOnStack(mv);
- return new Size(
- trySize.getSizeImpact() /* Same total size impact as wrapped body */,
- Math.max(trySize.getMaximalSize(), catchSize.getMaximalSize()));
- }
- }
+ // After the catch block we should have the return in scope, but nothing on the stack.
+ visitFrame(mv, true, null);
- /**
- * Wraps a given stack manipulation in a try catch block. Any exceptions thrown within the try are
- * wrapped with a {@link UserCodeException}.
- */
- private static StackManipulation wrapWithUserCodeException(
- StackManipulation tryBody,
- final TypeDescription returnType,
- MethodDescription instrumentedMethod) {
- if (TypeDescription.VOID.equals(returnType)) {
- return new BaseWrapUserCodeException(tryBody) {
- @Override
- protected void visitFrameWithReturnOnStack(MethodVisitor mv) {
- mv.visitFrame(Opcodes.F_SAME, 0, new Object[] {}, 0, new Object[] {});
- }
- };
- } else {
- return new BaseWrapUserCodeException(tryBody) {
- @Override
- protected void visitFrameWithReturnOnStack(MethodVisitor mv) {
- mv.visitFrame(
- Opcodes.F_SAME1, 0, new Object[] {}, 1, new Object[] {returnType.getInternalName()});
- }
- };
+ // After catch block, should have same locals and will have the return on the stack.
+ if (returnVarIndex != null) {
+ mv.visitVarInsn(Opcodes.ALOAD, returnVarIndex);
+ size = size.aggregate(new Size(1, 0)); // Increases the size of the stack
+ }
+ mv.visitLabel(wrapEnd);
+ if (returnVarIndex != null) {
+ // Drop the return type from the locals
+ mv.visitLocalVariable(
+ "res",
+ returnType.getDescriptor(),
+ returnType.getGenericSignature(),
+ wrapStart,
+ wrapEnd,
+ returnVarIndex);
+ }
+
+ return size;
}
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
index 20a887aa458f..da44a869420d 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
@@ -17,11 +17,13 @@
*/
package org.apache.beam.sdk.transforms.reflect;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.UserCodeException;
@@ -87,7 +89,8 @@ public RestrictionTracker restrictionTracker() {
};
}
- private void checkInvokeProcessElementWorks(DoFn fn, Invocations... invocations)
+ private void checkInvokeProcessElementWorks(
+ DoFn fn, ProcessContinuation expected, Invocations... invocations)
throws Exception {
assertTrue("Need at least one invocation to check", invocations.length >= 1);
for (Invocations invocation : invocations) {
@@ -95,9 +98,10 @@ private void checkInvokeProcessElementWorks(DoFn fn, Invocations
"Should not yet have called processElement on " + invocation.name,
invocation.wasProcessElementInvoked);
}
- DoFnInvokers.INSTANCE
+ ProcessContinuation actual = DoFnInvokers.INSTANCE
.newByteBuddyInvoker(fn)
.invokeProcessElement(mockContext, extraContextFactory);
+ assertEquals("Should return proper continuation", expected, actual);
for (Invocations invocation : invocations) {
assertTrue(
"Should have called processElement on " + invocation.name,
@@ -182,7 +186,7 @@ public void processElement(ProcessContext c) throws Exception {
.processElement()
.usesSingleWindow());
- checkInvokeProcessElementWorks(fn, invocations);
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), invocations);
}
@Test
@@ -223,7 +227,7 @@ public void testDoFnWithProcessElementInterface() throws Exception {
.getOrParseSignature(fn.getClass())
.processElement()
.usesSingleWindow());
- checkInvokeProcessElementWorks(fn, fn.invocations);
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), fn.invocations);
}
private class IdentityParent extends DoFn {
@@ -256,7 +260,7 @@ public void testDoFnWithMethodInSuperclass() throws Exception {
.getOrParseSignature(fn.getClass())
.processElement()
.usesSingleWindow());
- checkInvokeProcessElementWorks(fn, fn.parentInvocations);
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), fn.parentInvocations);
}
@Test
@@ -267,7 +271,8 @@ public void testDoFnWithMethodInSubclass() throws Exception {
.getOrParseSignature(fn.getClass())
.processElement()
.usesSingleWindow());
- checkInvokeProcessElementWorks(fn, fn.parentInvocations, fn.childInvocations);
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(),
+ fn.parentInvocations, fn.childInvocations);
}
@Test
@@ -289,7 +294,7 @@ public void processElement(ProcessContext c, BoundedWindow w) throws Exception {
.processElement()
.usesSingleWindow());
- checkInvokeProcessElementWorks(fn, invocations);
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), invocations);
}
@Test
@@ -311,7 +316,7 @@ public void processElement(ProcessContext c, OutputReceiver o) throws Ex
.processElement()
.usesSingleWindow());
- checkInvokeProcessElementWorks(fn, invocations);
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), invocations);
}
@Test
@@ -333,7 +338,31 @@ public void processElement(ProcessContext c, InputProvider i) throws Exc
.processElement()
.usesSingleWindow());
- checkInvokeProcessElementWorks(fn, invocations);
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), invocations);
+ }
+
+ @Test
+ public void testDoFnWithReturn() throws Exception {
+ final Invocations invocations = new Invocations("AnonymousClass");
+ DoFn fn =
+ new DoFn() {
+ @ProcessElement
+ public ProcessContinuation processElement(
+ ProcessContext c, InputProvider i) throws Exception {
+ invocations.wasProcessElementInvoked = true;
+ assertSame(c, mockContext);
+ assertSame(i, mockInputProvider);
+ return ProcessContinuation.resume();
+ }
+ };
+
+ assertFalse(
+ DoFnSignatures.INSTANCE
+ .getOrParseSignature(fn.getClass())
+ .processElement()
+ .usesSingleWindow());
+
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.resume(), invocations);
}
@Test
@@ -408,49 +437,61 @@ public void processThis(ProcessContext c) {
@Test
public void testLocalPrivateDoFnClass() throws Exception {
PrivateDoFnClass fn = new PrivateDoFnClass();
- checkInvokeProcessElementWorks(fn, fn.invocations);
+ checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), fn.invocations);
}
@Test
public void testStaticPackagePrivateDoFnClass() throws Exception {
Invocations invocations = new Invocations("StaticPackagePrivateDoFn");
checkInvokeProcessElementWorks(
- DoFnInvokersTestHelper.newStaticPackagePrivateDoFn(invocations), invocations);
+ DoFnInvokersTestHelper.newStaticPackagePrivateDoFn(invocations),
+ ProcessContinuation.stop(),
+ invocations);
}
@Test
public void testInnerPackagePrivateDoFnClass() throws Exception {
Invocations invocations = new Invocations("InnerPackagePrivateDoFn");
checkInvokeProcessElementWorks(
- new DoFnInvokersTestHelper().newInnerPackagePrivateDoFn(invocations), invocations);
+ new DoFnInvokersTestHelper().newInnerPackagePrivateDoFn(invocations),
+ ProcessContinuation.stop(),
+ invocations);
}
@Test
public void testStaticPrivateDoFnClass() throws Exception {
Invocations invocations = new Invocations("StaticPrivateDoFn");
checkInvokeProcessElementWorks(
- DoFnInvokersTestHelper.newStaticPrivateDoFn(invocations), invocations);
+ DoFnInvokersTestHelper.newStaticPrivateDoFn(invocations),
+ ProcessContinuation.stop(),
+ invocations);
}
@Test
public void testInnerPrivateDoFnClass() throws Exception {
Invocations invocations = new Invocations("StaticInnerDoFn");
checkInvokeProcessElementWorks(
- new DoFnInvokersTestHelper().newInnerPrivateDoFn(invocations), invocations);
+ new DoFnInvokersTestHelper().newInnerPrivateDoFn(invocations),
+ ProcessContinuation.stop(),
+ invocations);
}
@Test
public void testAnonymousInnerDoFnInOtherPackage() throws Exception {
Invocations invocations = new Invocations("AnonymousInnerDoFnInOtherPackage");
checkInvokeProcessElementWorks(
- new DoFnInvokersTestHelper().newInnerAnonymousDoFn(invocations), invocations);
+ new DoFnInvokersTestHelper().newInnerAnonymousDoFn(invocations),
+ ProcessContinuation.stop(),
+ invocations);
}
@Test
public void testStaticAnonymousDoFnInOtherPackage() throws Exception {
Invocations invocations = new Invocations("AnonymousStaticDoFnInOtherPackage");
checkInvokeProcessElementWorks(
- DoFnInvokersTestHelper.newStaticAnonymousDoFn(invocations), invocations);
+ DoFnInvokersTestHelper.newStaticAnonymousDoFn(invocations),
+ ProcessContinuation.stop(),
+ invocations);
}
@Test
@@ -468,6 +509,22 @@ public void processElement(@SuppressWarnings("unused") ProcessContext c) {
DoFnInvokers.INSTANCE.newByteBuddyInvoker(fn).invokeProcessElement(null, null);
}
+ @Test
+ public void testProcessElementExceptionWithReturn() throws Exception {
+ DoFn fn =
+ new DoFn() {
+ @ProcessElement
+ public ProcessContinuation processElement(@SuppressWarnings("unused") ProcessContext c) {
+ throw new IllegalArgumentException("bogus");
+ }
+ };
+
+ thrown.expect(UserCodeException.class);
+ thrown.expectMessage("bogus");
+ DoFnInvokers.INSTANCE.newByteBuddyInvoker(fn).invokeProcessElement(null, null);
+ }
+
+
@Test
public void testStartBundleException() throws Exception {
DoFn fn =