Skip to content

Commit

Permalink
fix(auth): Update http and grpc transports to support token exchange …
Browse files Browse the repository at this point in the history
…over mTLS (#10397)

This fixes a regression in the new Golang Auth lib. The default http client used for OAuth2 token exchange (for both http and grpc transport) should be configured for mTLS (including switching endpoints) when certificate source is available (whether from default cert source or explicitly configured).
  • Loading branch information
andyrzhao authored Jun 28, 2024
1 parent a1e198a commit c6dfdcf
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 6 deletions.
3 changes: 3 additions & 0 deletions auth/credentials/detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ const (
googleAuthURL = "https://accounts.google.com/o/oauth2/auth"
googleTokenURL = "https://oauth2.googleapis.com/token"

// GoogleMTLSTokenURL is Google's default OAuth2.0 mTLS endpoint.
GoogleMTLSTokenURL = "https://oauth2.mtls.googleapis.com/token"

// Help on default credentials
adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc"
)
Expand Down
30 changes: 27 additions & 3 deletions auth/grpctransport/grpctransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package grpctransport

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -45,6 +46,11 @@ var (
timeoutDialerOption grpc.DialOption
)

// ClientCertProvider is a function that returns a TLS client certificate to be
// used when opening TLS connections. It follows the same semantics as
// [crypto/tls.Config.GetClientCertificate].
type ClientCertProvider = func(*tls.CertificateRequestInfo) (*tls.Certificate, error)

// Options used to configure a [GRPCClientConnPool] from [Dial].
type Options struct {
// DisableTelemetry disables default telemetry (OpenTelemetry). An example
Expand All @@ -69,6 +75,10 @@ type Options struct {
// Credentials used to add Authorization metadata to all requests. If set
// DetectOpts are ignored.
Credentials *auth.Credentials
// ClientCertProvider is a function that returns a TLS client certificate to
// be used when opening TLS connections. It follows the same semantics as
// crypto/tls.Config.GetClientCertificate.
ClientCertProvider ClientCertProvider
// DetectOpts configures settings for detect Application Default
// Credentials.
DetectOpts *credentials.DetectOptions
Expand Down Expand Up @@ -125,6 +135,13 @@ func (o *Options) resolveDetectOptions() *credentials.DetectOptions {
if len(do.Scopes) == 0 && do.Audience == "" && io != nil {
do.Audience = o.InternalOptions.DefaultAudience
}
if o.ClientCertProvider != nil {
tlsConfig := &tls.Config{
GetClientCertificate: o.ClientCertProvider,
}
do.Client = transport.DefaultHTTPClientWithTLS(tlsConfig)
do.TokenURL = credentials.GoogleMTLSTokenURL
}
return do
}

Expand Down Expand Up @@ -189,9 +206,10 @@ func Dial(ctx context.Context, secure bool, opts *Options) (GRPCClientConnPool,
// return a GRPCClientConnPool if pool == 1 or else a pool of of them if >1
func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, error) {
tOpts := &transport.Options{
Endpoint: opts.Endpoint,
Client: opts.client(),
UniverseDomain: opts.UniverseDomain,
Endpoint: opts.Endpoint,
ClientCertProvider: opts.ClientCertProvider,
Client: opts.client(),
UniverseDomain: opts.UniverseDomain,
}
if io := opts.InternalOptions; io != nil {
tOpts.DefaultEndpointTemplate = io.DefaultEndpointTemplate
Expand All @@ -213,6 +231,12 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er
grpc.WithTransportCredentials(transportCreds),
}

// Ensure the token exchange HTTP transport uses the same ClientCertProvider as the GRPC API transport.
opts.ClientCertProvider, err = transport.GetClientCertificateProvider(tOpts)
if err != nil {
return nil, err
}

// Authentication can only be sent when communicating over a secure connection.
if !opts.DisableAuthentication {
metadata := opts.Metadata
Expand Down
9 changes: 9 additions & 0 deletions auth/httptransport/httptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ func (o *Options) resolveDetectOptions() *detect.DetectOptions {
if len(do.Scopes) == 0 && do.Audience == "" && io != nil {
do.Audience = o.InternalOptions.DefaultAudience
}
if o.ClientCertProvider != nil {
tlsConfig := &tls.Config{
GetClientCertificate: o.ClientCertProvider,
}
do.Client = transport.DefaultHTTPClientWithTLS(tlsConfig)
do.TokenURL = detect.GoogleMTLSTokenURL
}
return do
}

Expand Down Expand Up @@ -195,6 +202,8 @@ func NewClient(opts *Options) (*http.Client, error) {
if baseRoundTripper == nil {
baseRoundTripper = defaultBaseTransport(clientCertProvider, dialTLSContext)
}
// Ensure the token exchange transport uses the same ClientCertProvider as the API transport.
opts.ClientCertProvider = clientCertProvider
trans, err := newTransport(baseRoundTripper, opts)
if err != nil {
return nil, err
Expand Down
6 changes: 3 additions & 3 deletions auth/internal/transport/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context,
}

func getTransportConfig(opts *Options) (*transportConfig, error) {
clientCertSource, err := getClientCertificateSource(opts)
clientCertSource, err := GetClientCertificateProvider(opts)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -210,13 +210,13 @@ func getTransportConfig(opts *Options) (*transportConfig, error) {
}, nil
}

// getClientCertificateSource returns a default client certificate source, if
// GetClientCertificateProvider returns a default client certificate source, if
// not provided by the user.
//
// A nil default source can be returned if the source does not exist. Any exceptions
// encountered while initializing the default source will be reported as client
// error (ex. corrupt metadata file).
func getClientCertificateSource(opts *Options) (cert.Provider, error) {
func GetClientCertificateProvider(opts *Options) (cert.Provider, error) {
if !isClientCertificateEnabled(opts) {
return nil, nil
} else if opts.ClientCertProvider != nil {
Expand Down
74 changes: 74 additions & 0 deletions auth/internal/transport/cba_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import (
"net/http"
"testing"
"time"

"cloud.google.com/go/auth/internal"
"cloud.google.com/go/auth/internal/transport/cert"
)

const (
Expand Down Expand Up @@ -627,3 +630,74 @@ func TestGetGRPCTransportCredsAndEndpoint_UniverseDomain(t *testing.T) {
})
}
}

func TestGetClientCertificateProvider(t *testing.T) {
testCases := []struct {
name string
opts *Options
useCertEnvVar string
wantCertProvider cert.Provider
wantErr error
}{
{
name: "UseCertEnvVar false, Domain is GDU",
opts: &Options{
UniverseDomain: internal.DefaultUniverseDomain,
ClientCertProvider: fakeClientCertSource,
Endpoint: testRegularEndpoint,
},
useCertEnvVar: "false",
wantCertProvider: nil,
},
{
name: "UseCertEnvVar unset, Domain is not GDU",
opts: &Options{
UniverseDomain: testUniverseDomain,
ClientCertProvider: fakeClientCertSource,
Endpoint: testOverrideEndpoint,
},
useCertEnvVar: "unset",
wantCertProvider: nil,
},
{
name: "UseCertEnvVar unset, Domain is GDU",
opts: &Options{
UniverseDomain: internal.DefaultUniverseDomain,
ClientCertProvider: fakeClientCertSource,
Endpoint: testRegularEndpoint,
},
useCertEnvVar: "unset",
wantCertProvider: fakeClientCertSource,
},
{
name: "UseCertEnvVar true, Domain is not GDU",
opts: &Options{
UniverseDomain: testUniverseDomain,
ClientCertProvider: fakeClientCertSource,
Endpoint: testOverrideEndpoint,
},
useCertEnvVar: "true",
wantCertProvider: fakeClientCertSource,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.useCertEnvVar != "unset" {
t.Setenv(googleAPIUseCertSource, tc.useCertEnvVar)
}
certProvider, err := GetClientCertificateProvider(tc.opts)
if err != nil {
if err != tc.wantErr {
t.Fatalf("err: %v", err)
}
} else {
want := fmt.Sprintf("%v", tc.wantCertProvider)
got := fmt.Sprintf("%v", certProvider)
if want != got {
t.Errorf("want cert provider: %v, got %v", want, got)
}
}
})
}
}
27 changes: 27 additions & 0 deletions auth/internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
package transport

import (
"crypto/tls"
"fmt"
"net"
"net/http"
"time"

"cloud.google.com/go/auth/credentials"
)
Expand Down Expand Up @@ -74,3 +78,26 @@ func ValidateUniverseDomain(clientUniverseDomain, credentialsUniverseDomain stri
}
return nil
}

// DefaultHTTPClientWithTLS constructs an HTTPClient using the provided tlsConfig, to support mTLS.
func DefaultHTTPClientWithTLS(tlsConfig *tls.Config) *http.Client {
trans := baseTransport()
trans.TLSClientConfig = tlsConfig
return &http.Client{Transport: trans}
}

func baseTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

0 comments on commit c6dfdcf

Please sign in to comment.