Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure that channel pool ref count never goes negative (take2) #2065

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
* <p>Package-private for internal use.
*/
class ChannelPool extends ManagedChannel {
private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());
@VisibleForTesting static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());
private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50);

private final ChannelPoolSettings settings;
Expand Down Expand Up @@ -381,14 +381,14 @@ void refresh() {
* Get and retain a Channel Entry. The returned Entry will have its rpc count incremented,
* preventing it from getting recycled.
*/
Entry getRetainedEntry(int affinity) {
RetainedEntry getRetainedEntry(int affinity) {
// The maximum number of concurrent calls to this method for any given time span is at most 2,
// so the loop can actually be 2 times. But going for 5 times for a safety margin for potential
// code evolving
for (int i = 0; i < 5; i++) {
Entry entry = getEntry(affinity);
if (entry.retain()) {
return entry;
return new RetainedEntry(entry);
}
}
// It is unlikely to reach here unless the pool code evolves to increase the maximum possible
Expand All @@ -415,10 +415,37 @@ private Entry getEntry(int affinity) {
return localEntries.get(index);
}

/**
* This represents the reserved refcount of a single RPC using a channel. It the responsibility of
* that RPC to call release exactly once when it completes to release the Channel.
*/
private static class RetainedEntry {
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved
private final Entry entry;
private final AtomicBoolean wasReleased;

public RetainedEntry(Entry entry) {
this.entry = entry;
wasReleased = new AtomicBoolean(false);
}

void release() {
if (!wasReleased.compareAndSet(false, true)) {
Exception e = new IllegalStateException("Entry was already released");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another stacktrace we don't want to log, can we just log the WARNING without a stacktrace?
Or if we think this is a true error, then we can throw the Exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely need a stacktrace here, it will tell us how the refcount got negative. Otherwise we are no wiser than before.....we are just hiding the problem under the rug. And in this case alert firing woud be correct

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me Entry was already released is a true error, why are we just logging instead of throwing an exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because throwing will bubble up to the enduser, who had nothing to do with the error. Instead I want to keep it localized to the double release and protect the rest of the application

LOG.log(Level.WARNING, e.getMessage(), e);
return;
}
entry.release();
igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved
}

public Channel getChannel() {
return entry.channel;
}
}

/** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */
private static class Entry {
static class Entry {
private final ManagedChannel channel;
private final AtomicInteger outstandingRpcs = new AtomicInteger(0);
final AtomicInteger outstandingRpcs = new AtomicInteger(0);
private final AtomicInteger maxOutstanding = new AtomicInteger();

// Flag that the channel should be closed once all of the outstanding RPC complete.
Expand Down Expand Up @@ -511,18 +538,19 @@ public String authority() {
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {

Entry entry = getRetainedEntry(affinity);
RetainedEntry entry = getRetainedEntry(affinity);

return new ReleasingClientCall<>(entry.channel.newCall(methodDescriptor, callOptions), entry);
return new ReleasingClientCall<>(
entry.getChannel().newCall(methodDescriptor, callOptions), entry);
}
}

/** ClientCall wrapper that makes sure to decrement the outstanding RPC count on completion. */
static class ReleasingClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
@Nullable private CancellationException cancellationException;
final Entry entry;
final RetainedEntry entry;

public ReleasingClientCall(ClientCall<ReqT, RespT> delegate, Entry entry) {
public ReleasingClientCall(ClientCall<ReqT, RespT> delegate, RetainedEntry entry) {
super(delegate);
this.entry = entry;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
*/
package com.google.api.gax.grpc;

import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_RECOGNIZE;
import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE;
import static com.google.common.truth.Truth.assertThat;

import com.google.api.core.ApiFuture;
import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
Expand All @@ -40,6 +42,8 @@
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
Expand All @@ -63,6 +67,9 @@
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Handler;
import java.util.logging.LogRecord;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -663,4 +670,72 @@ public void onComplete() {}
assertThat(e.getCause()).isInstanceOf(CancellationException.class);
assertThat(e.getMessage()).isEqualTo("Call is already cancelled");
}

@Test
public void testDoubleRelease() throws Exception {
FakeLogHandler logHandler = new FakeLogHandler();
ChannelPool.LOG.addHandler(logHandler);

try {
// Create a fake channel pool thats backed by mock channels that simply record invocations
ClientCall mockClientCall = Mockito.mock(ClientCall.class);
ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall);
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));

pool = ChannelPool.create(channelPoolSettings, factory);

// Construct a fake callable to use the channel pool
ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(pool))
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
.build();

UnaryCallSettings<Color, Money> settings =
UnaryCallSettings.<Color, Money>newUnaryCallSettingsBuilder().build();
UnaryCallable<Color, Money> callable =
GrpcCallableFactory.createUnaryCallable(
GrpcCallSettings.create(METHOD_RECOGNIZE), settings, context);

// Start the RPC
ApiFuture<Money> rpcFuture =
callable.futureCall(Color.getDefaultInstance(), context.getDefaultCallContext());

// Get the server side listener and intentionally close it twice
ArgumentCaptor<ClientCall.Listener<?>> clientCallListenerCaptor =
ArgumentCaptor.forClass(ClientCall.Listener.class);
Mockito.verify(mockClientCall).start(clientCallListenerCaptor.capture(), Mockito.any());
clientCallListenerCaptor.getValue().onClose(Status.INTERNAL, new Metadata());
clientCallListenerCaptor.getValue().onClose(Status.UNKNOWN, new Metadata());

// Ensure that the channel pool properly logged the double call and kept the refCount correct
assertThat(logHandler.getAllMessages()).contains("Entry was already released");
assertThat(pool.entries.get()).hasSize(1);
ChannelPool.Entry entry = pool.entries.get().get(0);
assertThat(entry.outstandingRpcs.get()).isEqualTo(0);
} finally {
ChannelPool.LOG.removeHandler(logHandler);
}
}

private static class FakeLogHandler extends Handler {
List<LogRecord> records = new ArrayList<>();

@Override
public void publish(LogRecord record) {
records.add(record);
}

@Override
public void flush() {}

@Override
public void close() throws SecurityException {}

List<String> getAllMessages() {
return records.stream().map(LogRecord::getMessage).collect(Collectors.toList());
}
}
}
Loading