From c6dfdcf893c3f971eba15026c12db0a960ae81f2 Mon Sep 17 00:00:00 2001 From: Andy Zhao Date: Fri, 28 Jun 2024 07:08:47 -0700 Subject: [PATCH] fix(auth): Update http and grpc transports to support token exchange over mTLS (#10397) This fixes a regression in the new Golang Auth lib. The default http client used for OAuth2 token exchange (for both http and grpc transport) should be configured for mTLS (including switching endpoints) when certificate source is available (whether from default cert source or explicitly configured). --- auth/credentials/detect.go | 3 ++ auth/grpctransport/grpctransport.go | 30 +++++++++-- auth/httptransport/httptransport.go | 9 ++++ auth/internal/transport/cba.go | 6 +-- auth/internal/transport/cba_test.go | 74 ++++++++++++++++++++++++++++ auth/internal/transport/transport.go | 27 ++++++++++ 6 files changed, 143 insertions(+), 6 deletions(-) diff --git a/auth/credentials/detect.go b/auth/credentials/detect.go index c4728da3a41c..cfa0c88f8105 100644 --- a/auth/credentials/detect.go +++ b/auth/credentials/detect.go @@ -37,6 +37,9 @@ const ( googleAuthURL = "https://accounts.google.com/o/oauth2/auth" googleTokenURL = "https://oauth2.googleapis.com/token" + // GoogleMTLSTokenURL is Google's default OAuth2.0 mTLS endpoint. + GoogleMTLSTokenURL = "https://oauth2.mtls.googleapis.com/token" + // Help on default credentials adcSetupURL = "https://cloud.google.com/docs/authentication/external/set-up-adc" ) diff --git a/auth/grpctransport/grpctransport.go b/auth/grpctransport/grpctransport.go index 75bda4c63897..b4ba3dcbfa9a 100644 --- a/auth/grpctransport/grpctransport.go +++ b/auth/grpctransport/grpctransport.go @@ -16,6 +16,7 @@ package grpctransport import ( "context" + "crypto/tls" "errors" "fmt" "net/http" @@ -45,6 +46,11 @@ var ( timeoutDialerOption grpc.DialOption ) +// ClientCertProvider is a function that returns a TLS client certificate to be +// used when opening TLS connections. It follows the same semantics as +// [crypto/tls.Config.GetClientCertificate]. +type ClientCertProvider = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + // Options used to configure a [GRPCClientConnPool] from [Dial]. type Options struct { // DisableTelemetry disables default telemetry (OpenTelemetry). An example @@ -69,6 +75,10 @@ type Options struct { // Credentials used to add Authorization metadata to all requests. If set // DetectOpts are ignored. Credentials *auth.Credentials + // ClientCertProvider is a function that returns a TLS client certificate to + // be used when opening TLS connections. It follows the same semantics as + // crypto/tls.Config.GetClientCertificate. + ClientCertProvider ClientCertProvider // DetectOpts configures settings for detect Application Default // Credentials. DetectOpts *credentials.DetectOptions @@ -125,6 +135,13 @@ func (o *Options) resolveDetectOptions() *credentials.DetectOptions { if len(do.Scopes) == 0 && do.Audience == "" && io != nil { do.Audience = o.InternalOptions.DefaultAudience } + if o.ClientCertProvider != nil { + tlsConfig := &tls.Config{ + GetClientCertificate: o.ClientCertProvider, + } + do.Client = transport.DefaultHTTPClientWithTLS(tlsConfig) + do.TokenURL = credentials.GoogleMTLSTokenURL + } return do } @@ -189,9 +206,10 @@ func Dial(ctx context.Context, secure bool, opts *Options) (GRPCClientConnPool, // return a GRPCClientConnPool if pool == 1 or else a pool of of them if >1 func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, error) { tOpts := &transport.Options{ - Endpoint: opts.Endpoint, - Client: opts.client(), - UniverseDomain: opts.UniverseDomain, + Endpoint: opts.Endpoint, + ClientCertProvider: opts.ClientCertProvider, + Client: opts.client(), + UniverseDomain: opts.UniverseDomain, } if io := opts.InternalOptions; io != nil { tOpts.DefaultEndpointTemplate = io.DefaultEndpointTemplate @@ -213,6 +231,12 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er grpc.WithTransportCredentials(transportCreds), } + // Ensure the token exchange HTTP transport uses the same ClientCertProvider as the GRPC API transport. + opts.ClientCertProvider, err = transport.GetClientCertificateProvider(tOpts) + if err != nil { + return nil, err + } + // Authentication can only be sent when communicating over a secure connection. if !opts.DisableAuthentication { metadata := opts.Metadata diff --git a/auth/httptransport/httptransport.go b/auth/httptransport/httptransport.go index ef09c1b75238..969c8d4d2008 100644 --- a/auth/httptransport/httptransport.go +++ b/auth/httptransport/httptransport.go @@ -116,6 +116,13 @@ func (o *Options) resolveDetectOptions() *detect.DetectOptions { if len(do.Scopes) == 0 && do.Audience == "" && io != nil { do.Audience = o.InternalOptions.DefaultAudience } + if o.ClientCertProvider != nil { + tlsConfig := &tls.Config{ + GetClientCertificate: o.ClientCertProvider, + } + do.Client = transport.DefaultHTTPClientWithTLS(tlsConfig) + do.TokenURL = detect.GoogleMTLSTokenURL + } return do } @@ -195,6 +202,8 @@ func NewClient(opts *Options) (*http.Client, error) { if baseRoundTripper == nil { baseRoundTripper = defaultBaseTransport(clientCertProvider, dialTLSContext) } + // Ensure the token exchange transport uses the same ClientCertProvider as the API transport. + opts.ClientCertProvider = clientCertProvider trans, err := newTransport(baseRoundTripper, opts) if err != nil { return nil, err diff --git a/auth/internal/transport/cba.go b/auth/internal/transport/cba.go index 6ef88311a249..d94e0af08a35 100644 --- a/auth/internal/transport/cba.go +++ b/auth/internal/transport/cba.go @@ -176,7 +176,7 @@ func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, } func getTransportConfig(opts *Options) (*transportConfig, error) { - clientCertSource, err := getClientCertificateSource(opts) + clientCertSource, err := GetClientCertificateProvider(opts) if err != nil { return nil, err } @@ -210,13 +210,13 @@ func getTransportConfig(opts *Options) (*transportConfig, error) { }, nil } -// getClientCertificateSource returns a default client certificate source, if +// GetClientCertificateProvider returns a default client certificate source, if // not provided by the user. // // A nil default source can be returned if the source does not exist. Any exceptions // encountered while initializing the default source will be reported as client // error (ex. corrupt metadata file). -func getClientCertificateSource(opts *Options) (cert.Provider, error) { +func GetClientCertificateProvider(opts *Options) (cert.Provider, error) { if !isClientCertificateEnabled(opts) { return nil, nil } else if opts.ClientCertProvider != nil { diff --git a/auth/internal/transport/cba_test.go b/auth/internal/transport/cba_test.go index 4c1f562327c8..6be9c69e49bd 100644 --- a/auth/internal/transport/cba_test.go +++ b/auth/internal/transport/cba_test.go @@ -21,6 +21,9 @@ import ( "net/http" "testing" "time" + + "cloud.google.com/go/auth/internal" + "cloud.google.com/go/auth/internal/transport/cert" ) const ( @@ -627,3 +630,74 @@ func TestGetGRPCTransportCredsAndEndpoint_UniverseDomain(t *testing.T) { }) } } + +func TestGetClientCertificateProvider(t *testing.T) { + testCases := []struct { + name string + opts *Options + useCertEnvVar string + wantCertProvider cert.Provider + wantErr error + }{ + { + name: "UseCertEnvVar false, Domain is GDU", + opts: &Options{ + UniverseDomain: internal.DefaultUniverseDomain, + ClientCertProvider: fakeClientCertSource, + Endpoint: testRegularEndpoint, + }, + useCertEnvVar: "false", + wantCertProvider: nil, + }, + { + name: "UseCertEnvVar unset, Domain is not GDU", + opts: &Options{ + UniverseDomain: testUniverseDomain, + ClientCertProvider: fakeClientCertSource, + Endpoint: testOverrideEndpoint, + }, + useCertEnvVar: "unset", + wantCertProvider: nil, + }, + { + name: "UseCertEnvVar unset, Domain is GDU", + opts: &Options{ + UniverseDomain: internal.DefaultUniverseDomain, + ClientCertProvider: fakeClientCertSource, + Endpoint: testRegularEndpoint, + }, + useCertEnvVar: "unset", + wantCertProvider: fakeClientCertSource, + }, + { + name: "UseCertEnvVar true, Domain is not GDU", + opts: &Options{ + UniverseDomain: testUniverseDomain, + ClientCertProvider: fakeClientCertSource, + Endpoint: testOverrideEndpoint, + }, + useCertEnvVar: "true", + wantCertProvider: fakeClientCertSource, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.useCertEnvVar != "unset" { + t.Setenv(googleAPIUseCertSource, tc.useCertEnvVar) + } + certProvider, err := GetClientCertificateProvider(tc.opts) + if err != nil { + if err != tc.wantErr { + t.Fatalf("err: %v", err) + } + } else { + want := fmt.Sprintf("%v", tc.wantCertProvider) + got := fmt.Sprintf("%v", certProvider) + if want != got { + t.Errorf("want cert provider: %v, got %v", want, got) + } + } + }) + } +} diff --git a/auth/internal/transport/transport.go b/auth/internal/transport/transport.go index b76386d3c0df..2e2451c57645 100644 --- a/auth/internal/transport/transport.go +++ b/auth/internal/transport/transport.go @@ -17,7 +17,11 @@ package transport import ( + "crypto/tls" "fmt" + "net" + "net/http" + "time" "cloud.google.com/go/auth/credentials" ) @@ -74,3 +78,26 @@ func ValidateUniverseDomain(clientUniverseDomain, credentialsUniverseDomain stri } return nil } + +// DefaultHTTPClientWithTLS constructs an HTTPClient using the provided tlsConfig, to support mTLS. +func DefaultHTTPClientWithTLS(tlsConfig *tls.Config) *http.Client { + trans := baseTransport() + trans.TLSClientConfig = tlsConfig + return &http.Client{Transport: trans} +} + +func baseTransport() *http.Transport { + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +}