Skip to content

Commit

Permalink
Pull ack batching out of the source task
Browse files Browse the repository at this point in the history
  • Loading branch information
dpcollins-google committed Mar 29, 2021
1 parent 1859512 commit c2eb727
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 80 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.google.pubsub.kafka.source;

import com.google.api.core.ApiFuture;
import com.google.api.core.ApiFutureCallback;
import com.google.api.core.ApiFutures;
import com.google.api.core.SettableApiFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.protobuf.Empty;
import com.google.pubsub.v1.ReceivedMessage;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.tuple.Pair;

public class AckBatchingSubscriber implements CloudPubSubSubscriber {

private final CloudPubSubSubscriber underlying;
@GuardedBy("this")
private final Deque<Pair<Collection<String>, SettableApiFuture<Empty>>> toSend = new ArrayDeque<>();
private final ScheduledFuture<?> alarm;

public AckBatchingSubscriber(CloudPubSubSubscriber underlying,
ScheduledExecutorService executor) {
this.underlying = underlying;
this.alarm = executor.scheduleAtFixedRate(this::flush, 100, 100, TimeUnit.MILLISECONDS);
}

@Override
public ApiFuture<List<ReceivedMessage>> pull() {
return underlying.pull();
}

@Override
public synchronized ApiFuture<Empty> ackMessages(Collection<String> ackIds) {
SettableApiFuture<Empty> result = SettableApiFuture.create();
toSend.add(Pair.of(ackIds, result));
return result;
}

private void flush() {
List<String> ackIds = new ArrayList<>();
List<SettableApiFuture<Empty>> futures = new ArrayList<>();
synchronized (this) {
if (toSend.isEmpty()) {
return;
}
toSend.forEach(pair -> {
ackIds.addAll(pair.getLeft());
futures.add(pair.getRight());
});
toSend.clear();
}
ApiFuture<Empty> response = underlying.ackMessages(ackIds);
ApiFutures.addCallback(response, new ApiFutureCallback<Empty>() {
@Override
public void onFailure(Throwable t) {
futures.forEach(future -> future.setException(t));
}

@Override
public void onSuccess(Empty result) {
futures.forEach(future -> future.set(result));
}
}, MoreExecutors.directExecutor());
}

@Override
public void close() {
alarm.cancel(false);
flush();
underlying.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,20 @@
////////////////////////////////////////////////////////////////////////////////
package com.google.pubsub.kafka.source;

import com.google.api.core.ApiFuture;
import com.google.api.gax.batching.FlowControlSettings;
import com.google.api.gax.batching.FlowController.LimitExceededBehavior;
import com.google.cloud.pubsub.v1.Subscriber;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
import com.google.protobuf.util.Timestamps;
import com.google.pubsub.kafka.common.ConnectorUtils;
import com.google.pubsub.kafka.common.ConnectorCredentialsProvider;
import com.google.pubsub.kafka.source.CloudPubSubSourceConnector.PartitionScheme;
import com.google.pubsub.v1.ProjectSubscriptionName;
import com.google.pubsub.v1.PubsubMessage;
import com.google.pubsub.v1.ReceivedMessage;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
Expand All @@ -39,6 +37,8 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import org.apache.kafka.connect.data.Field;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
Expand All @@ -51,13 +51,15 @@

/**
* A {@link SourceTask} used by a {@link CloudPubSubSourceConnector} to write messages to <a
* href="http://kafka.apache.org/">Apache Kafka</a>. Due to at-last-once semantics in Google
* Cloud Pub/Sub duplicates in Kafka are possible.
* href="http://kafka.apache.org/">Apache Kafka</a>. Due to at-last-once semantics in Google Cloud
* Pub/Sub duplicates in Kafka are possible.
*/
public class CloudPubSubSourceTask extends SourceTask {

private static final Logger log = LoggerFactory.getLogger(CloudPubSubSourceTask.class);
private static final int NUM_CPS_SUBSCRIBERS = 10;
private static final ScheduledExecutorService ACK_EXECUTOR =
MoreExecutors.getExitingScheduledExecutorService(new ScheduledThreadPoolExecutor(4));

private String kafkaTopic;
private ProjectSubscriptionName cpsSubscription;
Expand All @@ -68,15 +70,12 @@ public class CloudPubSubSourceTask extends SourceTask {
private PartitionScheme kafkaPartitionScheme;
// Keeps track of the current partition to publish to if the partition scheme is round robin.
private int currentRoundRobinPartition = -1;
// Keep track of all ack ids that have not been sent correctly acked yet.
private final Set<String> deliveredAckIds = Collections.synchronizedSet(new HashSet<String>());
private CloudPubSubSubscriber subscriber;
private final Set<String> ackIdsInFlight = Collections.synchronizedSet(new HashSet<String>());
private final Set<String> standardAttributes = new HashSet<>();
private boolean useKafkaHeaders;
private final Executor ackExecutor = Executors.newCachedThreadPool();

public CloudPubSubSourceTask() {}
public CloudPubSubSourceTask() {
}

@VisibleForTesting
public CloudPubSubSourceTask(CloudPubSubSubscriber subscriber) {
Expand All @@ -90,8 +89,6 @@ public String version() {

@Override
public void start(Map<String, String> props) {
deliveredAckIds.clear();
ackIdsInFlight.clear();
Map<String, Object> validatedProps = new CloudPubSubSourceConnector().config().parse(props);
cpsSubscription = ProjectSubscriptionName.newBuilder()
.setProject(validatedProps.get(ConnectorUtils.CPS_PROJECT_CONFIG).toString())
Expand Down Expand Up @@ -152,9 +149,10 @@ public void start(Map<String, String> props) {
.setEndpoint(cpsEndpoint)
.build());
} else {
subscriber = new CloudPubSubRoundRobinSubscriber(NUM_CPS_SUBSCRIBERS,
gcpCredentialsProvider,
cpsEndpoint, cpsSubscription, cpsMaxBatchSize);
subscriber = new AckBatchingSubscriber(
new CloudPubSubRoundRobinSubscriber(NUM_CPS_SUBSCRIBERS,
gcpCredentialsProvider,
cpsEndpoint, cpsSubscription, cpsMaxBatchSize), ACK_EXECUTOR);
}
}
standardAttributes.add(kafkaMessageKeyAttribute);
Expand All @@ -164,7 +162,6 @@ public void start(Map<String, String> props) {

@Override
public List<SourceRecord> poll() throws InterruptedException {
ackMessages();
log.debug("Polling...");
try {
List<ReceivedMessage> response = subscriber.pull().get();
Expand All @@ -182,7 +179,7 @@ public List<SourceRecord> poll() throws InterruptedException {
key = messageAttributes.get(kafkaMessageKeyAttribute);
}
Long timestamp = getLongValue(messageAttributes.get(kafkaMessageTimestampAttribute));
if (timestamp == null){
if (timestamp == null) {
timestamp = Timestamps.toMillis(message.getPublishTime());
}
ByteString messageData = message.getData();
Expand All @@ -205,16 +202,16 @@ record =
}
} else {
record =
new SourceRecord(
null,
ack,
kafkaTopic,
selectPartition(key, messageBytes, orderingKey),
Schema.OPTIONAL_STRING_SCHEMA,
key,
Schema.BYTES_SCHEMA,
messageBytes,
timestamp);
new SourceRecord(
null,
ack,
kafkaTopic,
selectPartition(key, messageBytes, orderingKey),
Schema.OPTIONAL_STRING_SCHEMA,
key,
Schema.BYTES_SCHEMA,
messageBytes,
timestamp);
}
sourceRecords.add(record);
}
Expand All @@ -234,7 +231,7 @@ private SourceRecord createRecordWithHeaders(
Long timestamp) {
ConnectHeaders headers = new ConnectHeaders();
for (Entry<String, String> attribute :
messageAttributes.entrySet()) {
messageAttributes.entrySet()) {
if (!attribute.getKey().equals(kafkaMessageKeyAttribute)) {
headers.addString(attribute.getKey(), attribute.getValue());
}
Expand All @@ -244,16 +241,16 @@ private SourceRecord createRecordWithHeaders(
}

return new SourceRecord(
null,
ack,
kafkaTopic,
selectPartition(key, messageBytes, orderingKey),
Schema.OPTIONAL_STRING_SCHEMA,
key,
Schema.BYTES_SCHEMA,
messageBytes,
timestamp,
headers);
null,
ack,
kafkaTopic,
selectPartition(key, messageBytes, orderingKey),
Schema.OPTIONAL_STRING_SCHEMA,
key,
Schema.BYTES_SCHEMA,
messageBytes,
timestamp,
headers);
}

private SourceRecord createRecordWithStruct(
Expand Down Expand Up @@ -292,49 +289,20 @@ private SourceRecord createRecordWithStruct(
}
}
return new SourceRecord(
null,
ack,
kafkaTopic,
selectPartition(key, value, orderingKey),
Schema.OPTIONAL_STRING_SCHEMA,
key,
valueSchema,
value,
timestamp);
}

@Override
public void commit() throws InterruptedException {
ackMessages();
null,
ack,
kafkaTopic,
selectPartition(key, value, orderingKey),
Schema.OPTIONAL_STRING_SCHEMA,
key,
valueSchema,
value,
timestamp);
}

/**
* Attempt to ack all ids in {@link #deliveredAckIds}.
* Return the partition a message should go to based on {@link #kafkaPartitionScheme}.
*/
private void ackMessages() {
if (deliveredAckIds.size() != 0) {
final Set<String> ackIdsBatch = new HashSet<>();
synchronized (deliveredAckIds) {
ackIdsInFlight.addAll(deliveredAckIds);
ackIdsBatch.addAll(deliveredAckIds);
deliveredAckIds.clear();
}
final ApiFuture<Empty> response = subscriber.ackMessages(ackIdsBatch);
response.addListener(() -> {
try {
response.get();
log.trace("Successfully acked a set of messages. {}", ackIdsBatch.size());
} catch (Exception e) {
deliveredAckIds.addAll(ackIdsBatch);
log.error("An exception occurred acking messages: " + e);
} finally {
ackIdsInFlight.removeAll(ackIdsBatch);
}
}, ackExecutor);
}
}

/** Return the partition a message should go to based on {@link #kafkaPartitionScheme}. */
private Integer selectPartition(Object key, Object value, String orderingKey) {
if (kafkaPartitionScheme.equals(PartitionScheme.HASH_KEY)) {
return key == null ? 0 : Math.abs(key.hashCode()) % kafkaPartitions;
Expand All @@ -343,7 +311,7 @@ private Integer selectPartition(Object key, Object value, String orderingKey) {
} else if (kafkaPartitionScheme.equals(PartitionScheme.KAFKA_PARTITIONER)) {
return null;
} else if (kafkaPartitionScheme.equals(PartitionScheme.ORDERING_KEY) && orderingKey != null &&
!orderingKey.isEmpty()) {
!orderingKey.isEmpty()) {
return Math.abs(orderingKey.hashCode()) % kafkaPartitions;
} else {
currentRoundRobinPartition = ++currentRoundRobinPartition % kafkaPartitions;
Expand Down Expand Up @@ -373,7 +341,7 @@ public void stop() {
@Override
public void commitRecord(SourceRecord record) {
String ackId = record.sourceOffset().get(cpsSubscription.toString()).toString();
deliveredAckIds.add(ackId);
subscriber.ackMessages(ImmutableList.of(ackId));
log.trace("Committed {}", ackId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.google.pubsub.kafka.source;

public class AckBatchingSubscriberTest {

}

0 comments on commit c2eb727

Please sign in to comment.