Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
feat: add api key support (#1436)
Browse files Browse the repository at this point in the history
* feat: add api key support

* Update gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java

Co-authored-by: Chanseok Oh <[email protected]>

* Update gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java

Co-authored-by: Chanseok Oh <[email protected]>

* update

* update

* Update gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java

Co-authored-by: Chanseok Oh <[email protected]>

* Update gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java

Co-authored-by: Chanseok Oh <[email protected]>

Co-authored-by: Chanseok Oh <[email protected]>
  • Loading branch information
arithmetic1728 and chanseokoh authored Jan 20, 2022
1 parent f631a25 commit 5081ec6
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 11 deletions.
58 changes: 48 additions & 10 deletions gax/src/main/java/com/google/api/gax/rpc/ClientContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
import com.google.api.gax.core.BackgroundResource;
import com.google.api.gax.core.ExecutorAsBackgroundResource;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.rpc.internal.EnvironmentProvider;
import com.google.api.gax.rpc.internal.QuotaProjectIdHidingCredentials;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.api.gax.tracing.ApiTracerFactory;
import com.google.api.gax.tracing.BaseApiTracerFactory;
import com.google.auth.Credentials;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
Expand All @@ -65,6 +67,7 @@
@AutoValue
public abstract class ClientContext {
private static final String QUOTA_PROJECT_ID_HEADER_KEY = "x-goog-user-project";
private static final String API_KEY_HEADER_KEY = "x-goog-api-key";

/**
* The objects that need to be closed in order to clean up the resources created in the process of
Expand Down Expand Up @@ -159,6 +162,32 @@ static String getEndpoint(
return endpoint;
}

/**
* Retrieves the API key value and add it to the headers if API key exists. It first tries to
* retrieve the value from the stub settings. If not found, it then tries the load the
* GOOGLE_API_KEY environment variable. An IOException will be thrown if both GOOGLE_API_KEY and
* GOOGLE_APPLICATION_CREDENTIALS environment variables are set.
*/
@VisibleForTesting
static void addApiKeyToHeaders(
StubSettings settings, EnvironmentProvider environmentProvider, Map<String, String> headers)
throws IOException {
if (settings.getApiKey() != null) {
headers.put(API_KEY_HEADER_KEY, settings.getApiKey());
return;
}

String apiKey = environmentProvider.getenv("GOOGLE_API_KEY");
String applicationCredentials = environmentProvider.getenv("GOOGLE_APPLICATION_CREDENTIALS");
if (apiKey != null && applicationCredentials != null) {
throw new IOException(
"Environment variables GOOGLE_API_KEY and GOOGLE_APPLICATION_CREDENTIALS are mutually exclusive");
}
if (apiKey != null) {
headers.put(API_KEY_HEADER_KEY, apiKey);
}
}

/**
* Instantiates the executor, credentials, and transport context based on the given client
* settings.
Expand All @@ -169,14 +198,21 @@ public static ClientContext create(StubSettings settings) throws IOException {
ExecutorProvider backgroundExecutorProvider = settings.getBackgroundExecutorProvider();
final ScheduledExecutorService backgroundExecutor = backgroundExecutorProvider.getExecutor();

Credentials credentials = settings.getCredentialsProvider().getCredentials();
Credentials credentials = null;
Map<String, String> headers = getHeadersFromSettingsAndEnvironment(settings, System::getenv);

if (settings.getQuotaProjectId() != null) {
// If the quotaProjectId is set, wrap original credentials with correct quotaProjectId as
// QuotaProjectIdHidingCredentials.
// Ensure that a custom set quota project id takes priority over one detected by credentials.
// Avoid the backend receiving possibly conflict values of quotaProjectId
credentials = new QuotaProjectIdHidingCredentials(credentials);
boolean hasApiKey = headers.containsKey(API_KEY_HEADER_KEY);
if (!hasApiKey) {
credentials = settings.getCredentialsProvider().getCredentials();

if (settings.getQuotaProjectId() != null) {
// If the quotaProjectId is set, wrap original credentials with correct quotaProjectId as
// QuotaProjectIdHidingCredentials.
// Ensure that a custom set quota project id takes priority over one detected by
// credentials.
// Avoid the backend receiving possibly conflict values of quotaProjectId
credentials = new QuotaProjectIdHidingCredentials(credentials);
}
}

TransportChannelProvider transportChannelProvider = settings.getTransportChannelProvider();
Expand All @@ -186,11 +222,11 @@ public static ClientContext create(StubSettings settings) throws IOException {
if (transportChannelProvider.needsExecutor() && settings.getExecutorProvider() != null) {
transportChannelProvider = transportChannelProvider.withExecutor(backgroundExecutor);
}
Map<String, String> headers = getHeadersFromSettings(settings);

if (transportChannelProvider.needsHeaders()) {
transportChannelProvider = transportChannelProvider.withHeaders(headers);
}
if (transportChannelProvider.needsCredentials() && credentials != null) {
if (!hasApiKey && transportChannelProvider.needsCredentials()) {
transportChannelProvider = transportChannelProvider.withCredentials(credentials);
}
String endpoint =
Expand Down Expand Up @@ -260,7 +296,8 @@ public static ClientContext create(StubSettings settings) throws IOException {
* Getting a header map from HeaderProvider and InternalHeaderProvider from settings with Quota
* Project Id.
*/
private static Map<String, String> getHeadersFromSettings(StubSettings settings) {
private static Map<String, String> getHeadersFromSettingsAndEnvironment(
StubSettings settings, EnvironmentProvider environmentProvider) throws IOException {
// Resolve conflicts when merging headers from multiple sources
Map<String, String> userHeaders = settings.getHeaderProvider().getHeaders();
Map<String, String> internalHeaders = settings.getInternalHeaderProvider().getHeaders();
Expand All @@ -286,6 +323,7 @@ private static Map<String, String> getHeadersFromSettings(StubSettings settings)
effectiveHeaders.putAll(internalHeaders);
effectiveHeaders.putAll(userHeaders);
effectiveHeaders.putAll(conflictResolution);
addApiKeyToHeaders(settings, environmentProvider, effectiveHeaders);

return ImmutableMap.copyOf(effectiveHeaders);
}
Expand Down
20 changes: 20 additions & 0 deletions gax/src/main/java/com/google/api/gax/rpc/StubSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public abstract class StubSettings<SettingsT extends StubSettings<SettingsT>> {
private final String endpoint;
private final String mtlsEndpoint;
private final String quotaProjectId;
private final String apiKey;
@Nullable private final WatchdogProvider streamWatchdogProvider;
@Nonnull private final Duration streamWatchdogCheckInterval;
@Nonnull private final ApiTracerFactory tracerFactory;
Expand All @@ -99,6 +100,7 @@ protected StubSettings(Builder builder) {
this.mtlsEndpoint = builder.mtlsEndpoint;
this.switchToMtlsEndpointAllowed = builder.switchToMtlsEndpointAllowed;
this.quotaProjectId = builder.quotaProjectId;
this.apiKey = builder.apiKey;
this.streamWatchdogProvider = builder.streamWatchdogProvider;
this.streamWatchdogCheckInterval = builder.streamWatchdogCheckInterval;
this.tracerFactory = builder.tracerFactory;
Expand Down Expand Up @@ -154,6 +156,10 @@ public final String getQuotaProjectId() {
return quotaProjectId;
}

public final String getApiKey() {
return apiKey;
}

@BetaApi("The surface for streaming is not stable yet and may change in the future.")
@Nullable
public final WatchdogProvider getStreamWatchdogProvider() {
Expand Down Expand Up @@ -189,6 +195,7 @@ public String toString() {
.add("mtlsEndpoint", mtlsEndpoint)
.add("switchToMtlsEndpointAllowed", switchToMtlsEndpointAllowed)
.add("quotaProjectId", quotaProjectId)
.add("apiKey", apiKey)
.add("streamWatchdogProvider", streamWatchdogProvider)
.add("streamWatchdogCheckInterval", streamWatchdogCheckInterval)
.add("tracerFactory", tracerFactory)
Expand All @@ -209,6 +216,7 @@ public abstract static class Builder<
private String endpoint;
private String mtlsEndpoint;
private String quotaProjectId;
private String apiKey;
@Nullable private WatchdogProvider streamWatchdogProvider;
@Nonnull private Duration streamWatchdogCheckInterval;
@Nonnull private ApiTracerFactory tracerFactory;
Expand All @@ -234,6 +242,7 @@ protected Builder(StubSettings settings) {
this.mtlsEndpoint = settings.mtlsEndpoint;
this.switchToMtlsEndpointAllowed = settings.switchToMtlsEndpointAllowed;
this.quotaProjectId = settings.quotaProjectId;
this.apiKey = settings.apiKey;
this.streamWatchdogProvider = settings.streamWatchdogProvider;
this.streamWatchdogCheckInterval = settings.streamWatchdogCheckInterval;
this.tracerFactory = settings.tracerFactory;
Expand All @@ -258,6 +267,7 @@ private static String getQuotaProjectIdFromClientContext(ClientContext clientCon
}

protected Builder(ClientContext clientContext) {
this.apiKey = null;
if (clientContext == null) {
this.backgroundExecutorProvider = InstantiatingExecutorProvider.newBuilder().build();
this.transportChannelProvider = null;
Expand Down Expand Up @@ -432,6 +442,11 @@ public B setQuotaProjectId(String quotaProjectId) {
return self();
}

public B setApiKey(String apiKey) {
this.apiKey = apiKey;
return self();
}

/**
* Sets how often the {@link Watchdog} will check ongoing streaming RPCs. Defaults to 10 secs.
* Use {@link Duration#ZERO} to disable.
Expand Down Expand Up @@ -513,6 +528,10 @@ public String getQuotaProjectId() {
return quotaProjectId;
}

public String getApiKey() {
return apiKey;
}

@BetaApi("The surface for streaming is not stable yet and may change in the future.")
@Nonnull
public Duration getStreamWatchdogCheckInterval() {
Expand Down Expand Up @@ -549,6 +568,7 @@ public String toString() {
.add("mtlsEndpoint", mtlsEndpoint)
.add("switchToMtlsEndpointAllowed", switchToMtlsEndpointAllowed)
.add("quotaProjectId", quotaProjectId)
.add("apiKey", apiKey)
.add("streamWatchdogProvider", streamWatchdogProvider)
.add("streamWatchdogCheckInterval", streamWatchdogCheckInterval)
.add("tracerFactory", tracerFactory)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2021 Google LLC
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google LLC nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.google.api.gax.rpc.internal;

import com.google.api.core.InternalExtensionOnly;

/** Provides an interface to provide the environment variable values. */
@InternalExtensionOnly
public interface EnvironmentProvider {
/** Returns the environment variable value. */
String getenv(String name);
}
74 changes: 73 additions & 1 deletion gax/src/test/java/com/google/api/gax/rpc/ClientContextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

Expand All @@ -41,6 +42,7 @@
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.core.FixedExecutorProvider;
import com.google.api.gax.rpc.internal.EnvironmentProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider.MtlsEndpointUsagePolicy;
import com.google.api.gax.rpc.testing.FakeChannel;
Expand All @@ -54,6 +56,7 @@
import com.google.common.truth.Truth;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
Expand Down Expand Up @@ -176,7 +179,7 @@ public TransportChannelProvider withPoolSize(int size) {

@Override
public TransportChannel getTransportChannel() throws IOException {
if (needsCredentials()) {
if (needsCredentials() && !headers.containsKey("x-goog-api-key")) {
throw new IllegalStateException("Needs Credentials");
}
transport.setExecutor(executor);
Expand Down Expand Up @@ -769,4 +772,73 @@ public void testExecutorSettings() throws Exception {
transportChannel = (FakeTransportChannel) context.getTransportChannel();
assertThat(transportChannel.getExecutor()).isSameInstanceAs(executorProvider.getExecutor());
}

@Test
public void testAddApiKeyToHeadersFromStubSettings() throws IOException {
StubSettings settings = new FakeStubSettings.Builder().setApiKey("stub-setting-key").build();
EnvironmentProvider environmentProvider =
name -> name.equals("GOOGLE_API_KEY") ? "env-key" : null;
Map<String, String> headers = new HashMap<>();
ClientContext.addApiKeyToHeaders(settings, environmentProvider, headers);
assertThat(headers).containsEntry("x-goog-api-key", "stub-setting-key");
}

@Test
public void testAddApiKeyToHeadersFromEnvironmentProvider() throws IOException {
StubSettings settings = new FakeStubSettings.Builder().build();
EnvironmentProvider environmentProvider =
name -> name.equals("GOOGLE_API_KEY") ? "env-key" : null;
Map<String, String> headers = new HashMap<>();
ClientContext.addApiKeyToHeaders(settings, environmentProvider, headers);
assertThat(headers).containsEntry("x-goog-api-key", "env-key");
}

@Test
public void testAddApiKeyToHeadersNoApiKey() throws IOException {
StubSettings settings = new FakeStubSettings.Builder().build();
EnvironmentProvider environmentProvider = name -> null;
Map<String, String> headers = new HashMap<>();
ClientContext.addApiKeyToHeaders(settings, environmentProvider, headers);
assertThat(headers).doesNotContainKey("x-goog-api-key");
}

@Test
public void testAddApiKeyToHeadersThrows() throws IOException {
StubSettings settings = new FakeStubSettings.Builder().build();
EnvironmentProvider environmentProvider =
name -> name.equals("GOOGLE_API_KEY") ? "env-key" : "/path/to/adc/json";
Map<String, String> headers = new HashMap<>();
Exception ex =
assertThrows(
IOException.class,
() -> ClientContext.addApiKeyToHeaders(settings, environmentProvider, headers));
assertThat(ex)
.hasMessageThat()
.contains(
"Environment variables GOOGLE_API_KEY and GOOGLE_APPLICATION_CREDENTIALS are mutually exclusive");
}

@Test
public void testApiKey() throws IOException {
FakeStubSettings.Builder builder = new FakeStubSettings.Builder();

FakeTransportChannel transportChannel = FakeTransportChannel.create(new FakeChannel());
FakeTransportProvider transportProvider =
new FakeTransportProvider(transportChannel, null, true, null, null);
builder.setTransportChannelProvider(transportProvider);

HeaderProvider headerProvider = Mockito.mock(HeaderProvider.class);
Mockito.when(headerProvider.getHeaders()).thenReturn(ImmutableMap.of());
builder.setHeaderProvider(headerProvider);

// Set API key.
builder.setApiKey("key");

ClientContext context = ClientContext.create(builder.build());

// Check API key is in the transport channel's header.
List<BackgroundResource> resources = context.getBackgroundResources();
FakeTransportChannel fakeTransportChannel = (FakeTransportChannel) resources.get(0);
assertThat(fakeTransportChannel.getHeaders()).containsEntry("x-goog-api-key", "key");
}
}

0 comments on commit 5081ec6

Please sign in to comment.