diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index d1b4f65ce826..2797933d8474 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.base.Objects; import java.io.Serializable; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; @@ -620,6 +621,19 @@ public ProcessContinuation withResumeDelay(Duration resumeDelay) { public ProcessContinuation withFutureOutputWatermark(Instant futureOutputWatermark) { return new ProcessContinuation(resumeDelay, futureOutputWatermark); } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } else if (!(obj instanceof ProcessContinuation)) { + return false; + } else { + ProcessContinuation that = (ProcessContinuation) obj; + return Objects.equal(this.resumeDelay, that.resumeDelay) + && Objects.equal(this.futureOutputWatermark, that.futureOutputWatermark); + } + } } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java index b4d6d0e31be9..12dc7fc6135a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkArgument; import com.google.common.reflect.TypeToken; +import java.io.FileOutputStream; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -29,6 +30,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import javax.annotation.Nullable; import net.bytebuddy.ByteBuddy; import net.bytebuddy.NamingStrategy; import net.bytebuddy.description.field.FieldDescription; @@ -45,9 +47,7 @@ import net.bytebuddy.implementation.FixedValue; import net.bytebuddy.implementation.Implementation; import net.bytebuddy.implementation.Implementation.Context; -import net.bytebuddy.implementation.MethodCall; import net.bytebuddy.implementation.MethodDelegation; -import net.bytebuddy.implementation.bind.MethodDelegationBinder; import net.bytebuddy.implementation.bind.annotation.TargetMethodAnnotationDrivenBinder; import net.bytebuddy.implementation.bytecode.ByteCodeAppender; import net.bytebuddy.implementation.bytecode.StackManipulation; @@ -61,6 +61,7 @@ import net.bytebuddy.jar.asm.Label; import net.bytebuddy.jar.asm.MethodVisitor; import net.bytebuddy.jar.asm.Opcodes; +import net.bytebuddy.jar.asm.Type; import net.bytebuddy.matcher.ElementMatchers; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; @@ -223,6 +224,13 @@ public String subclass(TypeDescription.Generic superClass) { .intercept(delegateWithDowncastOrThrow(signature.newTracker())); DynamicType.Unloaded unloaded = builder.make(); + try { + try (FileOutputStream w = new FileOutputStream("/tmp/foo.class")) { + w.write(unloaded.getBytes()); + } + } catch (Exception e) { + e.printStackTrace(); + } @SuppressWarnings("unchecked") Class> res = @@ -249,7 +257,16 @@ private static Implementation delegateWithDowncastOrThrow(DoFnSignature.DoFnMeth /** Implements an invoker method by delegating to a method of the target {@link DoFn}. */ private abstract static class DoFnMethodDelegation implements Implementation { - FieldDescription delegateField; + + protected final MethodDescription targetMethod; + private FieldDescription delegateField; + + private final boolean targetHasReturn; + + public DoFnMethodDelegation(MethodDescription targetMethod) { + this.targetMethod = targetMethod; + targetHasReturn = !TypeDescription.VOID.equals(targetMethod.getReturnType().asErasure()); + } @Override public InstrumentedType prepare(InstrumentedType instrumentedType) { @@ -259,7 +276,6 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) { .getDeclaredFields() .filter(ElementMatchers.named(FN_DELEGATE_FIELD_NAME)) .getOnly(); - // Delegating the method call doesn't require any changes to the instrumented type. return instrumentedType; } @@ -272,25 +288,64 @@ public Size apply( MethodVisitor methodVisitor, Context implementationContext, MethodDescription instrumentedMethod) { - StackManipulation manipulation = - new StackManipulation.Compound( - // Push "this" reference to the stack - MethodVariableAccess.REFERENCE.loadOffset(0), - // Access the delegate field of the the invoker - FieldAccess.forField(delegateField).getter(), - invokeTargetMethod(instrumentedMethod)); + // Figure out how many locals we'll need. This corresponds to "this", the parameters + // of the instrumented method, and an argument to hold the return value if the target + // method has a return value. + int numLocals = 1 + instrumentedMethod.getParameters().size() + (targetHasReturn ? 1 : 0); + + Integer returnVarIndex = null; + if (targetHasReturn) { + // Local comes after formal parameters, so figure out where that is. + returnVarIndex = 1; // "this" + for (Type param : Type.getArgumentTypes(instrumentedMethod.getDescriptor())) { + returnVarIndex += param.getSize(); + } + } + + StackManipulation manipulation = new StackManipulation.Compound( + // Push "this" (DoFnInvoker on top of the stack) + MethodVariableAccess.REFERENCE.loadOffset(0), + // Access this.delegate (DoFn on top of the stack) + FieldAccess.forField(delegateField).getter(), + // Run the beforeDelegation manipulations. + // The arguments necessary to invoke the target are on top of the stack. + beforeDelegation(instrumentedMethod), + // Perform the method delegation. + // This will consume the arguments on top of the stack + // Either the stack is now empty (because the targetMethod returns void) or the + // stack contains the return value. + new UserCodeMethodInvocation(returnVarIndex, targetMethod, instrumentedMethod), + // Run the afterDelegation manipulations. + // Either the stack is now empty (because the instrumentedMethod returns void) + // or the stack contains the return value. + afterDelegation(instrumentedMethod)); + StackManipulation.Size size = manipulation.apply(methodVisitor, implementationContext); - return new Size(size.getMaximalSize(), instrumentedMethod.getStackSize()); + return new Size(size.getMaximalSize(), numLocals); } }; } /** - * Generates code to invoke the target method. When this is called the delegate field will be on - * top of the stack. This should add any necessary arguments to the stack and then perform the - * method invocation. + * Return the code to the prepare the operand stack for the method delegation. + * + *

Before this method is called, the stack delegate will be the only thing on the stack. + * + *

After this method is called, the stack contents should contain exactly the arguments + * necessary to invoke the target method. */ - protected abstract StackManipulation invokeTargetMethod(MethodDescription instrumentedMethod); + protected abstract StackManipulation beforeDelegation(MethodDescription instrumentedMethod); + + /** + * Return the code to execute after the method delegation. + * + *

Before this method is called, the stack will either be empty (if the target method + * returns void) or contain the method return value. + * + *

After this method is called, the stack should either be empty (if the instrumented method + * returns void) or contain the avlue for the instrumented method to return). + */ + protected abstract StackManipulation afterDelegation(MethodDescription instrumentedMethod); } /** @@ -339,16 +394,13 @@ private static final class ProcessElementDelegation extends DoFnMethodDelegation /** Implementation of {@link MethodDelegation} for the {@link ProcessElement} method. */ private ProcessElementDelegation(DoFnSignature.ProcessElementMethod signature) { + super(new MethodDescription.ForLoadedMethod(signature.targetMethod())); this.signature = signature; } @Override - protected StackManipulation invokeTargetMethod(MethodDescription instrumentedMethod) { - MethodDescription targetMethod = - new MethodCall.MethodLocator.ForExplicitMethod( - new MethodDescription.ForLoadedMethod(signature.targetMethod())) - .resolve(instrumentedMethod); - + protected StackManipulation beforeDelegation( + MethodDescription instrumentedMethod) { // Parameters of the wrapper invoker method: // DoFn.ProcessContext, ExtraContextFactory. // Parameters of the wrapped DoFn method: @@ -368,26 +420,21 @@ protected StackManipulation invokeTargetMethod(MethodDescription instrumentedMet // Insert a downcast. (param == DoFnSignature.Parameter.RESTRICTION_TRACKER) ? TypeCasting.to( - new TypeDescription.ForLoadedType(signature.trackerT().getRawType())) + new TypeDescription.ForLoadedType(signature.trackerT().getRawType())) : StackManipulation.Trivial.INSTANCE)); } + return new StackManipulation.Compound(parameters); + } - return new StackManipulation.Compound( - // Push the parameters - new StackManipulation.Compound(parameters), - // Invoke the target method - wrapWithUserCodeException( - MethodDelegationBinder.MethodInvoker.Simple.INSTANCE.invoke(targetMethod), - targetMethod.getReturnType().asErasure(), - instrumentedMethod), - // Return from the instrumented @ProcessElement method: - // if it returns void, then return null (meaning don't resume), - // otherwise return the ProcessContinuation it returned. - signature.hasReturnValue() - ? MethodReturn.returning(targetMethod.getReturnType().asErasure()) - : new StackManipulation.Compound( - MethodInvocation.invoke(PROCESS_CONTINUATION_STOP_METHOD), - MethodReturn.REFERENCE)); + @Override + protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) { + if (TypeDescription.VOID.equals(targetMethod.getReturnType())) { + return new StackManipulation.Compound( + MethodInvocation.invoke(PROCESS_CONTINUATION_STOP_METHOD), + MethodReturn.REFERENCE); + } else { + return MethodReturn.returning(targetMethod.getReturnType().asErasure()); + } } } @@ -396,33 +443,20 @@ protected StackManipulation invokeTargetMethod(MethodDescription instrumentedMet * downcasting parameters to the proper type. */ private static class SimpleMethodDelegation extends DoFnMethodDelegation { - private final Method method; - SimpleMethodDelegation(Method method) { - this.method = method; + protected SimpleMethodDelegation(Method method) { + super(new MethodDescription.ForLoadedMethod(method)); } @Override - protected StackManipulation invokeTargetMethod(MethodDescription instrumentedMethod) { - MethodDescription targetMethod = - new MethodCall.MethodLocator.ForExplicitMethod( - new MethodDescription.ForLoadedMethod(method)) - .resolve(instrumentedMethod); - return new StackManipulation.Compound( - pushArgumentsOfInstrumentedMethod(targetMethod), - // Invoke the target method - wrapWithUserCodeException( - MethodDelegationBinder.MethodInvoker.Simple.INSTANCE.invoke(targetMethod), - targetMethod.getReturnType().asErasure(), - instrumentedMethod), - new StackManipulation.Compound( - // Return from the instrumented method - TargetMethodAnnotationDrivenBinder.TerminationHandler.Returning.INSTANCE.resolve( - Assigner.DEFAULT, instrumentedMethod, targetMethod))); + protected StackManipulation beforeDelegation(MethodDescription instrumentedMethod) { + return MethodVariableAccess.allArgumentsOf(targetMethod); } - protected StackManipulation pushArgumentsOfInstrumentedMethod(MethodDescription targetMethod) { - return MethodVariableAccess.allArgumentsOf(targetMethod); + @Override + protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) { + return TargetMethodAnnotationDrivenBinder.TerminationHandler.Returning.INSTANCE + .resolve(Assigner.DEFAULT, instrumentedMethod, targetMethod); } } @@ -436,7 +470,7 @@ private static class DowncastingParametersMethodDelegation extends SimpleMethodD } @Override - protected StackManipulation pushArgumentsOfInstrumentedMethod(MethodDescription targetMethod) { + protected StackManipulation beforeDelegation(MethodDescription instrumentedMethod) { List pushParameters = new ArrayList<>(); TypeList.Generic paramTypes = targetMethod.getParameters().asTypeList(); for (int i = 0; i < paramTypes.size(); i++) { @@ -450,15 +484,36 @@ protected StackManipulation pushArgumentsOfInstrumentedMethod(MethodDescription } } - private abstract static class BaseWrapUserCodeException implements StackManipulation { + private static class UserCodeMethodInvocation implements StackManipulation { + + @Nullable + private final Integer returnVarIndex; + private final MethodDescription targetMethod; + private final MethodDescription instrumentedMethod; + private final TypeDescription returnType; + + private final Label wrapStart = new Label(); + private final Label wrapEnd = new Label(); + private final Label tryBlockStart = new Label(); + private final Label tryBlockEnd = new Label(); + private final Label catchBlockStart = new Label(); + private final Label catchBlockEnd = new Label(); - /** {@link MethodDescription} for {@link UserCodeException#wrap}. */ private final MethodDescription createUserCodeException; - private final StackManipulation tryBody; + public UserCodeMethodInvocation( + Integer returnVarIndex, + MethodDescription targetMethod, + MethodDescription instrumentedMethod) { + this.returnVarIndex = returnVarIndex; + this.targetMethod = targetMethod; + this.instrumentedMethod = instrumentedMethod; + this.returnType = targetMethod.getReturnType().asErasure(); + + boolean targetMethodReturnsVoid = TypeDescription.VOID.equals(returnType); + checkArgument((returnVarIndex == null) == targetMethodReturnsVoid, + "returnLocalIndex should be defined if and only if the target method has a return value"); - BaseWrapUserCodeException(StackManipulation tryBody) { - this.tryBody = tryBody; try { createUserCodeException = new MethodDescription.ForLoadedMethod( @@ -470,23 +525,52 @@ private abstract static class BaseWrapUserCodeException implements StackManipula @Override public boolean isValid() { - return tryBody.isValid(); + return true; } - void visitFrameWithThrowableOnStack(MethodVisitor mv) { - String throwableName = new TypeDescription.ForLoadedType(Throwable.class).getInternalName(); - mv.visitFrame(Opcodes.F_SAME1, 0, new Object[] {}, 1, new Object[] {throwableName}); + private Object describeType(Type type) { + switch (type.getSort()) { + case Type.OBJECT: + return type.getDescriptor(); + case Type.INT: + case Type.BYTE: + case Type.BOOLEAN: + case Type.SHORT: + return Opcodes.INTEGER; + case Type.LONG: + return Opcodes.LONG; + case Type.DOUBLE: + return Opcodes.DOUBLE; + case Type.FLOAT: + return Opcodes.FLOAT; + default: + throw new IllegalArgumentException("Unhandled type as method argument: " + type); + } } - protected abstract void visitFrameWithReturnOnStack(MethodVisitor mv); + private void visitFrame(MethodVisitor mv, + boolean localsIncludeReturn, + @Nullable String stackTop) { + boolean hasReturnLocal = (returnVarIndex != null) && localsIncludeReturn; + + Type[] localTypes = Type.getArgumentTypes(instrumentedMethod.getDescriptor()); + Object[] locals = new Object[1 + localTypes.length + (hasReturnLocal ? 1 : 0)]; + locals[0] = instrumentedMethod.getReceiverType().asErasure().getInternalName(); + for (int i = 0; i < localTypes.length; i++) { + locals[i + 1] = describeType(localTypes[i]); + } + if (hasReturnLocal) { + locals[locals.length - 1] = returnType.getInternalName(); + } + + Object[] stack = stackTop == null ? new Object[] {} : new Object[] { stackTop }; + + mv.visitFrame(Opcodes.F_NEW, locals.length, locals, stack.length, stack); + } @Override - public Size apply(MethodVisitor mv, Context implementationContext) { - Label wrapStart = new Label(); - Label tryBlockStart = new Label(); - Label tryBlockEnd = new Label(); - Label catchBlockStart = new Label(); - Label catchBlockEnd = new Label(); + public Size apply(MethodVisitor mv, Context context) { + Size size = new Size(0, 0); mv.visitLabel(wrapStart); @@ -495,55 +579,48 @@ public Size apply(MethodVisitor mv, Context implementationContext) { // The try block attempts to perform the expected operations, then jumps to success mv.visitLabel(tryBlockStart); - Size trySize = tryBody.apply(mv, implementationContext); + size = size.aggregate(MethodInvocation.invoke(targetMethod).apply(mv, context)); - // After try body, should have same locals and the return type on the stack. - visitFrameWithReturnOnStack(mv); + if (returnVarIndex != null) { + mv.visitVarInsn(Opcodes.ASTORE, returnVarIndex); + size = size.aggregate(new Size(-1, 0)); // Reduces the size of the stack + } mv.visitJumpInsn(Opcodes.GOTO, catchBlockEnd); mv.visitLabel(tryBlockEnd); // The handler wraps the exception, and then throws. mv.visitLabel(catchBlockStart); // In catch block, should have same locals and {Throwable} on the stack. - visitFrameWithThrowableOnStack(mv); + visitFrame(mv, false, throwableName); - Size catchSize = + // Create the user code exception and throw + size = size.aggregate( new Compound(MethodInvocation.invoke(createUserCodeException), Throw.INSTANCE) - .apply(mv, implementationContext); + .apply(mv, context)); mv.visitLabel(catchBlockEnd); - // After catch block, should have same locals and the return type on the stack - visitFrameWithReturnOnStack(mv); - return new Size( - trySize.getSizeImpact() /* Same total size impact as wrapped body */, - Math.max(trySize.getMaximalSize(), catchSize.getMaximalSize())); - } - } + // After the catch block we should have the return in scope, but nothing on the stack. + visitFrame(mv, true, null); - /** - * Wraps a given stack manipulation in a try catch block. Any exceptions thrown within the try are - * wrapped with a {@link UserCodeException}. - */ - private static StackManipulation wrapWithUserCodeException( - StackManipulation tryBody, - final TypeDescription returnType, - MethodDescription instrumentedMethod) { - if (TypeDescription.VOID.equals(returnType)) { - return new BaseWrapUserCodeException(tryBody) { - @Override - protected void visitFrameWithReturnOnStack(MethodVisitor mv) { - mv.visitFrame(Opcodes.F_SAME, 0, new Object[] {}, 0, new Object[] {}); - } - }; - } else { - return new BaseWrapUserCodeException(tryBody) { - @Override - protected void visitFrameWithReturnOnStack(MethodVisitor mv) { - mv.visitFrame( - Opcodes.F_SAME1, 0, new Object[] {}, 1, new Object[] {returnType.getInternalName()}); - } - }; + // After catch block, should have same locals and will have the return on the stack. + if (returnVarIndex != null) { + mv.visitVarInsn(Opcodes.ALOAD, returnVarIndex); + size = size.aggregate(new Size(1, 0)); // Increases the size of the stack + } + mv.visitLabel(wrapEnd); + if (returnVarIndex != null) { + // Drop the return type from the locals + mv.visitLocalVariable( + "res", + returnType.getDescriptor(), + returnType.getGenericSignature(), + wrapStart, + wrapEnd, + returnVarIndex); + } + + return size; } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java index 20a887aa458f..da44a869420d 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.transforms.reflect; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.UserCodeException; @@ -87,7 +89,8 @@ public RestrictionTracker restrictionTracker() { }; } - private void checkInvokeProcessElementWorks(DoFn fn, Invocations... invocations) + private void checkInvokeProcessElementWorks( + DoFn fn, ProcessContinuation expected, Invocations... invocations) throws Exception { assertTrue("Need at least one invocation to check", invocations.length >= 1); for (Invocations invocation : invocations) { @@ -95,9 +98,10 @@ private void checkInvokeProcessElementWorks(DoFn fn, Invocations "Should not yet have called processElement on " + invocation.name, invocation.wasProcessElementInvoked); } - DoFnInvokers.INSTANCE + ProcessContinuation actual = DoFnInvokers.INSTANCE .newByteBuddyInvoker(fn) .invokeProcessElement(mockContext, extraContextFactory); + assertEquals("Should return proper continuation", expected, actual); for (Invocations invocation : invocations) { assertTrue( "Should have called processElement on " + invocation.name, @@ -182,7 +186,7 @@ public void processElement(ProcessContext c) throws Exception { .processElement() .usesSingleWindow()); - checkInvokeProcessElementWorks(fn, invocations); + checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), invocations); } @Test @@ -223,7 +227,7 @@ public void testDoFnWithProcessElementInterface() throws Exception { .getOrParseSignature(fn.getClass()) .processElement() .usesSingleWindow()); - checkInvokeProcessElementWorks(fn, fn.invocations); + checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), fn.invocations); } private class IdentityParent extends DoFn { @@ -256,7 +260,7 @@ public void testDoFnWithMethodInSuperclass() throws Exception { .getOrParseSignature(fn.getClass()) .processElement() .usesSingleWindow()); - checkInvokeProcessElementWorks(fn, fn.parentInvocations); + checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), fn.parentInvocations); } @Test @@ -267,7 +271,8 @@ public void testDoFnWithMethodInSubclass() throws Exception { .getOrParseSignature(fn.getClass()) .processElement() .usesSingleWindow()); - checkInvokeProcessElementWorks(fn, fn.parentInvocations, fn.childInvocations); + checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), + fn.parentInvocations, fn.childInvocations); } @Test @@ -289,7 +294,7 @@ public void processElement(ProcessContext c, BoundedWindow w) throws Exception { .processElement() .usesSingleWindow()); - checkInvokeProcessElementWorks(fn, invocations); + checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), invocations); } @Test @@ -311,7 +316,7 @@ public void processElement(ProcessContext c, OutputReceiver o) throws Ex .processElement() .usesSingleWindow()); - checkInvokeProcessElementWorks(fn, invocations); + checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), invocations); } @Test @@ -333,7 +338,31 @@ public void processElement(ProcessContext c, InputProvider i) throws Exc .processElement() .usesSingleWindow()); - checkInvokeProcessElementWorks(fn, invocations); + checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), invocations); + } + + @Test + public void testDoFnWithReturn() throws Exception { + final Invocations invocations = new Invocations("AnonymousClass"); + DoFn fn = + new DoFn() { + @ProcessElement + public ProcessContinuation processElement( + ProcessContext c, InputProvider i) throws Exception { + invocations.wasProcessElementInvoked = true; + assertSame(c, mockContext); + assertSame(i, mockInputProvider); + return ProcessContinuation.resume(); + } + }; + + assertFalse( + DoFnSignatures.INSTANCE + .getOrParseSignature(fn.getClass()) + .processElement() + .usesSingleWindow()); + + checkInvokeProcessElementWorks(fn, ProcessContinuation.resume(), invocations); } @Test @@ -408,49 +437,61 @@ public void processThis(ProcessContext c) { @Test public void testLocalPrivateDoFnClass() throws Exception { PrivateDoFnClass fn = new PrivateDoFnClass(); - checkInvokeProcessElementWorks(fn, fn.invocations); + checkInvokeProcessElementWorks(fn, ProcessContinuation.stop(), fn.invocations); } @Test public void testStaticPackagePrivateDoFnClass() throws Exception { Invocations invocations = new Invocations("StaticPackagePrivateDoFn"); checkInvokeProcessElementWorks( - DoFnInvokersTestHelper.newStaticPackagePrivateDoFn(invocations), invocations); + DoFnInvokersTestHelper.newStaticPackagePrivateDoFn(invocations), + ProcessContinuation.stop(), + invocations); } @Test public void testInnerPackagePrivateDoFnClass() throws Exception { Invocations invocations = new Invocations("InnerPackagePrivateDoFn"); checkInvokeProcessElementWorks( - new DoFnInvokersTestHelper().newInnerPackagePrivateDoFn(invocations), invocations); + new DoFnInvokersTestHelper().newInnerPackagePrivateDoFn(invocations), + ProcessContinuation.stop(), + invocations); } @Test public void testStaticPrivateDoFnClass() throws Exception { Invocations invocations = new Invocations("StaticPrivateDoFn"); checkInvokeProcessElementWorks( - DoFnInvokersTestHelper.newStaticPrivateDoFn(invocations), invocations); + DoFnInvokersTestHelper.newStaticPrivateDoFn(invocations), + ProcessContinuation.stop(), + invocations); } @Test public void testInnerPrivateDoFnClass() throws Exception { Invocations invocations = new Invocations("StaticInnerDoFn"); checkInvokeProcessElementWorks( - new DoFnInvokersTestHelper().newInnerPrivateDoFn(invocations), invocations); + new DoFnInvokersTestHelper().newInnerPrivateDoFn(invocations), + ProcessContinuation.stop(), + invocations); } @Test public void testAnonymousInnerDoFnInOtherPackage() throws Exception { Invocations invocations = new Invocations("AnonymousInnerDoFnInOtherPackage"); checkInvokeProcessElementWorks( - new DoFnInvokersTestHelper().newInnerAnonymousDoFn(invocations), invocations); + new DoFnInvokersTestHelper().newInnerAnonymousDoFn(invocations), + ProcessContinuation.stop(), + invocations); } @Test public void testStaticAnonymousDoFnInOtherPackage() throws Exception { Invocations invocations = new Invocations("AnonymousStaticDoFnInOtherPackage"); checkInvokeProcessElementWorks( - DoFnInvokersTestHelper.newStaticAnonymousDoFn(invocations), invocations); + DoFnInvokersTestHelper.newStaticAnonymousDoFn(invocations), + ProcessContinuation.stop(), + invocations); } @Test @@ -468,6 +509,22 @@ public void processElement(@SuppressWarnings("unused") ProcessContext c) { DoFnInvokers.INSTANCE.newByteBuddyInvoker(fn).invokeProcessElement(null, null); } + @Test + public void testProcessElementExceptionWithReturn() throws Exception { + DoFn fn = + new DoFn() { + @ProcessElement + public ProcessContinuation processElement(@SuppressWarnings("unused") ProcessContext c) { + throw new IllegalArgumentException("bogus"); + } + }; + + thrown.expect(UserCodeException.class); + thrown.expectMessage("bogus"); + DoFnInvokers.INSTANCE.newByteBuddyInvoker(fn).invokeProcessElement(null, null); + } + + @Test public void testStartBundleException() throws Exception { DoFn fn =