diff --git a/contrib/kafka/src/main/java/com/google/cloud/dataflow/contrib/kafka/KafkaIO.java b/contrib/kafka/src/main/java/com/google/cloud/dataflow/contrib/kafka/KafkaIO.java
index a18062da56..47335bdbc9 100644
--- a/contrib/kafka/src/main/java/com/google/cloud/dataflow/contrib/kafka/KafkaIO.java
+++ b/contrib/kafka/src/main/java/com/google/cloud/dataflow/contrib/kafka/KafkaIO.java
@@ -44,7 +44,6 @@
import com.google.cloud.dataflow.sdk.values.PInput;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
-import com.google.common.base.Joiner;
import com.google.common.base.Optional;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
@@ -714,24 +713,12 @@ public UnboundedKafkaSource(
this.consumerConfig = consumerConfig;
}
- /**
- * The partitions are evenly distributed among the splits. The number of splits returned is
- * {@code min(desiredNumSplits, totalNumPartitions)}, though better not to depend on the exact
- * count.
- *
- *
It is important to assign the partitions deterministically so that we can support
- * resuming a split from last checkpoint. The Kafka partitions are sorted by {@code } and then assigned to splits in round-robin order.
- */
- @Override
- public List> generateInitialSplits(
- int desiredNumSplits, PipelineOptions options) throws Exception {
+ private List fetchKafkaPartitions() {
List partitions = new ArrayList<>(assignedPartitions);
// (a) fetch partitions for each topic
// (b) sort by
- // (c) round-robin assign the partitions to splits
if (partitions.isEmpty()) {
try (Consumer, ?> consumer = consumerFactoryFn.apply(consumerConfig)) {
@@ -754,37 +741,39 @@ public int compare(TopicPartition tp1, TopicPartition tp2) {
}
});
- checkArgument(desiredNumSplits > 0);
checkState(
partitions.size() > 0,
"Could not find any partitions. Please check Kafka configuration and topic names");
- int numSplits = Math.min(desiredNumSplits, partitions.size());
- List> assignments = new ArrayList<>(numSplits);
+ return partitions;
+ }
- for (int i = 0; i < numSplits; i++) {
- assignments.add(new ArrayList());
- }
- for (int i = 0; i < partitions.size(); i++) {
- assignments.get(i % numSplits).add(partitions.get(i));
- }
+ /**
+ * Returns one split for each of the Kafka partitions.
+ *
+ * It is important to sort the partitions deterministically so that we can support
+ * resuming a split from last checkpoint. The Kafka partitions are sorted by {@code }.
+ */
+ @Override
+ public List> generateInitialSplits(
+ int desiredNumSplits, PipelineOptions options) throws Exception {
- List> result = new ArrayList<>(numSplits);
+ List partitions = fetchKafkaPartitions();
- for (int i = 0; i < numSplits; i++) {
- List assignedToSplit = assignments.get(i);
+ List> result = new ArrayList<>(partitions.size());
+
+ // one split for each partition.
+ for (int i = 0; i < partitions.size(); i++) {
+ TopicPartition partition = partitions.get(i);
- LOG.info(
- "Partitions assigned to split {} (total {}): {}",
- i,
- assignedToSplit.size(),
- Joiner.on(",").join(assignedToSplit));
+ LOG.info("Partition assigned to split {} : {}", i, partition);
result.add(
- new UnboundedKafkaSource(
+ new UnboundedKafkaSource<>(
i,
this.topics,
- assignedToSplit,
+ ImmutableList.of(partition),
this.keyCoder,
this.valueCoder,
this.timestampFn,
@@ -804,7 +793,17 @@ public UnboundedKafkaReader createReader(
LOG.warn("Looks like generateSplits() is not called. Generate single split.");
try {
return new UnboundedKafkaReader(
- generateInitialSplits(1, options).get(0), checkpointMark);
+ new UnboundedKafkaSource<>(
+ 0,
+ this.topics,
+ fetchKafkaPartitions(),
+ this.keyCoder,
+ this.valueCoder,
+ this.timestampFn,
+ this.watermarkFn,
+ this.consumerFactoryFn,
+ this.consumerConfig),
+ checkpointMark);
} catch (Exception e) {
throw new RuntimeException(e);
}
diff --git a/contrib/kafka/src/test/java/com/google/cloud/dataflow/contrib/kafka/KafkaIOTest.java b/contrib/kafka/src/test/java/com/google/cloud/dataflow/contrib/kafka/KafkaIOTest.java
index af8d674831..19d4336a5b 100644
--- a/contrib/kafka/src/test/java/com/google/cloud/dataflow/contrib/kafka/KafkaIOTest.java
+++ b/contrib/kafka/src/test/java/com/google/cloud/dataflow/contrib/kafka/KafkaIOTest.java
@@ -26,7 +26,6 @@
import com.google.cloud.dataflow.sdk.io.Read;
import com.google.cloud.dataflow.sdk.io.UnboundedSource;
import com.google.cloud.dataflow.sdk.io.UnboundedSource.UnboundedReader;
-import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
import com.google.cloud.dataflow.sdk.testing.TestPipeline;
import com.google.cloud.dataflow.sdk.transforms.Count;
@@ -172,8 +171,8 @@ public Consumer apply(Map config) {
}
/**
- * Creates a consumer with two topics, with 5 partitions each. numElements are (round-robin)
- * assigned all the 10 partitions.
+ * Creates a consumer with two topics, with 10 partitions each. numElements are (round-robin)
+ * assigned all the 20 partitions.
*/
private static KafkaIO.TypedRead mkKafkaReadTransform(
int numElements, @Nullable SerializableFunction, Instant> timestampFn) {
@@ -309,16 +308,13 @@ public void processElement(ProcessContext ctx) throws Exception {
public void testUnboundedSourceSplits() throws Exception {
Pipeline p = TestPipeline.create();
int numElements = 1000;
- int numSplits = 10;
+ int numSplits = 20;
UnboundedSource, ?> initial =
mkKafkaReadTransform(numElements, null).makeSource();
- List<
- ? extends
- UnboundedSource<
- com.google.cloud.dataflow.contrib.kafka.KafkaRecord, ?>>
- splits = initial.generateInitialSplits(numSplits, p.getOptions());
- assertEquals("Expected exact splitting", numSplits, splits.size());
+ List extends UnboundedSource, ?>> splits =
+ initial.generateInitialSplits(1, p.getOptions());
+ assertEquals("KafkaIO should ignore desiredNumSplits", numSplits, splits.size());
long elementsPerSplit = numElements / numSplits;
assertEquals("Expected even splits", numElements, elementsPerSplit * numSplits);
@@ -368,9 +364,7 @@ public void testUnboundedSourceCheckpointMark() throws Exception {
com.google.cloud.dataflow.contrib.kafka.KafkaCheckpointMark>
source =
mkKafkaReadTransform(numElements, new ValueAsTimestampFn())
- .makeSource()
- .generateInitialSplits(1, PipelineOptionsFactory.fromArgs(new String[0]).create())
- .get(0);
+ .makeSource();
UnboundedReader> reader =
source.createReader(null, null);