From c62e5c6665e06611d792c892028c8e5e840f8447 Mon Sep 17 00:00:00 2001 From: Kui Xu Date: Tue, 8 Aug 2023 13:11:02 -0700 Subject: [PATCH] feat: add additional checks before using S2A (#2103) --- internal/cba.go | 40 +++++++++++++++++++++++++++++----------- internal/cba_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/internal/cba.go b/internal/cba.go index cecbb9ba115..6923d3a716e 100644 --- a/internal/cba.go +++ b/internal/cba.go @@ -91,16 +91,10 @@ func getTransportConfig(settings *DialSettings) (*transportConfig, error) { s2aMTLSEndpoint: "", } - // Check the env to determine whether to use S2A. - if !isGoogleS2AEnabled() { + if !shouldUseS2A(clientCertSource, settings) { return &defaultTransportConfig, nil } - // If client cert is found, use that over S2A. - // If MTLS is not enabled for the endpoint, skip S2A. - if clientCertSource != nil || !mtlsEndpointEnabledForS2A() { - return &defaultTransportConfig, nil - } s2aMTLSEndpoint := settings.DefaultMTLSEndpoint // If there is endpoint override, honor it. if settings.Endpoint != "" { @@ -118,10 +112,6 @@ func getTransportConfig(settings *DialSettings) (*transportConfig, error) { }, nil } -func isGoogleS2AEnabled() bool { - return strings.ToLower(os.Getenv(googleAPIUseS2AEnv)) == "true" -} - // getClientCertificateSource returns a default client certificate source, if // not provided by the user. // @@ -275,8 +265,36 @@ func GetHTTPTransportConfigAndEndpoint(settings *DialSettings) (cert.Source, fun return nil, dialTLSContextFunc, config.s2aMTLSEndpoint, nil } +func shouldUseS2A(clientCertSource cert.Source, settings *DialSettings) bool { + // If client cert is found, use that over S2A. + if clientCertSource != nil { + return false + } + // If EXPERIMENTAL_GOOGLE_API_USE_S2A is not set to true, skip S2A. + if !isGoogleS2AEnabled() { + return false + } + // If DefaultMTLSEndpoint is not set, skip S2A. + if settings.DefaultMTLSEndpoint == "" { + return false + } + // If MTLS is not enabled for this endpoint, skip S2A. + if !mtlsEndpointEnabledForS2A() { + return false + } + // If custom HTTP client is provided, skip S2A. + if settings.HTTPClient != nil { + return false + } + return true +} + // mtlsEndpointEnabledForS2A checks if the endpoint is indeed MTLS-enabled, so that we can use S2A for MTLS connection. var mtlsEndpointEnabledForS2A = func() bool { // TODO(xmenxk): determine this via discovery config. return true } + +func isGoogleS2AEnabled() bool { + return strings.ToLower(os.Getenv(googleAPIUseS2AEnv)) == "true" +} diff --git a/internal/cba_test.go b/internal/cba_test.go index 7b271af975e..761d8e7d07e 100644 --- a/internal/cba_test.go +++ b/internal/cba_test.go @@ -6,6 +6,7 @@ package internal import ( "crypto/tls" + "net/http" "os" "testing" "time" @@ -278,6 +279,29 @@ func TestGetHTTPTransportConfigAndEndpoint(t *testing.T) { testOverrideEndpoint, false, }, + { + "no client cert, S2A address not empty, but DefaultMTLSEndpoint is not set", + &DialSettings{ + DefaultMTLSEndpoint: "", + DefaultEndpoint: testRegularEndpoint, + }, + validConfigResp, + func() bool { return true }, + testRegularEndpoint, + true, + }, + { + "no client cert, endpoint is MTLS enabled, S2A address not empty, custom HTTP client", + &DialSettings{ + DefaultMTLSEndpoint: testMTLSEndpoint, + DefaultEndpoint: testRegularEndpoint, + HTTPClient: http.DefaultClient, + }, + validConfigResp, + func() bool { return true }, + testRegularEndpoint, + true, + }, } defer setupTest()()