Skip to content

Commit

Permalink
fix: Handle cancel in ReleasingClientCall and rethrow the exception i…
Browse files Browse the repository at this point in the history
…n start (#1221)

* fix: Handle cancel in ReleasingClientCall and rethrow the exception in start

* address comments
  • Loading branch information
mutianf authored Jan 11, 2023
1 parent 84a1355 commit 8a61249
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand All @@ -53,6 +54,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import org.threeten.bp.Duration;

/**
Expand Down Expand Up @@ -517,6 +519,7 @@ public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(

/** ClientCall wrapper that makes sure to decrement the outstanding RPC count on completion. */
static class ReleasingClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
@Nullable private CancellationException cancellationException;
final Entry entry;

public ReleasingClientCall(ClientCall<ReqT, RespT> delegate, Entry entry) {
Expand All @@ -526,6 +529,9 @@ public ReleasingClientCall(ClientCall<ReqT, RespT> delegate, Entry entry) {

@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
if (cancellationException != null) {
throw new IllegalStateException("Call is already cancelled", cancellationException);
}
try {
super.start(
new SimpleForwardingClientCallListener<RespT>(responseListener) {
Expand All @@ -542,7 +548,14 @@ public void onClose(Status status, Metadata trailers) {
} catch (Exception e) {
// In case start failed, make sure to release
entry.release();
throw e;
}
}

@Override
public void cancel(@Nullable String message, @Nullable Throwable cause) {
this.cancellationException = new CancellationException(message);
super.cancel(message, cause);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@
*/
package com.google.api.gax.grpc;

import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE;
import static com.google.common.truth.Truth.assertThat;

import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.type.Color;
Expand All @@ -49,12 +55,14 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -595,4 +603,50 @@ public void removedActiveChannelsAreShutdown() throws Exception {
// Now the channel should be closed
Mockito.verify(channels.get(1), Mockito.times(1)).shutdown();
}

@Test
public void testReleasingClientCallCancelEarly() throws IOException {
ClientCall mockClientCall = Mockito.mock(ClientCall.class);
Mockito.doAnswer(invocation -> null).when(mockClientCall).cancel(Mockito.any(), Mockito.any());
ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall);
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
ChannelPool channelPool = ChannelPool.create(channelPoolSettings, factory);
ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(channelPool))
.setDefaultCallContext(GrpcCallContext.of(channelPool, CallOptions.DEFAULT))
.build();
ServerStreamingCallSettings settings =
ServerStreamingCallSettings.<Color, Money>newBuilder().build();
ServerStreamingCallable streamingCallable =
GrpcCallableFactory.createServerStreamingCallable(
GrpcCallSettings.create(METHOD_SERVER_STREAMING_RECOGNIZE), settings, context);
Color request = Color.newBuilder().setRed(0.5f).build();

IllegalStateException e =
Assert.assertThrows(
IllegalStateException.class,
() ->
streamingCallable.call(
request,
new ResponseObserver() {
@Override
public void onStart(StreamController controller) {
controller.cancel();
}

@Override
public void onResponse(Object response) {}

@Override
public void onError(Throwable t) {}

@Override
public void onComplete() {}
}));
assertThat(e.getCause()).isInstanceOf(CancellationException.class);
assertThat(e.getMessage()).isEqualTo("Call is already cancelled");
}
}

0 comments on commit 8a61249

Please sign in to comment.