Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Validate the Universe Domain #2330

Merged
merged 14 commits into from
Jan 12, 2024

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ public static <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
channel = ClientInterceptors.intercept(channel, interceptor);
}

// Validate the Universe Domain prior to the call. Only allow the call to go through
// if the Universe Domain is valid.
grpcContext.validateUniverseDomain();

try (Scope ignored = grpcContext.getTracer().inScope()) {
return channel.newCall(descriptor, callOptions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
Expand Down Expand Up @@ -628,10 +629,15 @@ public void testReleasingClientCallCancelEarly() throws IOException {
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
pool = ChannelPool.create(channelPoolSettings, factory);

EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.when(endpointContext.hasValidUniverseDomain(Mockito.any())).thenReturn(true);

ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(pool))
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
.setDefaultCallContext(
GrpcCallContext.of(pool, CallOptions.DEFAULT).withEndpointContext(endpointContext))
.build();
ServerStreamingCallSettings settings =
ServerStreamingCallSettings.<Color, Money>newBuilder().build();
Expand Down Expand Up @@ -680,11 +686,17 @@ public void testDoubleRelease() throws Exception {

pool = ChannelPool.create(channelPoolSettings, factory);

EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.when(endpointContext.hasValidUniverseDomain(Mockito.any())).thenReturn(true);
Mockito.when(endpointContext.merge(Mockito.any())).thenReturn(endpointContext);

// Construct a fake callable to use the channel pool
ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(pool))
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
.setDefaultCallContext(
GrpcCallContext.of(pool, CallOptions.DEFAULT)
.withEndpointContext(endpointContext))
.build();

UnaryCallSettings<Color, Money> settings =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.StatusCode;
import com.google.api.gax.rpc.testing.FakeCallContext;
import com.google.api.gax.rpc.testing.FakeChannel;
Expand All @@ -46,6 +47,7 @@
import io.grpc.CallOptions;
import io.grpc.ManagedChannel;
import io.grpc.Metadata.Key;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -373,7 +375,14 @@ public void testWithOptions() {
}

@Test
public void testMergeOptions() {
public void testEndpointContext() throws IOException {
EndpointContext endpointContext = EndpointContext.newBuilder().setServiceName("test").build();
GrpcCallContext context = GrpcCallContext.createDefault().withEndpointContext(endpointContext);
assertEquals(context.getEndpointContext(), endpointContext);
}

@Test
public void testMergeOptions() throws IOException {
GrpcCallContext emptyCallContext = GrpcCallContext.createDefault();
ApiCallContext.Key<String> contextKey1 = ApiCallContext.Key.create("testKey1");
ApiCallContext.Key<String> contextKey2 = ApiCallContext.Key.create("testKey2");
Expand All @@ -391,9 +400,13 @@ public void testMergeOptions() {
.withOption(contextKey1, testContextOverwrite)
.withOption(contextKey3, testContext3);
ApiCallContext mergedContext = context1.merge(context2);
EndpointContext endpointContext = EndpointContext.newBuilder().setServiceName("test").build();
ApiCallContext context3 = emptyCallContext.withEndpointContext(endpointContext);
mergedContext = mergedContext.merge(context3);
assertEquals(testContextOverwrite, mergedContext.getOption(contextKey1));
assertEquals(testContext2, mergedContext.getOption(contextKey2));
assertEquals(testContext3, mergedContext.getOption(contextKey3));
assertEquals(mergedContext.getEndpointContext(), endpointContext);
}

private static Map<String, List<String>> createTestExtraHeaders(String... keyValues) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
import com.google.api.gax.grpc.testing.FakeServiceImpl;
import com.google.api.gax.grpc.testing.InProcessServer;
import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.InvalidArgumentException;
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
Expand Down Expand Up @@ -74,10 +76,15 @@ public void setUp() throws Exception {
inprocessServer.start();

channel = InProcessChannelBuilder.forName(serverName).directExecutor().usePlaintext().build();
EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.when(endpointContext.hasValidUniverseDomain(Mockito.any())).thenReturn(true);
Mockito.when(endpointContext.merge(Mockito.any())).thenReturn(endpointContext);
clientContext =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(channel))
.setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT))
.setDefaultCallContext(
GrpcCallContext.of(channel, CallOptions.DEFAULT)
.withEndpointContext(endpointContext))
.build();
}

Expand Down Expand Up @@ -106,11 +113,10 @@ public void createServerStreamingCallableRetryableExceptions() {
GrpcCallableFactory.createServerStreamingCallable(
grpcCallSettings, nonRetryableSettings, clientContext);

ApiCallContext defaultCallContext = clientContext.getDefaultCallContext();
Throwable actualError = null;
try {
nonRetryableCallable
.first()
.call(Color.getDefaultInstance(), clientContext.getDefaultCallContext());
nonRetryableCallable.first().call(Color.getDefaultInstance(), defaultCallContext);
} catch (Throwable e) {
actualError = e;
}
Expand All @@ -134,9 +140,7 @@ public void createServerStreamingCallableRetryableExceptions() {

Throwable actualError2 = null;
try {
retryableCallable
.first()
.call(Color.getDefaultInstance(), clientContext.getDefaultCallContext());
retryableCallable.first().call(Color.getDefaultInstance(), defaultCallContext);
} catch (Throwable e) {
actualError2 = e;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@
package com.google.api.gax.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.verify;

import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.PermissionDeniedException;
import com.google.api.gax.rpc.UnavailableException;
import com.google.auth.Credentials;
import com.google.auth.Retryable;
import com.google.common.collect.ImmutableList;
import com.google.common.truth.Truth;
import com.google.type.Color;
Expand All @@ -51,12 +57,50 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.threeten.bp.Duration;

public class GrpcClientCallsTest {

// Auth Library's GoogleAuthException is package-private. Copy basic functionality for tests
private static class GoogleAuthException extends IOException implements Retryable {

private final boolean isRetryable;

private GoogleAuthException(boolean isRetryable) {
this.isRetryable = isRetryable;
}

@Override
public boolean isRetryable() {
return isRetryable;
}

@Override
public int getRetryCount() {
return 0;
}
}

private GrpcCallContext defaultCallContext;
private EndpointContext endpointContext;
private Credentials credentials;
private Channel mockChannel;

@Before
public void setUp() throws IOException {
credentials = Mockito.mock(Credentials.class);
endpointContext = Mockito.mock(EndpointContext.class);
mockChannel = Mockito.mock(Channel.class);

defaultCallContext = GrpcCallContext.createDefault().withEndpointContext(endpointContext);

Mockito.when(endpointContext.hasValidUniverseDomain(Mockito.any())).thenReturn(true);
}

@Test
public void testAffinity() throws IOException {
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
Expand All @@ -78,7 +122,7 @@ public void testAffinity() throws IOException {
ChannelPool.create(
ChannelPoolSettings.staticallySized(2),
new FakeChannelFactory(Arrays.asList(channel0, channel1)));
GrpcCallContext context = GrpcCallContext.createDefault().withChannel(pool);
GrpcCallContext context = defaultCallContext.withChannel(pool);

ClientCall<Color, Money> gotCallA =
GrpcClientCalls.newCall(descriptor, context.withChannelAffinity(0));
Expand All @@ -92,7 +136,7 @@ public void testAffinity() throws IOException {
}

@Test
public void testExtraHeaders() {
public void testExtraHeaders() throws IOException {
Metadata emptyHeaders = new Metadata();
final Map<String, List<String>> extraHeaders = new HashMap<>();
extraHeaders.put(
Expand Down Expand Up @@ -128,12 +172,12 @@ public void testExtraHeaders() {
.thenReturn(mockClientCall);

GrpcCallContext context =
GrpcCallContext.createDefault().withChannel(mockChannel).withExtraHeaders(extraHeaders);
defaultCallContext.withChannel(mockChannel).withExtraHeaders(extraHeaders);
GrpcClientCalls.newCall(descriptor, context).start(mockListener, emptyHeaders);
}

@Test
public void testTimeoutToDeadlineConversion() {
public void testTimeoutToDeadlineConversion() throws IOException {
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;

@SuppressWarnings("unchecked")
Expand All @@ -152,8 +196,7 @@ public void testTimeoutToDeadlineConversion() {
Duration timeout = Duration.ofSeconds(10);
Deadline minExpectedDeadline = Deadline.after(timeout.getSeconds(), TimeUnit.SECONDS);

GrpcCallContext context =
GrpcCallContext.createDefault().withChannel(mockChannel).withTimeout(timeout);
GrpcCallContext context = defaultCallContext.withChannel(mockChannel).withTimeout(timeout);

GrpcClientCalls.newCall(descriptor, context).start(mockListener, new Metadata());

Expand All @@ -164,7 +207,7 @@ public void testTimeoutToDeadlineConversion() {
}

@Test
public void testTimeoutAfterDeadline() {
public void testTimeoutAfterDeadline() throws IOException {
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;

@SuppressWarnings("unchecked")
Expand All @@ -185,7 +228,7 @@ public void testTimeoutAfterDeadline() {
Duration timeout = Duration.ofSeconds(10);

GrpcCallContext context =
GrpcCallContext.createDefault()
defaultCallContext
.withChannel(mockChannel)
.withCallOptions(CallOptions.DEFAULT.withDeadline(priorDeadline))
.withTimeout(timeout);
Expand All @@ -197,7 +240,7 @@ public void testTimeoutAfterDeadline() {
}

@Test
public void testTimeoutBeforeDeadline() {
public void testTimeoutBeforeDeadline() throws IOException {
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;

@SuppressWarnings("unchecked")
Expand All @@ -219,7 +262,7 @@ public void testTimeoutBeforeDeadline() {
Deadline minExpectedDeadline = Deadline.after(timeout.getSeconds(), TimeUnit.SECONDS);

GrpcCallContext context =
GrpcCallContext.createDefault()
defaultCallContext
.withChannel(mockChannel)
.withCallOptions(CallOptions.DEFAULT.withDeadline(subsequentDeadline))
.withTimeout(timeout);
Expand All @@ -232,4 +275,63 @@ public void testTimeoutBeforeDeadline() {
Truth.assertThat(capturedCallOptions.getValue().getDeadline()).isAtLeast(minExpectedDeadline);
Truth.assertThat(capturedCallOptions.getValue().getDeadline()).isAtMost(maxExpectedDeadline);
}

@Test
public void testValidUniverseDomain() throws IOException {
Mockito.when(endpointContext.hasValidUniverseDomain(credentials)).thenReturn(true);
GrpcCallContext context =
GrpcCallContext.createDefault()
.withChannel(mockChannel)
.withCredentials(credentials)
.withEndpointContext(endpointContext);

CallOptions callOptions = context.getCallOptions();

MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
GrpcClientCalls.newCall(descriptor, context);
Mockito.verify(mockChannel, Mockito.times(1)).newCall(descriptor, callOptions);
}

// This test is when the universe domain does not match
@Test
public void testInvalidUniverseDomain() throws IOException {
Mockito.when(endpointContext.hasValidUniverseDomain(credentials)).thenReturn(false);
GrpcCallContext context =
GrpcCallContext.createDefault()
.withChannel(mockChannel)
.withCredentials(credentials)
.withEndpointContext(endpointContext);

CallOptions callOptions = context.getCallOptions();

MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
PermissionDeniedException exception =
assertThrows(
PermissionDeniedException.class, () -> GrpcClientCalls.newCall(descriptor, context));
assertThat(exception.getStatusCode().getCode())
.isEqualTo(GrpcStatusCode.Code.PERMISSION_DENIED);
Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions);
}

// This test is when the MDS is unable to return a valid universe domain
@Test
public void testUniverseDomainNotReady_shouldRetry() throws IOException {
Mockito.when(endpointContext.hasValidUniverseDomain(credentials))
.thenThrow(new GoogleAuthException(true));
GrpcCallContext context =
GrpcCallContext.createDefault()
.withChannel(mockChannel)
.withCredentials(credentials)
.withEndpointContext(endpointContext);

CallOptions callOptions = context.getCallOptions();

MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
UnavailableException exception =
assertThrows(
UnavailableException.class, () -> GrpcClientCalls.newCall(descriptor, context));
assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAVAILABLE);
Truth.assertThat(exception.isRetryable()).isTrue();
Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.api.gax.grpc.testing.InProcessServer;
import com.google.api.gax.rpc.ApiException;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStream;
import com.google.api.gax.rpc.ServerStreamingCallSettings;
Expand Down Expand Up @@ -63,6 +64,7 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
public class GrpcDirectServerStreamingCallableTest {
Expand All @@ -85,11 +87,16 @@ public void setUp() throws InstantiationException, IllegalAccessException, IOExc
inprocessServer = new InProcessServer<>(serviceImpl, serverName);
inprocessServer.start();

EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.when(endpointContext.hasValidUniverseDomain(Mockito.any())).thenReturn(true);

channel = InProcessChannelBuilder.forName(serverName).directExecutor().usePlaintext().build();
clientContext =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(channel))
.setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT))
.setDefaultCallContext(
GrpcCallContext.of(channel, CallOptions.DEFAULT)
.withEndpointContext(endpointContext))
.build();
streamingCallSettings = ServerStreamingCallSettings.<Color, Money>newBuilder().build();
streamingCallable =
Expand Down
Loading
Loading