Skip to content

Commit

Permalink
Do not add source messages to a checkpoint until after it is emitted (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-kwong authored May 29, 2024
1 parent bf07d37 commit 116046e
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import com.google.pubsub.flink.internal.source.enumerator.PubSubCheckpointSerializer;
import com.google.pubsub.flink.internal.source.enumerator.PubSubSplitEnumerator;
import com.google.pubsub.flink.internal.source.reader.AckTracker;
import com.google.pubsub.flink.internal.source.reader.PubSubAckTracker;
import com.google.pubsub.flink.internal.source.reader.PubSubNotifyingPullSubscriber;
import com.google.pubsub.flink.internal.source.reader.PubSubRecordEmitter;
import com.google.pubsub.flink.internal.source.reader.PubSubSourceReader;
import com.google.pubsub.flink.internal.source.reader.PubSubSplitReader;
import com.google.pubsub.flink.internal.source.split.SubscriptionSplit;
Expand Down Expand Up @@ -156,7 +156,8 @@ public UserCodeClassLoader getUserCodeClassLoader() {
}
});
return new PubSubSourceReader<>(
new PubSubRecordEmitter<>(schema),
schema,
new PubSubAckTracker(),
this::createSplitReader,
new Configuration(),
readerContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,38 @@

import com.google.cloud.pubsub.v1.AckReplyConsumer;

/** This class tracks the lifecycle of messages in {@link PubSubSource}. */
public interface AckTracker {
void addPendingAck(AckReplyConsumer ackReplyConsumer);
/**
* Track a new pending ack. Acks are pending when a message has been received but not yet
* processed by the Flink pipeline.
*
* <p>If there is already a pending ack for {@code messageId}, the existing ack is replaced.
*/
void addPendingAck(String messageId, AckReplyConsumer ackReplyConsumer);

/**
* Stage a pending ack for the next checkpoint snapshot. Staged acks indicate that a message has
* been emitted to the Flink pipeline and should be included in the next checkpoint.
*/
void stagePendingAck(String messageId);

/**
* Prepare all staged acks to be acknowledged to Google Cloud Pub/Sub when checkpoint {@code
* checkpointId} completes.
*/
void addCheckpoint(long checkpointId);

/**
* Acknowledge all staged acks in checkpoint {@code checkpointId} and stop tracking them in this
* {@link AckTracker}.
*/
void notifyCheckpointComplete(long checkpointId);

/**
* Negatively acknowledge (nack) and stop tracking all acks currently tracked by this {@link
* AckTracker}. Nacked messages are eligible for redelivery by Google Cloud Pub/Sub before the
* message's ack deadline expires.
*/
void nackAll();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,39 @@
import com.google.cloud.pubsub.v1.AckReplyConsumer;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;

public class PubSubAckTracker implements AckTracker {
@GuardedBy("this")
private final List<AckReplyConsumer> pendingAcks = new ArrayList<>();
private final Map<String, AckReplyConsumer> pendingAcks = new HashMap<>();

@GuardedBy("this")
private final List<AckReplyConsumer> stagedAcks = new ArrayList<>();

@GuardedBy("this")
private final SortedMap<Long, List<AckReplyConsumer>> checkpoints = new TreeMap<>();

@Override
public synchronized void addPendingAck(AckReplyConsumer ackReplyConsumer) {
pendingAcks.add(ackReplyConsumer);
public synchronized void addPendingAck(String messageId, AckReplyConsumer ackReplyConsumer) {
pendingAcks.put(messageId, ackReplyConsumer);
}

@Override
public synchronized void stagePendingAck(String messageId) {
AckReplyConsumer ackToStage = pendingAcks.remove(messageId);
if (ackToStage != null) {
stagedAcks.add(ackToStage);
}
}

@Override
public synchronized void addCheckpoint(long checkpointId) {
checkpoints.put(checkpointId, new ArrayList<>(pendingAcks));
pendingAcks.clear();
checkpoints.put(checkpointId, new ArrayList<>(stagedAcks));
stagedAcks.clear();
}

@Override
Expand All @@ -51,8 +64,10 @@ public synchronized void notifyCheckpointComplete(long checkpointId) {

@Override
public synchronized void nackAll() {
pendingAcks.forEach((ackReplyConsumer) -> ackReplyConsumer.nack());
pendingAcks.values().forEach((ackReplyConsumer) -> ackReplyConsumer.nack());
pendingAcks.clear();
stagedAcks.forEach((ackReplyConsumer) -> ackReplyConsumer.nack());
stagedAcks.clear();
checkpoints
.values()
.forEach(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import com.google.pubsub.v1.PubsubMessage;
import java.util.ArrayDeque;
import java.util.Deque;
import org.apache.flink.api.java.tuple.Tuple2;

public class PubSubNotifyingPullSubscriber implements NotifyingPullSubscriber {
public static class SubscriberWakeupException extends Exception {}
Expand All @@ -50,7 +49,7 @@ public interface SubscriberFactory {
private Optional<SettableApiFuture<Void>> notification = Optional.absent();

@GuardedBy("this")
private final Deque<Tuple2<PubsubMessage, AckReplyConsumer>> messages = new ArrayDeque<>();
private final Deque<PubsubMessage> messages = new ArrayDeque<>();

private final AckTracker ackTracker;

Expand Down Expand Up @@ -91,9 +90,7 @@ public synchronized Optional<PubsubMessage> pullMessage() throws Throwable {
if (messages.isEmpty()) {
return Optional.absent();
}
Tuple2<PubsubMessage, AckReplyConsumer> message = messages.pop();
ackTracker.addPendingAck(message.f1);
return Optional.of(message.f0);
return Optional.of(messages.pop());
}

@Override
Expand All @@ -105,9 +102,8 @@ public void interruptNotify() {
public void shutdown() {
setPermanentError(new SubscriberShutdownException());
completeNotification(permanentError);
// Nack all outstanding messages, so that they are redelivered quickly.
messages.forEach((tuple) -> tuple.f1.nack());
messages.clear();
// Nack all outstanding messages, so that they are redelivered quickly.
ackTracker.nackAll();
subscriber.stopAsync().awaitTerminated();
}
Expand All @@ -119,7 +115,8 @@ private synchronized void receiveMessage(
completeNotification(permanentError);
return;
}
messages.add(Tuple2.of(message, ackReplyConsumer));
ackTracker.addPendingAck(message.getMessageId(), ackReplyConsumer);
messages.add(message);
completeNotification(Optional.absent());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@
public class PubSubRecordEmitter<T>
implements RecordEmitter<PubsubMessage, T, SubscriptionSplitState> {
private final PubSubDeserializationSchema<T> deserializationSchema;
private final AckTracker ackTracker;

public PubSubRecordEmitter(PubSubDeserializationSchema<T> deserializationSchema) {
public PubSubRecordEmitter(
PubSubDeserializationSchema<T> deserializationSchema, AckTracker ackTracker) {
this.deserializationSchema = deserializationSchema;
this.ackTracker = ackTracker;
}

@Override
Expand All @@ -39,6 +42,7 @@ public void emitRecord(
sourceOutput.collect(
deserializationSchema.deserialize(message),
Timestamps.toMillis(message.getPublishTime()));
ackTracker.stagePendingAck(message.getMessageId());
} catch (Exception e) {
throw new IOException("Failed to deserialize PubsubMessage", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.google.pubsub.flink.internal.source.reader;

import com.google.common.annotations.VisibleForTesting;
import com.google.pubsub.flink.PubSubDeserializationSchema;
import com.google.pubsub.flink.internal.source.split.SubscriptionSplit;
import com.google.pubsub.flink.internal.source.split.SubscriptionSplitState;
import com.google.pubsub.v1.PubsubMessage;
Expand All @@ -36,23 +37,18 @@ public interface SplitReaderFactory {

private final AckTracker ackTracker;

@VisibleForTesting
PubSubSourceReader(
RecordEmitter<PubsubMessage, T, SubscriptionSplitState> recordEmitter,
SplitReaderFactory splitReaderFactory,
Configuration config,
SourceReaderContext context,
AckTracker ackTracker) {
super(() -> splitReaderFactory.create(ackTracker), recordEmitter, config, context);
this.ackTracker = ackTracker;
}

public PubSubSourceReader(
RecordEmitter<PubsubMessage, T, SubscriptionSplitState> recordEmitter,
PubSubDeserializationSchema<T> schema,
AckTracker ackTracker,
SplitReaderFactory splitReaderFactory,
Configuration config,
SourceReaderContext context) {
this(recordEmitter, splitReaderFactory, config, context, new PubSubAckTracker());
super(
() -> splitReaderFactory.create(ackTracker),
new PubSubRecordEmitter<>(schema, ackTracker),
config,
context);
this.ackTracker = ackTracker;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.google.pubsub.flink.internal.source.reader;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.cloud.pubsub.v1.AckReplyConsumer;
Expand All @@ -36,18 +37,35 @@ public void doBeforeEachTest() {
@Test
public void singleAck_ackedOnCheckpoint() throws Exception {
AckReplyConsumer mockAck = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck);
ackTracker.addPendingAck("message-id", mockAck);
ackTracker.stagePendingAck("message-id");
ackTracker.addCheckpoint(1L);
ackTracker.notifyCheckpointComplete(1L);
verify(mockAck).ack();
}

@Test
public void singleAck_ackLatestMessageDelivery() throws Exception {
AckReplyConsumer mockAck1 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck("message-id", mockAck1);
AckReplyConsumer mockAck2 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck("message-id", mockAck2);

ackTracker.stagePendingAck("message-id");
ackTracker.addCheckpoint(1L);
ackTracker.notifyCheckpointComplete(1L);
verify(mockAck1, times(0)).ack();
verify(mockAck2, times(1)).ack();
}

@Test
public void manyAcks_ackedOnCheckpoint() throws Exception {
AckReplyConsumer mockAck1 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck1);
ackTracker.addPendingAck("message1-id", mockAck1);
ackTracker.stagePendingAck("message1-id");
AckReplyConsumer mockAck2 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck2);
ackTracker.addPendingAck("message2-id", mockAck2);
ackTracker.stagePendingAck("message2-id");
ackTracker.addCheckpoint(1L);
ackTracker.notifyCheckpointComplete(1L);
verify(mockAck1).ack();
Expand All @@ -57,28 +75,34 @@ public void manyAcks_ackedOnCheckpoint() throws Exception {
@Test
public void manyCheckpoints_completedOneByOne() throws Exception {
AckReplyConsumer mockAck1 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck1);
ackTracker.addPendingAck("message1-id", mockAck1);
ackTracker.stagePendingAck("message1-id");
ackTracker.addCheckpoint(1L);

AckReplyConsumer mockAck2 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck2);
ackTracker.addPendingAck("message2-id", mockAck2);
ackTracker.stagePendingAck("message2-id");
ackTracker.addCheckpoint(2L);

ackTracker.notifyCheckpointComplete(1L);
verify(mockAck1).ack();
verify(mockAck1, times(1)).ack();
verify(mockAck2, times(0)).ack();

ackTracker.notifyCheckpointComplete(2L);
verify(mockAck2).ack();
verify(mockAck1, times(1)).ack();
verify(mockAck2, times(1)).ack();
}

@Test
public void manyCheckpoints_completedTogether() throws Exception {
AckReplyConsumer mockAck1 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck1);
ackTracker.addPendingAck("message1-id", mockAck1);
ackTracker.stagePendingAck("message1-id");
ackTracker.addCheckpoint(1L);

AckReplyConsumer mockAck2 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck2);
ackTracker.addPendingAck("message2-id", mockAck2);
ackTracker.stagePendingAck("message2-id");
ackTracker.addCheckpoint(2L);

ackTracker.notifyCheckpointComplete(2L);
Expand All @@ -87,16 +111,24 @@ public void manyCheckpoints_completedTogether() throws Exception {
}

@Test
public void nackAll_pendingAndIncompleteCheckpointAcks() throws Exception {
public void nackAll_pendingStagedAndIncompleteCheckpointAcks() throws Exception {
// mockAck1 is added to checkpoint 1, which never completes.
AckReplyConsumer mockAck1 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck1);
ackTracker.addPendingAck("message1-id", mockAck1);
ackTracker.stagePendingAck("message1-id");
ackTracker.addCheckpoint(1L);
// mockAck2 is staged.
AckReplyConsumer mockAck2 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck(mockAck2);
ackTracker.addPendingAck("message2-id", mockAck2);
ackTracker.stagePendingAck("message2-id");
// mockAck3 is pending.
AckReplyConsumer mockAck3 = mock(AckReplyConsumer.class);
ackTracker.addPendingAck("message3-id", mockAck3);

ackTracker.nackAll();

verify(mockAck1).nack();
verify(mockAck2).nack();
verify(mockAck3).nack();
}
}
Loading

0 comments on commit 116046e

Please sign in to comment.