diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSource.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSource.java index 575fe39771..4da8ad1bad 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSource.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSource.java @@ -30,6 +30,7 @@ import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.options.ValueProvider; +import com.google.cloud.dataflow.sdk.options.ValueProvider.StaticValueProvider; import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.Combine; import com.google.cloud.dataflow.sdk.transforms.DoFn; @@ -1290,6 +1291,7 @@ public String getIdLabel() { @Override public PCollection apply(PBegin input) { + ValueProvider subscriptionPath = subscription; if (subscription == null) { try { try (PubsubClient pubsubClient = @@ -1299,9 +1301,8 @@ public PCollection apply(PBegin input) { .as(DataflowPipelineOptions.class))) { checkState(project.isAccessible(), "createRandomSubscription must be called at runtime."); checkState(topic.isAccessible(), "createRandomSubscription must be called at runtime."); - SubscriptionPath subscriptionPath = - pubsubClient.createRandomSubscription( - project.get(), topic.get(), DEAULT_ACK_TIMEOUT_SEC); + subscriptionPath = StaticValueProvider.of(pubsubClient.createRandomSubscription( + project.get(), topic.get(), DEAULT_ACK_TIMEOUT_SEC)); LOG.warn("Created subscription {} to topic {}." + " Note this subscription WILL NOT be deleted when the pipeline terminates", subscription, topic); @@ -1314,7 +1315,7 @@ public PCollection apply(PBegin input) { return input.getPipeline().begin() .apply(Read.from(new PubsubSource(this))) .apply(ParDo.named("PubsubUnboundedSource.Stats") - .of(new StatsFn(pubsubFactory, subscription, - timestampLabel, idLabel))); + .of(new StatsFn(pubsubFactory, checkNotNull(subscriptionPath), + timestampLabel, idLabel))); } } diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubTestClient.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubTestClient.java index c3a5a4e959..01831218d4 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubTestClient.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/PubsubTestClient.java @@ -107,6 +107,11 @@ private static class State { */ @Nullable Map ackDeadline; + + /** + * Whether a subscription has been created. + */ + boolean createdSubscription; } private static final State STATE = new State(); @@ -124,12 +129,40 @@ public static PubsubTestClientFactory createFactoryForPublish( final TopicPath expectedTopic, final Iterable expectedOutgoingMessages, final Iterable failingOutgoingMessages) { + return createFactoryForPublishInternal( + expectedTopic, expectedOutgoingMessages, failingOutgoingMessages, false); + } + + /** + * Return a factory for testing publishers. Only one factory may be in-flight at a time. + * The factory must be closed when the test is complete, at which point final validation will + * occur. Additionally, verify that createSubscription was called. + */ + public static PubsubTestClientFactory createFactoryForPublishVerifySubscription( + final TopicPath expectedTopic, + final Iterable expectedOutgoingMessages, + final Iterable failingOutgoingMessages) { + return createFactoryForPublishInternal( + expectedTopic, expectedOutgoingMessages, failingOutgoingMessages, true); + } + + /** + * Return a factory for testing publishers. Only one factory may be in-flight at a time. + * The factory must be closed when the test is complete, at which point final validation will + * occur. + */ + public static PubsubTestClientFactory createFactoryForPublishInternal( + final TopicPath expectedTopic, + final Iterable expectedOutgoingMessages, + final Iterable failingOutgoingMessages, + final boolean verifySubscriptionCreated) { synchronized (STATE) { checkState(!STATE.isActive, "Test still in flight"); STATE.expectedTopic = expectedTopic; STATE.remainingExpectedOutgoingMessages = Sets.newHashSet(expectedOutgoingMessages); STATE.remainingFailingOutgoingMessages = Sets.newHashSet(failingOutgoingMessages); STATE.isActive = true; + STATE.createdSubscription = false; } return new PubsubTestClientFactory() { @Override @@ -148,6 +181,9 @@ public String getKind() { @Override public void close() { synchronized (STATE) { + if (verifySubscriptionCreated) { + checkState(STATE.createdSubscription, "Did not call create subscription"); + } checkState(STATE.isActive, "No test still in flight"); checkState(STATE.remainingExpectedOutgoingMessages.isEmpty(), "Still waiting for %s messages to be published", @@ -372,7 +408,10 @@ public List listTopics(ProjectPath project) throws IOException { @Override public void createSubscription( TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException { - throw new UnsupportedOperationException(); + synchronized (STATE) { + STATE.createdSubscription = true; + } + return; } @Override diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSourceTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSourceTest.java index f7e4f863de..65fdf737af 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSourceTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/PubsubUnboundedSourceTest.java @@ -36,7 +36,10 @@ import com.google.cloud.dataflow.sdk.util.CoderUtils; import com.google.cloud.dataflow.sdk.util.PubsubClient; import com.google.cloud.dataflow.sdk.util.PubsubClient.IncomingMessage; +import com.google.cloud.dataflow.sdk.util.PubsubClient.OutgoingMessage; +import com.google.cloud.dataflow.sdk.util.PubsubClient.ProjectPath; import com.google.cloud.dataflow.sdk.util.PubsubClient.SubscriptionPath; +import com.google.cloud.dataflow.sdk.util.PubsubClient.TopicPath; import com.google.cloud.dataflow.sdk.util.PubsubTestClient; import com.google.cloud.dataflow.sdk.util.PubsubTestClient.PubsubTestClientFactory; @@ -60,8 +63,12 @@ */ @RunWith(JUnit4.class) public class PubsubUnboundedSourceTest { + private static final ProjectPath PROJECT = + PubsubClient.projectPathFromId("testProject"); private static final SubscriptionPath SUBSCRIPTION = PubsubClient.subscriptionPathFromName("testProject", "testSubscription"); + private static final TopicPath TOPIC = + PubsubClient.topicPathFromName("testProject", "testTopic"); private static final String DATA = "testData"; private static final long TIMESTAMP = 1234L; private static final long REQ_TIME = 6373L; @@ -320,4 +327,14 @@ public void readManyMessages() throws IOException { assertTrue(dataToMessageNum.isEmpty()); reader.close(); } + + @Test + public void testNullSubscription() throws Exception { + factory = PubsubTestClient.createFactoryForPublishVerifySubscription( + TOPIC, ImmutableList.of(), ImmutableList.of()); + TestPipeline p = TestPipeline.create(); + p.apply(new PubsubUnboundedSource<>( + clock, factory, StaticValueProvider.of(PROJECT), StaticValueProvider.of(TOPIC), + null, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL)); + } }