diff --git a/contrib/kafka/src/main/java/com/google/cloud/dataflow/contrib/kafka/KafkaIO.java b/contrib/kafka/src/main/java/com/google/cloud/dataflow/contrib/kafka/KafkaIO.java index a18062da56..47335bdbc9 100644 --- a/contrib/kafka/src/main/java/com/google/cloud/dataflow/contrib/kafka/KafkaIO.java +++ b/contrib/kafka/src/main/java/com/google/cloud/dataflow/contrib/kafka/KafkaIO.java @@ -44,7 +44,6 @@ import com.google.cloud.dataflow.sdk.values.PInput; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; -import com.google.common.base.Joiner; import com.google.common.base.Optional; import com.google.common.collect.ComparisonChain; import com.google.common.collect.ImmutableList; @@ -714,24 +713,12 @@ public UnboundedKafkaSource( this.consumerConfig = consumerConfig; } - /** - * The partitions are evenly distributed among the splits. The number of splits returned is - * {@code min(desiredNumSplits, totalNumPartitions)}, though better not to depend on the exact - * count. - * - *

It is important to assign the partitions deterministically so that we can support - * resuming a split from last checkpoint. The Kafka partitions are sorted by {@code } and then assigned to splits in round-robin order. - */ - @Override - public List> generateInitialSplits( - int desiredNumSplits, PipelineOptions options) throws Exception { + private List fetchKafkaPartitions() { List partitions = new ArrayList<>(assignedPartitions); // (a) fetch partitions for each topic // (b) sort by - // (c) round-robin assign the partitions to splits if (partitions.isEmpty()) { try (Consumer consumer = consumerFactoryFn.apply(consumerConfig)) { @@ -754,37 +741,39 @@ public int compare(TopicPartition tp1, TopicPartition tp2) { } }); - checkArgument(desiredNumSplits > 0); checkState( partitions.size() > 0, "Could not find any partitions. Please check Kafka configuration and topic names"); - int numSplits = Math.min(desiredNumSplits, partitions.size()); - List> assignments = new ArrayList<>(numSplits); + return partitions; + } - for (int i = 0; i < numSplits; i++) { - assignments.add(new ArrayList()); - } - for (int i = 0; i < partitions.size(); i++) { - assignments.get(i % numSplits).add(partitions.get(i)); - } + /** + * Returns one split for each of the Kafka partitions. + * + *

It is important to sort the partitions deterministically so that we can support + * resuming a split from last checkpoint. The Kafka partitions are sorted by {@code }. + */ + @Override + public List> generateInitialSplits( + int desiredNumSplits, PipelineOptions options) throws Exception { - List> result = new ArrayList<>(numSplits); + List partitions = fetchKafkaPartitions(); - for (int i = 0; i < numSplits; i++) { - List assignedToSplit = assignments.get(i); + List> result = new ArrayList<>(partitions.size()); + + // one split for each partition. + for (int i = 0; i < partitions.size(); i++) { + TopicPartition partition = partitions.get(i); - LOG.info( - "Partitions assigned to split {} (total {}): {}", - i, - assignedToSplit.size(), - Joiner.on(",").join(assignedToSplit)); + LOG.info("Partition assigned to split {} : {}", i, partition); result.add( - new UnboundedKafkaSource( + new UnboundedKafkaSource<>( i, this.topics, - assignedToSplit, + ImmutableList.of(partition), this.keyCoder, this.valueCoder, this.timestampFn, @@ -804,7 +793,17 @@ public UnboundedKafkaReader createReader( LOG.warn("Looks like generateSplits() is not called. Generate single split."); try { return new UnboundedKafkaReader( - generateInitialSplits(1, options).get(0), checkpointMark); + new UnboundedKafkaSource<>( + 0, + this.topics, + fetchKafkaPartitions(), + this.keyCoder, + this.valueCoder, + this.timestampFn, + this.watermarkFn, + this.consumerFactoryFn, + this.consumerConfig), + checkpointMark); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/contrib/kafka/src/test/java/com/google/cloud/dataflow/contrib/kafka/KafkaIOTest.java b/contrib/kafka/src/test/java/com/google/cloud/dataflow/contrib/kafka/KafkaIOTest.java index af8d674831..19d4336a5b 100644 --- a/contrib/kafka/src/test/java/com/google/cloud/dataflow/contrib/kafka/KafkaIOTest.java +++ b/contrib/kafka/src/test/java/com/google/cloud/dataflow/contrib/kafka/KafkaIOTest.java @@ -26,7 +26,6 @@ import com.google.cloud.dataflow.sdk.io.Read; import com.google.cloud.dataflow.sdk.io.UnboundedSource; import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader; -import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; import com.google.cloud.dataflow.sdk.testing.DataflowAssert; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Count; @@ -172,8 +171,8 @@ public Consumer apply(Map config) { } /** - * Creates a consumer with two topics, with 5 partitions each. numElements are (round-robin) - * assigned all the 10 partitions. + * Creates a consumer with two topics, with 10 partitions each. numElements are (round-robin) + * assigned all the 20 partitions. */ private static KafkaIO.TypedRead mkKafkaReadTransform( int numElements, @Nullable SerializableFunction, Instant> timestampFn) { @@ -309,16 +308,13 @@ public void processElement(ProcessContext ctx) throws Exception { public void testUnboundedSourceSplits() throws Exception { Pipeline p = TestPipeline.create(); int numElements = 1000; - int numSplits = 10; + int numSplits = 20; UnboundedSource, ?> initial = mkKafkaReadTransform(numElements, null).makeSource(); - List< - ? extends - UnboundedSource< - com.google.cloud.dataflow.contrib.kafka.KafkaRecord, ?>> - splits = initial.generateInitialSplits(numSplits, p.getOptions()); - assertEquals("Expected exact splitting", numSplits, splits.size()); + List, ?>> splits = + initial.generateInitialSplits(1, p.getOptions()); + assertEquals("KafkaIO should ignore desiredNumSplits", numSplits, splits.size()); long elementsPerSplit = numElements / numSplits; assertEquals("Expected even splits", numElements, elementsPerSplit * numSplits); @@ -368,9 +364,7 @@ public void testUnboundedSourceCheckpointMark() throws Exception { com.google.cloud.dataflow.contrib.kafka.KafkaCheckpointMark> source = mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) - .makeSource() - .generateInitialSplits(1, PipelineOptionsFactory.fromArgs(new String[0]).create()) - .get(0); + .makeSource(); UnboundedReader> reader = source.createReader(null, null);